In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import plot.viz_sequence as viz_sequence
from feature.util import one_hot_to_seq
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold",
    "svg.fonttype": "none"
}
plt.rcParams.update(plot_params)

### Define constants and paths

In [None]:
out_path = "/users/amtseng/tfmodisco/figures/profile_imp_score_example"
os.makedirs(out_path, exist_ok=True)

In [None]:
tf_num_tasks = {
    "E2F6": 2,
    "FOXA2": 4,
    "SPI1": 4,
    "CEBPB": 7,
    "MAX": 7,
    "GABPA": 9,
    "MAFK": 9,
    "JUND": 14,
    "NR3C1-reddytime": 16,
    "REST": 20
}

tf_best_model_types = {
    "E2F6": list("MM"),
    "FOXA2": list("SSMM"),
    "SPI1": list("MSSS"),
    "CEBPB": list("MMMMSMM"),
    "MAX": list("MMSMMSS"),
    "GABPA": list("MMMSMMMMM"),
    "MAFK": list("MMMMMMMMM"),
    "JUND": list("SMMSMSSSSSSSMS"),
    "NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
    "REST": list("MMMMMMMMMSMMSMMSMMMM")
}

### Helper functions

In [None]:
def get_predictions_impscores_path(tf_name, model_type, task_index):
    """
    Gets the path to the predictions and importance scores HDF5s, using
    the given TF name, the model type ("S" or "M"), and a task index.
    """
    assert model_type in ("M", "S")
    
    name_match = lambda name, task: name.startswith(tf_name + "_") and (task or "task%d_" % task_index in name) and name.endswith(".h5")
    
    preds_base_dir = "/users/amtseng/tfmodisco/results/peak_predictions"
    scores_base_dir = "/users/amtseng/tfmodisco/results/importance_scores"
    preds_subdir = os.path.join(preds_base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
    scores_subdir = os.path.join(scores_base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
    
    preds_path = None
    for path, _, names in os.walk(preds_subdir):
        for name in names:
            if name_match(name, model_type == "M"):
                assert preds_path is None
                preds_path = os.path.join(path, name)
    scores_path = None
    for path, _, names in os.walk(scores_subdir):
        for name in names:
            if name_match(name, False):
                assert scores_path is None
                scores_path = os.path.join(path, name)
    return preds_path, scores_path

### Show an example of motif hits

In [None]:
def plot_example_hits(
    chrom, start, end, profiles_hdf5_path, imp_scores_hdf5_path, task_index, prof_center_size=700,
    score_center_size=150, hyp_score_key="profile_hyp_scores", save_path=None
):
    """
    For a given region, plots the true/predicted profiles and importance scores.
    """
    mid = (start + end) // 2
    prof_start = mid - (prof_center_size // 2)
    prof_end = prof_start + prof_center_size
    score_start = mid - (score_center_size // 2)
    score_end = score_start + score_center_size
    
    with h5py.File(profiles_hdf5_path, "r") as f:
        # Need to use the coordinates of the profiles themselves
        prof_len = f["predictions"]["log_pred_profs"].shape[2]
        prof_coords_chrom = f["coords"]["coords_chrom"][:].astype(str)
        prof_coords_start = f["coords"]["coords_start"][:]
        prof_coords_end = f["coords"]["coords_end"][:]
        mid = (prof_coords_start + prof_coords_end) // 2
        prof_coords_start = mid - (prof_len // 2)
        prof_coords_end = prof_coords_start + prof_len
        match_inds = np.where(
            (prof_coords_chrom == chrom) &
            (prof_coords_start <= prof_start) &
            (prof_coords_end >= prof_end)
        )[0]
        if not match_inds.size:
            print("Warning: did not find sufficiently large prediction track for %s:%d-%s" % (chrom, prof_start, prof_end))
            return
        
        match_ind = match_inds[0]
        
        coord_start, coord_end = prof_coords_start[match_ind], prof_coords_end[match_ind]
        cut_start = prof_start - coord_start
        cut_end = cut_start + prof_center_size
        
        if f["predictions"]["log_pred_profs"].shape[1] == 1:
            task_index = 0
        pred_profs = np.exp(f["predictions"]["log_pred_profs"][match_ind][task_index][cut_start:cut_end])
        true_profs = f["predictions"]["true_profs"][match_ind][task_index][cut_start:cut_end]
    with h5py.File(imp_scores_hdf5_path, "r") as f:
        match_inds = np.where(
            (f["coords_chrom"][:].astype(str) == chrom) &
            (f["coords_start"][:] <= score_start) &
            (f["coords_end"][:] >= score_end)
        )[0]
        if not match_inds.size:
            print("Warning: did not find sufficiently large importance score track for %s:%d-%s" % (chrom, score_start, score_end))
            return
        
        match_ind = match_inds[0]
        
        coord_start, coord_end = f["coords_start"][match_ind], f["coords_end"][match_ind]
        hyp_scores = f[hyp_score_key][match_ind]
        one_hot_seq = f["input_seqs"][match_ind]
        
        cut_start = score_start - coord_start
        cut_end = cut_start + score_center_size
        hyp_scores = hyp_scores[cut_start:cut_end]
        one_hot_seq = one_hot_seq[cut_start:cut_end]
    
    prof_fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(20, 8))
    # Draw profiles
    ax[0].plot(true_profs[:, 0], color="darkslateblue")
    ax[0].plot(-true_profs[:, 1], color="darkorange")
    ax[0].set_title("True ChIP-seq profiles")
    ax[1].plot(pred_profs[:, 0], color="darkslateblue")
    ax[1].plot(-pred_profs[:, 1], color="darkorange")
    ax[1].set_title("Predicted ChIP-seq profiles")
    
    # Draw vertical lines that denote the portion with importance scores
    for i in range(2):
        ax[i].axvline(score_start - prof_start, color="gray")
        ax[i].axvline(score_end - prof_start, color="gray")
    if save_path:
        plt.savefig(
            os.path.join(save_path + "_profiles.svg"), format="svg"
        )
    plt.show()

    score_fig = viz_sequence.plot_weights(
        hyp_scores * one_hot_seq, figsize=(20, 4), subticks_frequency=score_center_size, return_fig=True
    )
    score_fig.tight_layout()
    if save_path:
        plt.savefig(
            os.path.join(save_path + "_impscores.svg"), format="svg"
        )
    plt.show()
    print(one_hot_to_seq(one_hot_seq))

In [None]:
tf_name = "MAX"
task_index = 0
model_type = tf_best_model_types[tf_name][task_index]

profiles_hdf5_path, imp_scores_hdf5_path = get_predictions_impscores_path(tf_name, model_type, task_index)

with h5py.File(imp_scores_hdf5_path, "r") as f:
    num_coords = len(f["coords_chrom"])
    coords = np.empty((num_coords, 3), dtype=object)
    coords[:, 0] = f["coords_chrom"][:].astype(str)
    coords[:, 1] = f["coords_start"][:]
    coords[:, 2] = f["coords_end"][:]

num_to_take = 20

seed = 20211203
rng = np.random.RandomState(seed)

for peak_ind in rng.choice(len(coords), size=num_to_take, replace=False):
    chrom, start, end = coords[peak_ind]

    print("%s:%d-%d (index %d)" % (chrom, start, end, peak_ind))

    if peak_ind in (9735, 43745):
        save_path = os.path.join(out_path, "motif_hit_example_%s_task%d_peak%d" % (tf_name, task_index, peak_ind))
        plot_example_hits(
            chrom, start, end, profiles_hdf5_path, imp_scores_hdf5_path, task_index, save_path=save_path
        )
    else:
        plot_example_hits(
            chrom, start, end, profiles_hdf5_path, imp_scores_hdf5_path, task_index
        )
    print("")