In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import matplotlib.pyplot as plt
import vdom.helpers as vdomh
from IPython.display import display

### Define constants and paths

In [None]:
# Define parameters/fetch arguments

tf_name = "REST"
model_type = "MTFT"
fold_num = 7
task_index = None
head = "count"

assert model_type in ("STFT", "MTFT")
assert head in ("profile", "count")

if model_type == "STFT":
    motif_file = os.path.join(
        "/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/",
        "singletask_profile_finetune",
        "%s_singletask_profile_finetune_fold%d" % (tf_name, fold_num),
        "task_%d" % task_index,
        "%s_singletask_profile_finetune_task%d_fold%d_%s" % (tf_name, task_index, fold_num, head),
        "all_motifs.h5"
    )
else:
    motif_file = os.path.join(
        "/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/",
        "multitask_profile_finetune",
        "%s_multitask_profile_finetune_fold%d" % (tf_name, fold_num)
    )
    if task_index is None:
        motif_file = os.path.join(
            motif_file,
            "%s_multitask_profile_finetune_fold%d_%s" % (tf_name, fold_num, head),
            "all_motifs.h5"
        )
    else:
        motif_file = os.path.join(
            motif_file,
            "%s_multitask_profile_finetune_task%d_fold%d_%s" % (tf_name, task_index, fold_num, head),
            "all_motifs.h5"
        )

print("Motif file: %s" % motif_file)

### Helper functions
For plotting and organizing things

In [None]:
def import_motifs(motif_file):
    """
    Imports a set of motifs from the saved HDF5 file.
    Returns a list of motifs as L x 4 arrays and a parallel list of
    motif names
    """
    motifs, motif_names = [], []
    with h5py.File(motif_file, "r") as f:
        for key in f.keys():
            motif_names.append(key)
            motifs.append(f[key]["cwm_trimmed"][:])
    return motifs, motif_names

### Show motifs

In [None]:
motifs, motif_names = import_motifs(motif_file)

In [None]:
# Flip all motifs to be the purine-rich version
for i, motif in enumerate(motifs):
    if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
        motifs[i] = np.flip(motif)

In [None]:
# Show aggregated and constituent motifs for each cluster
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "95%"})
)

header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Motif key", style={"text-align": "center"}),
        vdomh.th("CWM", style={"text-align": "center"})
    )
)

rows = []
for i in sorted(range(len(motifs)), key=lambda i: (int(motif_names[i].split("_")[0]), int(motif_names[i].split("_")[1]))):
    motif_key, motif = motif_names[i], motifs[i]
    fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)
    fig.tight_layout()
    rows.append(
        vdomh.tr(
            vdomh.td(motif_key),
            vdomh.td(figure_to_vdom_image(fig))
        )
    )

display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
plt.close("all")