In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import motif.match_motifs as match_motifs
import motif.read_motifs as read_motifs
import plot.viz_sequence as viz_sequence
from util import figure_to_vdom_image, create_motif_similarity_matrix
import tempfile
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook(range(1))

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold",
    "svg.fonttype": "none"
}
plt.rcParams.update(plot_params)

In [None]:
if "TFM_TF_NAME" in os.environ:
    tf_name = os.environ["TFM_TF_NAME"]
else:
    tf_name = "E2F6"

In [None]:
tf_num_tasks = {
    "E2F6": 2,
    "FOXA2": 4,
    "SPI1": 4,
    "CEBPB": 7,
    "MAX": 7,
    "GABPA": 9,
    "MAFK": 9,
    "JUND": 14,
    "NR3C1-reddytime": 16,
    "REST": 20
}

tf_best_model_types = {
    "E2F6": list("MM"),
    "FOXA2": list("SSMM"),
    "SPI1": list("MSSS"),
    "CEBPB": list("MMMMSMM"),
    "MAX": list("MMSMMSS"),
    "GABPA": list("MMMSMMMMM"),
    "MAFK": list("MMMMMMMMM"),
    "JUND": list("SMMSMSSSSSSSMS"),
    "NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
    "REST": list("MMMMMMMMMSMMSMMSMMMM")
}

num_tasks = tf_num_tasks[tf_name]
best_model_types = tf_best_model_types[tf_name]

tfm_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/%s_tfmodisco_cpmerged_motifs.h5" % tf_name
meme_motif_file = "/users/amtseng/tfmodisco/results/motifs/meme/%s_meme_motifs.h5" % tf_name
homer_motif_file = "/users/amtseng/tfmodisco/results/motifs/homer/%s_homer_motifs.h5" % tf_name
dichipmunk_motif_file = "/users/amtseng/tfmodisco/results/motifs/dichipmunk/%s_dichipmunk_motifs.h5" % tf_name

In [None]:
motif_database_path = "/users/amtseng/tfmodisco/data/processed/motif_databases/JASPAR2020_CORE_vertebrates_non-redundant_pfms_meme.txt"

### Helper functions

In [None]:
def purine_rich_motif(motif):
    """
    Flip motif to be the purine-rich orientation
    """
    if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
        return np.flip(motif, axis=(0, 1))
    return motif

In [None]:
def renorm_motif(motif, pseudocount=1e-10):
    """
    Renormalizes a motif (L x 4 array) so that the bases sum to 1.
    """
    s = np.sum(motif, axis=1, keepdims=True)
    assert np.all(s > 0)
    return motif / s

In [None]:
def import_tfmodisco_motifs(motif_file, model_types, motif_type="cwm_trimmed"):
    """
    From a file containing all motifs for that TF, imports the
    trimmed CWMs (or another kind of motif type) of the fine-tuned models
    corresponding to the model type for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to motif.
    """
    motifs = []
    with h5py.File(motif_file, "r") as f:
        mtft = f["multitask_finetune"]
        stft = f["singletask_finetune"]
        for i, model_type in enumerate(model_types):
            task = "task_%d" % i
            if model_type == "M":
                dset = mtft[task]
            else:
                dset = stft[task]
            task_motifs = {}
            for motif_key in dset.keys():
                if "0_" in motif_key:
                    # Motifs that are (or are constructed from) positive metacluster only
                    motif = dset[motif_key][motif_type][:]
                    if motif_type.startswith("pfm"):
                        motif = renorm_motif(motif)
                    task_motifs["T%d:%s" % (i, motif_key)] = purine_rich_motif(motif)
            motifs.append(task_motifs)
    return motifs

In [None]:
def import_classic_benchmark_motifs(motif_file, mode):
    """
    From a file containing all motifs for that TF from a benchmark
    method, imports the PFMs of the motifs for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to motif.
    """
    if mode == "dichipmunk":
        score_key = "supporting_seqs"
    elif mode == "homer":
        score_key = "log_enrichment"
    elif mode == "meme":
        score_key = "evalue"
    motifs = []
    with h5py.File(motif_file, "r") as f:
        tasks = sorted([int(key[5:]) for key in f.keys() if key != "task_agg"])
        for i in tasks:
            dset = f["task_%d" % i]
            task_motifs = {}
            for motif_key in dset.keys():
                if motif_key == score_key:
                    continue
                task_motifs["T%d:%s" % (i, motif_key)] = purine_rich_motif(renorm_motif(dset[motif_key][:]))
            motifs.append(task_motifs)
    return motifs

In [None]:
def import_database_pfms(database_path):
    """
    Imports the database of PFMs by reading through the entire database and
    constructing a dictionary mapping motif IDs to NumPy arrays of PFMs.
    """
    motif_dict = {}
    with open(database_path, "r") as f:
        try:
            while True:
                line = next(f)
                if line.startswith("MOTIF"):
                    key = line.strip().split()[1]
                    header = next(f)
                    motif_width = int(header.split()[5])
                    motif = np.empty((motif_width, 4))
                    for i in range(motif_width):
                        motif[i] = np.array([
                            float(x) for x in next(f).strip().split()
                        ])
                    # Add the motif with a shortened key
                    motif_dict[key.split("_")[1]] = purine_rich_motif(renorm_motif(motif))
        except StopIteration:
            pass
    return motif_dict

In [None]:
def get_closest_tomtom_motif_similarities(query_dict, target_dict):
    """
    From a dictionary mapping N motif keys to query motifs, and a
    dictionary mapping M motif keys to target motifs, returns a
    dictionary mapping the N query motif keys to the similarity and
    key of the closest target motif (a pair). Similarity is the
    -log(p) from TOMTOM.
    """
    query_keys, query_pfms = list(zip(*query_dict.items()))
    target_keys, target_pfms = list(zip(*target_dict.items()))
    
    # Create temporary directory to do work in
    temp_dir_obj = tempfile.TemporaryDirectory()
    temp_dir = temp_dir_obj.name

    # Convert motifs to MEME format
    query_motif_file = os.path.join(temp_dir, "query_motifs.txt")
    target_motif_file = os.path.join(temp_dir, "target_motifs.txt")
    match_motifs.export_pfms_to_meme_format(query_pfms, query_motif_file)
    match_motifs.export_pfms_to_meme_format(target_pfms, target_motif_file)

    # Run TOMTOM
    tomtom_dir = os.path.join(temp_dir, "tomtom")
    match_motifs.run_tomtom(
        query_motif_file, target_motif_file, tomtom_dir,
        show_output=False
    )

    # Find results, mapping each query motif to target index
    # The query/target IDs are the indices
    tomtom_table = match_motifs.import_tomtom_results(tomtom_dir)
    matches = []
    for i in range(len(query_pfms)):
        rows = tomtom_table[tomtom_table["Query_ID"] == i]
        if rows.empty:
            matches.append((0, "N/A"))
            continue
        min_row = rows.loc[rows["p-value"].idxmin()]
        score = -np.log10(min_row["p-value"])
        target_key = target_keys[min_row["Target_ID"]]
        matches.append((score, target_key))

    temp_dir_obj.cleanup()
    
    return dict(zip(query_keys, matches))

### Import motifs

In [None]:
tfm_motif_cwms = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "cwm_trimmed")
tfm_motif_pfms = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "pfm_trimmed")
meme_motifs = import_classic_benchmark_motifs(meme_motif_file, "meme")
homer_motifs = import_classic_benchmark_motifs(homer_motif_file, "homer")
dichipmunk_motifs = import_classic_benchmark_motifs(dichipmunk_motif_file, "dichipmunk")
database_motifs = import_database_pfms(motif_database_path)

### Compute similarity of TF-MoDISco motifs to each benchmark motif
For each benchmark motif, compute the closest TF-MoDISco motif to it.

In [None]:
# Compute TF-MoDISco similarity dictionaries by task
meme_tfm_sims = [
    get_closest_tomtom_motif_similarities(meme_motifs[task_index], tfm_motif_pfms[task_index])
    for task_index in range(len(meme_motifs))
]
homer_tfm_sims = [
    get_closest_tomtom_motif_similarities(homer_motifs[task_index], tfm_motif_pfms[task_index])
    for task_index in range(len(homer_motifs))
]
dichipmunk_tfm_sims = [
    get_closest_tomtom_motif_similarities(dichipmunk_motifs[task_index], tfm_motif_pfms[task_index])
    for task_index in range(len(dichipmunk_motifs))
]

In [None]:
# Compute database similarity dictionaries by task
meme_database_sims = [
    get_closest_tomtom_motif_similarities(meme_motifs[task_index], database_motifs)
    for task_index in range(len(meme_motifs))
]
homer_database_sims = [
    get_closest_tomtom_motif_similarities(homer_motifs[task_index], database_motifs)
    for task_index in range(len(homer_motifs))
]
dichipmunk_database_sims = [
    get_closest_tomtom_motif_similarities(dichipmunk_motifs[task_index], database_motifs)
    for task_index in range(len(dichipmunk_motifs))
]

### Construct plots of benchmark motif database similarity vs TF-MoDISco similarity

In [None]:
# For each task, similarity of motifs to database vs TF-MoDISco
for task_index in range(len(tfm_motif_pfms)):
    
    fig, ax = plt.subplots(figsize=(8, 5))
    meme_keys = meme_motifs[task_index].keys()
    ax.scatter(
        [meme_tfm_sims[task_index][k][0] for k in meme_keys],
        [meme_database_sims[task_index][k][0] for k in meme_keys],
        label="MEME"
    )
    homer_keys = homer_motifs[task_index].keys()
    ax.scatter(
        [homer_tfm_sims[task_index][k][0] for k in homer_keys],
        [homer_database_sims[task_index][k][0] for k in homer_keys],
        label="HOMER"
    )
    dichipmunk_keys = dichipmunk_motifs[task_index].keys()
    ax.scatter(
        [dichipmunk_tfm_sims[task_index][k][0] for k in dichipmunk_keys],
        [dichipmunk_database_sims[task_index][k][0] for k in dichipmunk_keys],
        label="DiChIPMunk"
    )
    
    (min_x, max_x), (min_y, max_y) = ax.get_xlim(), ax.get_ylim()
    min_both, max_both = min(min_x, min_y), max(max_x, max_y)
    ax.set_xlim(min_both, max_both)
    ax.set_ylim(min_both, max_both)
    ax.plot(
        [min_both, max_both], [min_both, max_both],
        color="black", linestyle="--", alpha=0.3, zorder=0
    )

    ax.legend()
    ax.set_xlabel("Maximum similarity to a TF-MoDISco motif")
    ax.set_ylabel("Maximum similarity to a JASPAR motif")
    ax.set_title("Motif benchmark reproducibility: task %d" % task_index)

In [None]:
# For each task, show all motifs in order of similarity ratio

colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "30"}),
    vdomh.col(style={"width": "5"}),
    vdomh.col(style={"width": "5"}),
    vdomh.col(style={"width": "20"}),
    vdomh.col(style={"width": "5"}),
    vdomh.col(style={"width": "5"}),
    vdomh.col(style={"width": "20"}),
)
header = vdomh.thead(
    vdomh.th("Motif", style={"text-align": "center"}),
    vdomh.th("Key", style={"text-align": "center"}),
    vdomh.th("Similarity", style={"text-align": "center"}),
    vdomh.th("Motif", style={"text-align": "center"}),
    vdomh.th("Key", style={"text-align": "center"}),
    vdomh.th("Similarity", style={"text-align": "center"}),
    vdomh.th("Motif", style={"text-align": "center"}),
)

safe_div = lambda a, b: a / b if b else float("inf")

for task_index in range(len(tfm_motif_pfms)):
    display(vdomh.h3("Task %d" % task_index))
    
    # Rank motif keys by decreasing similarity ratio of TF-MoDISco / database
    meme_keys = sorted(
        meme_database_sims[task_index].keys(),
        key=lambda k: -safe_div(meme_tfm_sims[task_index][k][0], meme_database_sims[task_index][k][0])
    )
    homer_keys = sorted(
        homer_database_sims[task_index].keys(),
        key=lambda k: -safe_div(homer_tfm_sims[task_index][k][0], homer_database_sims[task_index][k][0])
    )
    dichipmunk_keys = sorted(
        dichipmunk_database_sims[task_index].keys(),
        key=lambda k: -safe_div(dichipmunk_tfm_sims[task_index][k][0], dichipmunk_database_sims[task_index][k][0])
    )
    
    for bench_type, key_list, motif_dict, tfm_sim_dict, database_sim_dict in [
        ("MEME", meme_keys, meme_motifs[task_index], meme_tfm_sims[task_index], meme_database_sims[task_index]),
        ("HOMER", homer_keys, homer_motifs[task_index], homer_tfm_sims[task_index], homer_database_sims[task_index]),
        ("DiChIPMunk", dichipmunk_keys, dichipmunk_motifs[task_index], dichipmunk_tfm_sims[task_index], dichipmunk_database_sims[task_index])
    ]:
        subheader = vdomh.tr(
            vdomh.td(vdomh.b(bench_type), colspan="1", style={"text-align": "center"}),
            vdomh.td(vdomh.b("TF-MoDISco"), colspan="3", style={"text-align": "center"}),
            vdomh.td(vdomh.b("Database"), colspan="3", style={"text-align": "center"}),
        )
        rows = [subheader]
        for i in range(len(key_list)):
            bench_fig = viz_sequence.plot_weights(
                read_motifs.pfm_to_pwm(motif_dict[key_list[i]]),
                figsize=(20, 4), return_fig=True
            )
            bench_fig.tight_layout()
            
            if tfm_sim_dict[key_list[i]][1] == "N/A":
                tfm_fig_cell = "N/A"
            else:
                tfm_fig = viz_sequence.plot_weights(
                    tfm_motif_cwms[task_index][tfm_sim_dict[key_list[i]][1]],
                    figsize=(20, 4), return_fig=True
                )
                tfm_fig.tight_layout()
                tfm_fig_cell = figure_to_vdom_image(tfm_fig)
            
            if database_sim_dict[key_list[i]][1] == "N/A":
                database_fig_cell = "N/A"
            else:
                database_fig = viz_sequence.plot_weights(
                    read_motifs.pfm_to_pwm(database_motifs[database_sim_dict[key_list[i]][1]]),
                    figsize=(20, 4), return_fig=True
                )
                database_fig.tight_layout()
                database_fig_cell = figure_to_vdom_image(database_fig)
            
            rows.append(vdomh.tr(
                vdomh.td(figure_to_vdom_image(bench_fig)),
                vdomh.td(tfm_sim_dict[key_list[i]][1]),
                vdomh.td("%.3f" % tfm_sim_dict[key_list[i]][0]),
                vdomh.td(tfm_fig_cell),
                vdomh.td(database_sim_dict[key_list[i]][1]),
                vdomh.td("%.3f" % database_sim_dict[key_list[i]][0]),
                vdomh.td(database_fig_cell)
            ))

        display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
        plt.close("all")

### Construct plots of benchmark motif rank and TF-MoDISco similarity

In [None]:
# For each task, plot (cumulative) distribution of similarities
for task_index in range(len(tfm_motif_pfms)):
    
    fig, ax = plt.subplots(figsize=(8, 5))
    meme_keys = meme_motifs[task_index].keys()
    ax.plot(
        np.arange(len(meme_tfm_sims[task_index])),
        np.sort([meme_tfm_sims[task_index][k][0] for k in meme_keys]),
        label="MEME"
    )
    homer_keys = homer_motifs[task_index].keys()
    ax.plot(
        np.arange(len(homer_tfm_sims[task_index])),
        np.sort([homer_tfm_sims[task_index][k][0] for k in homer_keys]),
        label="HOMER"
    )
    dichipmunk_keys = dichipmunk_motifs[task_index].keys()
    ax.plot(
        np.arange(len(dichipmunk_tfm_sims[task_index])),
        np.sort([dichipmunk_tfm_sims[task_index][k][0] for k in dichipmunk_keys]),
        label="DiChIPMunk"
    )
    
    ax.legend()
    ax.set_xlabel("Motif rank (by max similarity to TF-MoDISco motif)")
    ax.set_ylabel("Maximum similarity to a TF-MoDISco motif")
    ax.set_title("Motif benchmark reproducibility: task %d" % task_index)

In [None]:
# For each task, show the most and least reproducible motifs for each benchmark
num_to_show = 5

cols, heads = [], []
for _ in range(3):
    cols.append(vdomh.col(style={"width": "%.2f%%" % (10 / 3)}))
    cols.append(vdomh.col(style={"width": "%.2f%%" % (10 / 3)}))
    cols.append(vdomh.col(style={"width": "%.2f%%" % (80 / 3)}))
    heads.append(vdomh.th("TF-MoDISco similarity", style={"text-align": "center"}))
    heads.append(vdomh.th("TF-MoDISco key", style={"text-align": "center"}))
    heads.append(vdomh.th("Motif", style={"text-align": "center"}))
colgroup = vdomh.colgroup(*cols)
header = vdomh.thead(heads)
subheader = vdomh.tr(
    vdomh.td(vdomh.b("MEME"), colspan="3", style={"text-align": "center"}),
    vdomh.td(vdomh.b("HOMER"), colspan="3", style={"text-align": "center"}),
    vdomh.td(vdomh.b("DiChIPMunk"), colspan="3", style={"text-align": "center"}),
)

for task_index in range(len(tfm_motif_pfms)):
    display(vdomh.h3("Task %d" % task_index))
    
    # Rank motif keys by similarity
    meme_keys = sorted(
        meme_database_sims[task_index].keys(),
        key=lambda k: -meme_tfm_sims[task_index][k][0]
    )
    homer_keys = sorted(
        homer_database_sims[task_index].keys(),
        key=lambda k: -homer_tfm_sims[task_index][k][0]
    )
    dichipmunk_keys = sorted(
        dichipmunk_database_sims[task_index].keys(),
        key=lambda k: -dichipmunk_tfm_sims[task_index][k][0]
    )

    display(vdomh.h4("Most reproducible motifs"))

    rows = [subheader]
    for i in range(num_to_show):
        if i >= max([len(meme_keys), len(homer_keys), len(dichipmunk_keys)]):
            break
        row = []
        for key_list, motif_dict, sim_dict in [
            (meme_keys, meme_motifs[task_index], meme_tfm_sims[task_index]),
            (homer_keys, homer_motifs[task_index], homer_tfm_sims[task_index]),
            (dichipmunk_keys, dichipmunk_motifs[task_index], dichipmunk_tfm_sims[task_index])
        ]:
            if i < len(key_list):
                fig = viz_sequence.plot_weights(
                    read_motifs.pfm_to_pwm(motif_dict[key_list[i]]), figsize=(20, 4), return_fig=True
                )
                fig.tight_layout()
                row.extend([
                    vdomh.td("%.3f" % sim_dict[key_list[i]][0]),
                    vdomh.td(sim_dict[key_list[i]][1]),
                    vdomh.td(figure_to_vdom_image(fig))
                ])
            else:
                row.extend([vdomh.td(), vdomh.td(), vdomh.td()])
        rows.append(vdomh.tr(*row))

    display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
    plt.close("all")
    
    display(vdomh.h4("Least reproducible motifs"))
    
    rows = [subheader]
    for i in range(num_to_show):
        if i >= max([len(meme_keys), len(homer_keys), len(dichipmunk_keys)]) - num_to_show:
            # Don't show motifs that have already been shown in the previous "most reproducible" table
            break
        row = []
        for key_list, motif_dict, sim_dict in [
            (meme_keys, meme_motifs[task_index], meme_tfm_sims[task_index]),
            (homer_keys, homer_motifs[task_index], homer_tfm_sims[task_index]),
            (dichipmunk_keys, dichipmunk_motifs[task_index], dichipmunk_tfm_sims[task_index])
        ]:
            if i < len(key_list) - num_to_show:
                fig = viz_sequence.plot_weights(
                    read_motifs.pfm_to_pwm(motif_dict[key_list[len(key_list) - i - 1]]),
                    figsize=(20, 4), return_fig=True
                )
                fig.tight_layout()
                row.extend([
                    vdomh.td("%.3f" % sim_dict[key_list[len(key_list) - i - 1]][0]),
                    vdomh.td(sim_dict[key_list[len(key_list) - i - 1]][1]),
                    vdomh.td(figure_to_vdom_image(fig))
                ])
            else:
                row.extend([vdomh.td(), vdomh.td(), vdomh.td()])
        rows.append(vdomh.tr(*row))

    display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
    plt.close("all")