In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import plot.viz_sequence as viz_sequence
import motif.read_motifs as read_motifs
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold",
    "svg.fonttype": "none"
}
plt.rcParams.update(plot_params)

In [None]:
if "TFM_TF_NAME" in os.environ:
    tf_name = os.environ["TFM_TF_NAME"]
else:
    tf_name = "E2F6"

In [None]:
out_path = "/users/amtseng/tfmodisco/figures/all_tfmodisco_motifs/%s_all_tfmodisco_motifs" % tf_name
os.makedirs(out_path, exist_ok=True)

tf_num_tasks = {
    "E2F6": 2,
    "FOXA2": 4,
    "SPI1": 4,
    "CEBPB": 7,
    "MAX": 7,
    "GABPA": 9,
    "MAFK": 9,
    "JUND": 14,
    "NR3C1-reddytime": 16,
    "REST": 20
}

tf_best_model_types = {
    "E2F6": list("MM"),
    "FOXA2": list("SSMM"),
    "SPI1": list("MSSS"),
    "CEBPB": list("MMMMSMM"),
    "MAX": list("MMSMMSS"),
    "GABPA": list("MMMSMMMMM"),
    "MAFK": list("MMMMMMMMM"),
    "JUND": list("SMMSMSSSSSSSMS"),
    "NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
    "REST": list("MMMMMMMMMSMMSMMSMMMM")
}

num_tasks = tf_num_tasks[tf_name]
best_model_types = tf_best_model_types[tf_name]

tfm_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/%s_tfmodisco_cpmerged_motifs.h5" % tf_name

### Helper functions

In [None]:
def renorm_motif(motif, pseudocount=1e-10):
    """
    Renormalizes a motif (L x 4 array) so that the bases sum to 1.
    """
    s = np.sum(motif, axis=1, keepdims=True)
    assert np.all(s > 0)
    return motif / s

In [None]:
def import_tfmodisco_motifs(motif_file, model_types):
    """
    From a file containing all motifs for that TF, imports the
    trimmed PFMs, CWMs, and hCWMs of the fine-tuned models
    corresponding to the model type for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to tuple of PFM, CWM, and hCWM.
    """
    motifs = []
    with h5py.File(motif_file, "r") as f:
        mtft = f["multitask_finetune"]
        stft = f["singletask_finetune"]
        for i, model_type in enumerate(model_types):
            task = "task_%d" % i
            if model_type == "M":
                dset = mtft[task]
            else:
                dset = stft[task]
            task_motifs = {}
            for motif_key in dset.keys():
                pfm, cwm, hcwm = \
                    dset[motif_key]["pfm_trimmed"][:], dset[motif_key]["cwm_trimmed"][:], dset[motif_key]["hcwm_trimmed"][:]

                pfm = renorm_motif(pfm) 
                if np.sum(cwm[:, [0, 2]]) < 0.5 * np.sum(cwm):
                    pfm, cwm, hcwm = np.flip(pfm), np.flip(cwm), np.flip(hcwm)

                task_motifs["T%d:%s" % (i, motif_key)] = (pfm, cwm, hcwm)
            motifs.append(task_motifs)
    return motifs

### Import and save motifs

In [None]:
tfm_motifs = import_tfmodisco_motifs(tfm_motif_file, best_model_types)

In [None]:
for task_index, task_motif_dict in enumerate(tfm_motifs):
    for key, (pfm, cwm, hcwm) in task_motif_dict.items():
        fig = viz_sequence.plot_weights(
            read_motifs.pfm_to_pwm(pfm),
            subticks_frequency=100, figsize=(20, 4), return_fig=True
        )
        fig.tight_layout()
        plt.savefig(
            os.path.join(out_path, "%s_task%d_%s_pwm.svg" % (tf_name, task_index, key)),
            format="svg"
        )
        
        fig = viz_sequence.plot_weights(
            cwm, subticks_frequency=100, figsize=(20, 4), return_fig=True
        )
        fig.tight_layout()
        plt.savefig(
            os.path.join(out_path, "%s_task%d_%s_cwm.svg" % (tf_name, task_index, key)),
            format="svg"
        )
        
        fig = viz_sequence.plot_weights(
            hcwm, subticks_frequency=100, figsize=(20, 4), return_fig=True
        )
        fig.tight_layout()
        plt.savefig(
            os.path.join(out_path, "%s_task%d_%s_hcwm.svg" % (tf_name, task_index, key)),
            format="svg"
        )
        
        plt.close("all")