In [1]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import plot.viz_sequence as viz_sequence
import motif.read_motifs as read_motifs
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager

In [2]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold",
    "svg.fonttype": "none"
}
plt.rcParams.update(plot_params)

The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead.
  after removing the cwd from sys.path.


In [3]:
if "TFM_TF_NAME" in os.environ:
    tf_name = os.environ["TFM_TF_NAME"]
else:
    tf_name = "MAX"

In [4]:
countreg_out_path = "/users/amtseng/tfmodisco/figures/all_tfmodisco_motifs/%s_all_tfmodisco_motifs_countreg" % tf_name
svm_out_path = "/users/amtseng/tfmodisco/figures/all_tfmodisco_motifs/%s_all_tfmodisco_motifs_svm" % tf_name
os.makedirs(countreg_out_path, exist_ok=True)
os.makedirs(svm_out_path, exist_ok=True)

tf_num_tasks = {
    "FOXA2": 4,
    "SPI1": 4,
    "MAX": 7,
}

num_tasks = tf_num_tasks[tf_name]

countreg_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/countreg/%s_countreg_tfmodisco_motifs.h5" % tf_name
svm_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/svm/%s_svm_tfmodisco_motifs.h5" % tf_name

### Helper functions

In [5]:
def renorm_motif(motif, pseudocount=1e-10):
    """
    Renormalizes a motif (L x 4 array) so that the bases sum to 1.
    """
    s = np.sum(motif, axis=1, keepdims=True)
    assert np.all(s > 0)
    return motif / s

In [6]:
def import_countreg_svm_tfmodisco_motifs(motif_file):
    """
    From a file containing all motifs for that TF, imports the
    trimmed PFMs, CWMs, and hCWMs for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to tuple of PFM, CWM, and hCWM.
    """
    motifs = []
    with h5py.File(motif_file, "r") as f:
        for i in range(num_tasks):
            task = "task_%d" % i
            dset = f[task]
            task_motifs = {}
            for motif_key in dset.keys():
                pfm, cwm, hcwm = \
                    dset[motif_key]["pfm_trimmed"][:], dset[motif_key]["cwm_trimmed"][:], dset[motif_key]["hcwm_trimmed"][:]

                pfm = renorm_motif(pfm) 
                if np.sum(cwm[:, [0, 2]]) < 0.5 * np.sum(cwm):
                    pfm, cwm, hcwm = np.flip(pfm), np.flip(cwm), np.flip(hcwm)

                task_motifs["T%d:%s" % (i, motif_key)] = (pfm, cwm, hcwm)
            motifs.append(task_motifs)
    return motifs

### Import and save motifs

In [7]:
countreg_tfm_motifs = import_countreg_svm_tfmodisco_motifs(countreg_motif_file)
svm_tfm_motifs = import_countreg_svm_tfmodisco_motifs(svm_motif_file)

In [8]:
for tfm_motifs, out_path in [(countreg_tfm_motifs, countreg_out_path), (svm_tfm_motifs, svm_out_path)]:
    for task_index, task_motif_dict in enumerate(tfm_motifs):
        for key, (pfm, cwm, hcwm) in task_motif_dict.items():
            fig = viz_sequence.plot_weights(
                read_motifs.pfm_to_pwm(pfm),
                subticks_frequency=100, figsize=(20, 4), return_fig=True
            )
            fig.tight_layout()
            plt.savefig(
                os.path.join(out_path, "%s_task%d_%s_pwm.svg" % (tf_name, task_index, key)),
                format="svg"
            )

            fig = viz_sequence.plot_weights(
                cwm, subticks_frequency=100, figsize=(20, 4), return_fig=True
            )
            fig.tight_layout()
            plt.savefig(
                os.path.join(out_path, "%s_task%d_%s_cwm.svg" % (tf_name, task_index, key)),
                format="svg"
            )

            fig = viz_sequence.plot_weights(
                hcwm, subticks_frequency=100, figsize=(20, 4), return_fig=True
            )
            fig.tight_layout()
            plt.savefig(
                os.path.join(out_path, "%s_task%d_%s_hcwm.svg" % (tf_name, task_index, key)),
                format="svg"
            )

            plt.close("all")