### Link to results
[Results](#results)

In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import matplotlib.pyplot as plt
import vdom.helpers as vdomh
from IPython.display import display

### Define constants and paths

In [None]:
# Define parameters/fetch arguments
in_motif_file = os.environ["TFM_IN_MOTIF_FILE"]
out_motif_file = os.environ["TFM_OUT_MOTIF_FILE"]

print("Input motif file: %s" % in_motif_file)
print("Output motif file: %s" % out_motif_file)

In [None]:
min_match_sim = 0.9

In [None]:
os.makedirs(os.path.dirname(out_motif_file), exist_ok=True)

### Helper functions

In [None]:
def import_motifs_from_hdf5_group(h5_group):
    """
    Imports a set of motifs from an open HDF5 group.
    The HDF5 group must be structured as follows:
        count:
            0_0:
                cwm_trimmed
                ...
        profile:
            0_0:
                cwm_trimmed
                ...
    Returns a dictionary of the trimmed CWMs matching the
    structure of the group.
    Motifs are flipped to the purine-rich version.
    """
    motifs = {}
    for key in ("count", "profile"):
        motifs[key] = {}
        for motif_key in h5_group[key]:
            cwm = h5_group[key][motif_key]["cwm_trimmed"][:]
            # Flip motif to be the purine-rich version
            if np.sum(cwm[:, [0, 2]]) < 0.5 * np.sum(cwm):
                cwm = np.flip(cwm)
            motifs[key][motif_key] = cwm
        
    return motifs

In [None]:
def write_merged_motifs_to_hdf5_group(in_h5_group, out_h5_group, matched_keys):
    """
    Writes a set of motifs from an open HDF5 group.
    The input HDF5 group must be structured as follows:
        count:
            0_0:
                cwm_trimmed
                ...
        profile:
            0_0:
                cwm_trimmed
                ...
    The output HDF5 group will be stuctured as follows:
        C0_0:P0_1:
            cwm_trimmed
            ...
    `matched_keys` is a list of pairs of keys, denoting count
    and profile keys that are matched together, respectively.
    Motifs are flipped to the purine-rich version.
    """
    count_keys = list(in_h5_group["count"].keys())
    profile_keys = list(in_h5_group["profile"].keys())

    matched_count, matched_profile = (list(zip(*matched_keys))) if matched_keys else ([], [])
    
    unmatched_count = [key for key in count_keys if key not in matched_count]
    unmatched_profile = [key for key in profile_keys if key not in matched_profile]
    
    for count_key, profile_key in matched_keys:
        group = out_h5_group.create_group("C%s:P%s" % (count_key, profile_key))
        for motif_type in ["pfm_full", "cwm_full", "hcwm_full", "pfm_trimmed", "cwm_trimmed", "hcwm_trimmed"]:
            agg = aggregate_motifs([
                in_h5_group["count"][count_key][motif_type][:],
                in_h5_group["profile"][profile_key][motif_type][:]
            ])
            group.create_dataset(motif_type, data=agg, compression="gzip")
    
    for count_key in unmatched_count:
        group = out_h5_group.create_group("C%s" % count_key)
        for motif_type in ["pfm_full", "cwm_full", "hcwm_full", "pfm_trimmed", "cwm_trimmed", "hcwm_trimmed"]:
            motif = in_h5_group["count"][count_key][motif_type][:]
            group.create_dataset(motif_type, data=motif, compression="gzip")
    for profile_key in unmatched_profile:
        group = out_h5_group.create_group("P%s" % profile_key)
        for motif_type in ["pfm_full", "cwm_full", "hcwm_full", "pfm_trimmed", "cwm_trimmed", "hcwm_trimmed"]:
            motif = in_h5_group["profile"][profile_key][motif_type][:]
            group.create_dataset(motif_type, data=motif, compression="gzip")

In [None]:
def motif_similarity_score(motif_1, motif_2, average=True, align_to_longer=True):
    """
    Computes the motif similarity score between two motifs by
    the summed cosine similarity, maximized over all possible sliding
    windows. Also returns the index relative to the start of `motif_2`
    where `motif_1` should be placed to maximize this score.
    If `average` is True, then use average of similarity of overlap.
    If `align_to_longer` is True, always use the longer motif as the basis
    for the index computation (if tie use `motif_2`). Otherwise, always use
    `motif_2`.
    """
    # L2-normalize
    motif_1 = motif_1 - np.mean(motif_1, axis=1, keepdims=True)
    motif_2 = motif_2 - np.mean(motif_2, axis=1, keepdims=True)
    motif_1 = motif_1 / np.sqrt(np.sum(motif_1 * motif_1, axis=1, keepdims=True))
    motif_2 = motif_2 / np.sqrt(np.sum(motif_2 * motif_2, axis=1, keepdims=True))
    
    # Mean-normalize
    motif_1 = motif_1 - np.mean(motif_1, axis=1, keepdims=True)
    motif_2 = motif_2 - np.mean(motif_2, axis=1, keepdims=True)
    
    # Always make motif_2 longer
    if align_to_longer and len(motif_1) > len(motif_2):
        motif_1, motif_2 = motif_2, motif_1
    
    # Pad motif_2 by len(motif_1) - 1 on either side
    orig_motif_2_len = len(motif_2)
    pad_size = len(motif_1) - 1
    motif_2 = np.pad(motif_2, ((pad_size, pad_size), (0, 0)))
    
    if average:
        # Compute overlap sizes
        overlap_sizes = np.empty(orig_motif_2_len + pad_size)
        overlap_sizes[:pad_size] = np.arange(1, len(motif_1))
        overlap_sizes[-pad_size:] = np.flip(np.arange(1, len(motif_1)))
        overlap_sizes[pad_size:-pad_size] = len(motif_1)
    
    # Compute similarities across all sliding windows
    scores = np.empty(orig_motif_2_len + pad_size)
    for i in range(orig_motif_2_len + pad_size):
        scores[i] = np.sum(motif_1 * motif_2[i : i + len(motif_1)])
        
    best_ind = np.argmax(scores)
    if average:
        scores = scores / overlap_sizes
    return scores[best_ind], best_ind - pad_size

In [None]:
def compute_motif_pairs(motifs_a, motifs_b):
    """
    For a list of motifs `motifs_a` and another list of motifs `motifs_b`,
    computes the similarity between all pairs and pairs up the closest
    motifs between A and B.
    This is done greedily, where the two motifs with the biggest similarity
    in the matrix are paired up, and then removed from the matrix.
    Returns a list of triplets (A_i, B_i, s_i), where A_i and B_i are the
    indices of motifs in `motifs_a` and `motifs_b` that have been matched
    up, and `s_i` is the similarity between them. The triplets are ordered
    by decreasing `s_i`. Note that if the lengths of the lists are not the
    same, then some motifs will not be matched up.
    """
    num_a, num_b = len(motifs_a), len(motifs_b)
    
    # Compute the similarity matrix
    sim_matrix = np.empty((num_a, num_b))
    for a_i in range(num_a):
        for b_i in range(num_b):
            sim, _ = motif_similarity_score(motifs_a[a_i], motifs_b[b_i])
            sim_matrix[a_i, b_i] = sim
    
    a_inds, b_inds = np.arange(num_a), np.arange(num_b)
    matches = []
    for _ in range(min(num_a, num_b)):
        a_i, b_i = np.unravel_index(np.argmax(sim_matrix), sim_matrix.shape)
        sim = sim_matrix[a_i, b_i]
        
        matches.append((a_inds[a_i], b_inds[b_i], sim))
        
        a_inds = np.delete(a_inds, a_i)
        b_inds = np.delete(b_inds, b_i)
        sim_matrix = np.delete(np.delete(sim_matrix, a_i, axis=0), b_i, axis=1)
        
    return matches

In [None]:
def aggregate_motifs(motifs):
    """
    Aggregates a list of L x 4 (not all the same L) motifs into a single
    L x 4 motif.
    """
    num_motifs = len(motifs)
    # Compute similarity matrix
    sim_matrix = np.empty((num_motifs, num_motifs))
    for i in range(num_motifs):
        for j in range(i, num_motifs):
            sim, _ = motif_similarity_score(motifs[i], motifs[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim

    # Sort motifs by how similar it is to everyone else
    inds = np.flip(np.argsort(np.sum(sim_matrix, axis=0)))
    
    # Have the consensus start with the most similar
    consensus = np.zeros_like(motifs[inds[0]])
    consensus = consensus + motifs[inds[0]]
    
    # For each successive motif, add it into the consensus
    for i in inds[1:]:
        motif = motifs[i]
        _, index = motif_similarity_score(motif, consensus, align_to_longer=False)
        if index >= 0:
            start, end = index, index + len(motif)
            consensus[start:end] = consensus[start:end] + motif[:len(consensus) - index]
        else:
            end = len(motif) + index
            consensus[:end] = consensus[:end] + motif[-index:-index + len(consensus)]
    return consensus / len(motifs)

In [None]:
def plot_matches(count_keys, count_motifs, profile_keys, profile_motifs, matches):
    """
    Show matched and leftover motifs.
    """
    display(vdomh.h4("Matched motifs"))
    colgroup = vdomh.colgroup(
        vdomh.col(style={"width": "5"}),
        vdomh.col(style={"width": "5"}),
        vdomh.col(style={"width": "45%"}),
        vdomh.col(style={"width": "45%"}),
    )
    header = vdomh.thead(
        vdomh.tr(
            vdomh.th("Merged ID", style={"text-align": "center"}),
            vdomh.th("Similarity", style={"text-align": "center"}),
            vdomh.th("Aggregate motif", style={"text-align": "center"}),
            vdomh.th("Constituent motifs", style={"text-align": "center"})
        )
    )

    rows = []
    for count_i, profile_i, sim in matches:
        constituents = [count_motifs[count_i], profile_motifs[profile_i]]
        consensus = aggregate_motifs([count_motifs[count_i], profile_motifs[profile_i]])

        merged_id = "C%s:P%s" % (count_keys[count_i], profile_keys[profile_i])

        agg_fig = viz_sequence.plot_weights(consensus, figsize=(20, 4), return_fig=True)
        agg_fig.tight_layout()
        const_figs = []
        for motif in constituents:
            fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)
            fig.tight_layout()
            const_figs.append(figure_to_vdom_image(fig))

        rows.append(vdomh.tr(
            vdomh.td(merged_id),
            vdomh.td("%.4f" % sim),
            vdomh.td(figure_to_vdom_image(agg_fig)),
            vdomh.td(*const_figs)
        ))
    display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
    plt.close("all")

    display(vdomh.h4("Unmatched motifs"))
    colgroup = vdomh.colgroup(
        vdomh.col(style={"width": "5%"}),
        vdomh.col(style={"width": "45%"}),
        vdomh.col(style={"width": "5%"}),
        vdomh.col(style={"width": "45%"}),
    )
    header = vdomh.thead(
        vdomh.tr(
            vdomh.th("Count ID", style={"text-align": "center"}),
            vdomh.th("Count motif", style={"text-align": "center"}),
            vdomh.th("Profile ID", style={"text-align": "center"}),
            vdomh.th("Profile motif", style={"text-align": "center"})
        )
    )

    matched_count = [trip[0] for trip in matches]
    unmatched_count = [i for i in range(len(count_motifs)) if i not in matched_count]
    matched_profile = [trip[1] for trip in matches]
    unmatched_profile = [i for i in range(len(profile_motifs)) if i not in matched_profile]
    rows = []
    for i in range(max(len(unmatched_count), len(unmatched_profile))):
        row = []
        if i < len(unmatched_count):
            fig = viz_sequence.plot_weights(
                count_motifs[unmatched_count[i]], figsize=(20, 4), return_fig=True
            )
            fig.tight_layout()
            row.extend([
                vdomh.td("C%s" % count_keys[unmatched_count[i]]),
                vdomh.td(figure_to_vdom_image(fig))
            ])
        else:
            row.extend([vdomh.td(), vdomh.td()])
        if i < len(unmatched_profile):
            fig = viz_sequence.plot_weights(
                profile_motifs[unmatched_profile[i]], figsize=(20, 4), return_fig=True
            )
            fig.tight_layout()
            row.extend([
                vdomh.td("P%s" % profile_keys[unmatched_profile[i]]),
                vdomh.td(figure_to_vdom_image(fig))
            ])
        else:
            row.extend([vdomh.td(), vdomh.td()])
        rows.append(vdomh.tr(*row))
    display(vdomh.table(colgroup, header, vdomh.tbody(*rows)))
    plt.close("all")

In [None]:
motif_key_sorter = lambda k: (int(k.split("_")[0]), int(k.split("_")[1]))

### Consolidate count/profile head motifs for all conditions

In [None]:
def consolidate(in_h5_group, out_h5_group):
    """
    Performs consolidation, reading motifs from the in group and
    writing to the out group.
    """
    # Import CWMs
    motifs = import_motifs_from_hdf5_group(in_h5_group)
    count_keys = sorted(motifs["count"].keys(), key=motif_key_sorter)
    profile_keys = sorted(motifs["profile"].keys(), key=motif_key_sorter)
    count_motifs = [motifs["count"][k] for k in count_keys]
    profile_motifs = [motifs["profile"][k] for k in profile_keys]
    matches = compute_motif_pairs(
        count_motifs, profile_motifs
    )

    # Plot distribution of match similarities
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.hist([trip[2] for trip in matches], bins=10)
    ax.set_title("Histogram of match distances")
    ax.set_xlabel("Similarity")
    plt.show()

    # Restrict to matches of sufficient similarity
    matches = [match for match in matches if match[2] >= min_match_sim]

    # Plot matches themselves
    plot_matches(count_keys, count_motifs, profile_keys, profile_motifs, matches)

    # Write the consolidated motifs and unconsolidated motifs
    matched_keys = [(count_keys[trip[0]], profile_keys[trip[1]]) for trip in matches]
    write_merged_motifs_to_hdf5_group(in_h5_group, out_h5_group, matched_keys)

<a id="results"></a>

In [None]:
with h5py.File(in_motif_file, "r") as f, h5py.File(out_motif_file, "w") as g:
    # Multi-task, all 10 folds
    f_mt = f["multitask"]
    g_mt = g.create_group("multitask")
    for fold in f_mt.keys():
        display(vdomh.h3("Multi-task %s" % fold))
        
        f_mt_fold = f_mt[fold]
        g_mt_fold = g_mt.create_group(fold)
        
        consolidate(f_mt_fold, g_mt_fold)
        
    f_st = f["singletask"]
    g_st = g.create_group("singletask")
    
    # Single-task, all 10 folds for all tasks
    for task in f_st.keys():
        f_st_task = f_st[task]
        g_st_task = g_st.create_group(task)
        
        for fold in f_st_task.keys():
            display(vdomh.h3("Single-task %s %s" % (task, fold)))

            f_st_task_fold = f_st_task[fold]
            g_st_task_fold = g_st_task.create_group(fold)

            consolidate(f_st_task_fold, g_st_task_fold)
            
    # Multi-task fine-tuned, all tasks
    f_mtft = f["multitask_finetune"]
    g_mtft = g.create_group("multitask_finetune")
    
    for task in f_mtft.keys():
        display(vdomh.h3("Multi-task fine-tune %s" % task))
        f_mtft_task = f_mtft[task]
        g_mtft_task = g_mtft.create_group(task)

        consolidate(f_mtft_task, g_mtft_task)
        
    # Single-task fine-tuned, all tasks
    f_stft = f["singletask_finetune"]
    g_stft = g.create_group("singletask_finetune")
    
    for task in f_stft.keys():
        display(vdomh.h3("Single-task fine-tune %s" % task))
        f_stft_task = f_stft[task]
        g_stft_task = g_stft.create_group(task)

        consolidate(f_stft_task, g_stft_task)