In [None]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import motif.tfmodisco_hit_scoring as tfmodisco_hit_scoring
import motif.match_motifs as match_motifs
import plot.viz_sequence as viz_sequence
from util import figure_to_vdom_image, create_motif_similarity_matrix, aggregate_motifs, aggregate_motifs_from_inds
import tempfile
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import scipy.signal
import scipy.cluster.hierarchy
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook(range(1))

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold",
    "svg.fonttype": "none"
}
plt.rcParams.update(plot_params)

In [None]:
if "TFM_TF_NAME" in os.environ:
    tf_name = os.environ["TFM_TF_NAME"]
else:
    tf_name = "E2F6"

In [None]:
# Manually define the clusters of core motifs

if tf_name == "E2F6":
    core_motif_defs = [
        ("MAX", ["T0:C0_0:P0_0", "T1:C0_0:P0_0"]),
        ("E2F6", ["T0:C0_1:P0_1", "T1:C0_1:P0_1"]),
        ("AP-1", ["T0:C0_5", "T0:P0_9", "T1:C0_7:P0_11"]),
        ("other E2F", ["T0:C0_4", "T0:P0_4", "T0:P0_6", "T1:C0_3:P0_3"]),
        ("weak MAX v1", ["T1:C0_2:P0_7", "T1:C0_6:P0_2"]),
        ("weak MAX v2", ["T0:C0_2:P0_2", "T1:C0_5", "T1:P0_4"])
    ]
elif tf_name == "SPI1":
    core_motif_defs = [
        ("SPI1", ["T1:C0_0:P0_0", "T3:C0_0:P0_0", "T1:P0_7", "T0:C0_4", "T0:C0_0:P0_0", "T2:C0_0:P0_0"]),
        ("SPI1 v2", ["T1:C0_4", "T1:P0_5", "T3:C0_2"]),
        ("AP-1", ["T2:C0_5", "T1:C0_3", "T2:C0_6", "T2:P0_7", "T2:C0_2"]),
        ("CEBP", ["T3:C0_1:P0_1"]),
        ("RUNX", ["T2:C0_8", "T3:C0_3", "T3:P0_5", "T3:P0_7"]),
        ("GATA", ["T1:C0_1"])
    ]

In [None]:
tf_num_tasks = {
    "E2F6": 2,
    "FOXA2": 4,
    "SPI1": 4,
    "CEBPB": 7,
    "MAX": 7,
    "GABPA": 9,
    "MAFK": 9,
    "JUND": 14,
    "NR3C1-reddytime": 16,
    "REST": 20
}

tf_best_model_types = {
    "E2F6": list("MM"),
    "FOXA2": list("SSMM"),
    "SPI1": list("MSSS"),
    "CEBPB": list("MMMMSMM"),
    "MAX": list("MMSMMSS"),
    "GABPA": list("MMMSMMMMM"),
    "MAFK": list("MMMMMMMMM"),
    "JUND": list("SMMSMSSSSSSSMS"),
    "NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
    "REST": list("MMMMMMMMMSMMSMMSMMMM")
}

num_tasks = tf_num_tasks[tf_name]
best_model_types = tf_best_model_types[tf_name]

tfm_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/%s_tfmodisco_cpmerged_motifs.h5" % tf_name
meme_motif_file = "/users/amtseng/tfmodisco/results/motifs/meme/%s_meme_motifs.h5" % tf_name
homer_motif_file = "/users/amtseng/tfmodisco/results/motifs/homer/%s_homer_motifs.h5" % tf_name
dichipmunk_motif_file = "/users/amtseng/tfmodisco/results/motifs/dichipmunk/%s_dichipmunk_motifs.h5" % tf_name

multitask_finetune_model_def_tsv = "/users/amtseng/tfmodisco/results/model_stats/multitask_profile_finetune_stats.tsv"
singletask_finetune_model_def_tsv = "/users/amtseng/tfmodisco/results/model_stats/singletask_profile_finetune_stats.tsv"

In [None]:
motif_database_path = "/users/amtseng/tfmodisco/data/processed/motif_databases/JASPAR2020_CORE_vertebrates_non-redundant_pfms_meme.txt"

### Helper functions

In [None]:
def get_motif_hit_paths():
    """
    Returns a list of pairs, where each pair is the count and profile
    motif hit paths for the task.
    """
    # First, import the best fold definitions
    
    # Finetuned multi-task model
    best_mt_fold = None
    with open(multitask_finetune_model_def_tsv, "r") as f:
        for line in f:
            tokens = line.strip().split("\t")
            if tokens[0] == tf_name and int(tokens[1]) == num_tasks - 1:
                assert best_mt_fold is None
                best_mt_fold = int(tokens[2])

    # Finetuned single-task models
    best_st_folds = []
    with open(singletask_finetune_model_def_tsv, "r") as f:
        for line in f:
            tokens = line.strip().split("\t")
            if tokens[0] == tf_name:
                best_st_folds.append(int(tokens[2]))
                
    assert len(best_st_folds) == num_tasks

    # Get paths to motif hits
    task_motif_hit_paths = []
    base_path = "/users/amtseng/tfmodisco/results/tfmodisco_hit_scoring"
    for task_index, model_type in enumerate(best_model_types):
        if model_type == "M":
            path = os.path.join(
                base_path,
                "multitask_profile_finetune",
                "%s_multitask_profile_finetune_task%d_fold%d_{0}" % (tf_name, task_index, best_mt_fold),
                "tfm_matches.bed"
            )
        else:
            path = os.path.join(
                base_path,
                "singletask_profile_finetune",
                "%s_singletask_profile_finetune_fold%d_{0}" % (tf_name, best_st_folds[task_index]),
                "task_%d" % task_index,
                "tfm_matches.bed"
            )
        task_motif_hit_paths.append(
            (path.format("count"), path.format("profile"))
        )
    return task_motif_hit_paths

In [None]:
def purine_rich_motif(motif):
    """
    Flip motif to be the purine-rich orientation
    """
    if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
        return np.flip(motif, axis=(0, 1))
    return motif

In [None]:
def import_tfmodisco_motifs(motif_file, model_types, motif_type="cwm_trimmed"):
    """
    From a file containing all motifs for that TF, imports the
    trimmed CWMs (or another kind of motif type) of the fine-tuned models
    corresponding to the model type for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to motif.
    """
    motifs = []
    with h5py.File(motif_file, "r") as f:
        mtft = f["multitask_finetune"]
        stft = f["singletask_finetune"]
        for i, model_type in enumerate(model_types):
            task = "task_%d" % i
            if model_type == "M":
                dset = mtft[task]
            else:
                dset = stft[task]
            task_motifs = {}
            for motif_key in dset.keys():
                if "0_" in motif_key:
                    # Motifs that are (or are constructed from) positive metacluster only
                    task_motifs["T%d:%s" % (i, motif_key)] = purine_rich_motif(dset[motif_key][motif_type][:])
            motifs.append(task_motifs)
    return motifs

In [None]:
def import_classic_benchmark_motifs(motif_file, mode):
    """
    From a file containing all motifs for that TF from a benchmark
    method, imports the PFMs of the motifs for each task.
    Returns a list of dictionaries (one for each task), where
    each dictionary maps motif key to motif.
    """
    if mode == "dichipmunk":
        score_key = "supporting_seqs"
    elif mode == "homer":
        score_key = "log_enrichment"
    elif mode == "meme":
        score_key = "evalue"
    motifs = []
    with h5py.File(motif_file, "r") as f:
        tasks = sorted([int(key[5:]) for key in f.keys() if key != "task_agg"])
        for i in tasks:
            dset = f["task_%d" % i]
            task_motifs = {}
            for motif_key in dset.keys():
                if motif_key == score_key:
                    continue
                task_motifs["T%d:%s" % (i, motif_key)] = purine_rich_motif(dset[motif_key][:])
            motifs.append(task_motifs)
    return motifs

In [None]:
def import_database_pfms(database_path):
    """
    Imports the database of PFMs by reading through the entire database and
    constructing a dictionary mapping motif IDs to NumPy arrays of PFMs.
    """
    motif_dict = {}
    with open(database_path, "r") as f:
        try:
            while True:
                line = next(f)
                if line.startswith("MOTIF"):
                    key = line.strip().split()[1]
                    header = next(f)
                    motif_width = int(header.split()[5])
                    motif = np.empty((motif_width, 4))
                    for i in range(motif_width):
                        motif[i] = np.array([
                            float(x) for x in next(f).strip().split()
                        ])
                    # Add the motif with a shortened key
                    motif_dict[key.split("_")[1]] = purine_rich_motif(motif)
        except StopIteration:
            pass
    return motif_dict

In [None]:
def get_closest_tomtom_motif_similarities(query_dict, target_dict):
    """
    From a dictionary mapping N motif keys to query motifs, and a
    dictionary mapping M motif keys to target motifs, returns a
    dictionary mapping the N query motif keys to the similarity and
    key of the closest target motif (a pair). Similarity is the
    -log(p) from TOMTOM.
    """
    query_keys, query_pfms = list(zip(*query_dict.items()))
    target_keys, target_pfms = list(zip(*target_dict.items()))
    
    # Create temporary directory to do work in
    temp_dir_obj = tempfile.TemporaryDirectory()
    temp_dir = temp_dir_obj.name

    # Convert motifs to MEME format
    query_motif_file = os.path.join(temp_dir, "query_motifs.txt")
    target_motif_file = os.path.join(temp_dir, "target_motifs.txt")
    match_motifs.export_pfms_to_meme_format(query_pfms, query_motif_file)
    match_motifs.export_pfms_to_meme_format(target_pfms, target_motif_file)

    # Run TOMTOM
    tomtom_dir = os.path.join(temp_dir, "tomtom")
    match_motifs.run_tomtom(
        query_motif_file, target_motif_file, tomtom_dir,
        show_output=False
    )

    # Find results, mapping each query motif to target index
    # The query/target IDs are the indices
    tomtom_table = match_motifs.import_tomtom_results(tomtom_dir)
    matches = []
    for i in range(len(query_pfms)):
        rows = tomtom_table[tomtom_table["Query_ID"] == i]
        if rows.empty:
            matches.append((0, "N/A"))
            continue
        min_row = rows.loc[rows["p-value"].idxmin()]
        score = -np.log10(min_row["p-value"])
        target_key = target_keys[min_row["Target_ID"]]
        matches.append((score, target_key))

    temp_dir_obj.cleanup()
    
    return dict(zip(query_keys, matches))

### Import motifs

In [None]:
tfm_cwm_motifs = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "cwm_trimmed")
tfm_pfm_motifs = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "pfm_trimmed")
meme_motifs = import_classic_benchmark_motifs(meme_motif_file, "meme")
homer_motifs = import_classic_benchmark_motifs(homer_motif_file, "homer")
dichipmunk_motifs = import_classic_benchmark_motifs(dichipmunk_motif_file, "dichipmunk")

### Cluster the TF-MoDISco motifs across each task
We need to decide which ones to merge together. We start with a best guess for clustering, and then manually decide on the right motifs from each task to cluster.

In [None]:
# Flatten all TF-MoDISco motifs across all tasks into a single list
tfm_motif_keys = [list(d.keys()) for d in tfm_cwm_motifs]
tfm_motif_cwms = [[tfm_cwm_motifs[i][key] for key in tfm_motif_keys[i]] for i in range(len(tfm_motif_keys))]
tfm_motif_pfms = [[tfm_pfm_motifs[i][key] for key in tfm_motif_keys[i]] for i in range(len(tfm_motif_keys))]
tfm_motif_keys = sum(tfm_motif_keys, [])
tfm_motif_cwms = sum(tfm_motif_cwms, [])
tfm_motif_pfms = sum(tfm_motif_pfms, [])

In [None]:
# Compute similarity matrix
sim_matrix = create_motif_similarity_matrix(tfm_motif_cwms)

In [None]:
# Compute linkage and clusters
dist_matrix = 1 - sim_matrix
np.fill_diagonal(dist_matrix, 0)
dist_vec = scipy.spatial.distance.squareform(dist_matrix)

cluster_distance = 0.6  # On the greedy side
linkage = scipy.cluster.hierarchy.linkage(dist_vec, method="ward")
clusters = scipy.cluster.hierarchy.fcluster(
    linkage, cluster_distance, criterion="distance"
)

In [None]:
# Show aggregated and constituent motifs for each cluster
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "45%"}),
    vdomh.col(style={"width": "45%"}),
    vdomh.col(style={"width": "10%"})
)

header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Aggregate motif", style={"text-align": "center"}),
        vdomh.th("Constituent motifs", style={"text-align": "center"}),
        vdomh.th("Constituent motif IDs", style={"text-align": "center"})
    )
)

cluster_ids, counts = np.unique(clusters, return_counts=True)
for i, cluster_id in enumerate(cluster_ids):
    match_inds = np.where(clusters == cluster_id)[0]
    match_cwms = [tfm_motif_cwms[j] for j in match_inds]
    match_keys = [tfm_motif_keys[j] for j in match_inds]
    
    consensus_cwm = aggregate_motifs(match_cwms)
    
    display(vdomh.h3("Cluster %d (%d/%d)" % (cluster_id, i + 1, len(cluster_ids))))
    display(vdomh.h4("%d motifs" % len(match_cwms)))
    
    agg_fig = viz_sequence.plot_weights(consensus_cwm, figsize=(20, 4), return_fig=True)
    agg_fig.tight_layout()
    
    const_figs = []
    for cwm in match_cwms:
        fig = viz_sequence.plot_weights(cwm, figsize=(20, 4), return_fig=True)
        fig.tight_layout()
        const_figs.append(figure_to_vdom_image(fig))

    body = vdomh.tbody(*([
        vdomh.tr(
            vdomh.td(figure_to_vdom_image(agg_fig), rowspan=str(len(match_cwms))),
            vdomh.td(const_figs[0]),
            vdomh.td(match_keys[0])
        )] + [
            vdomh.tr(
                vdomh.td(const_figs[j + 1]),
                vdomh.td(match_keys[j + 1])
            ) for j in range(len(match_cwms) - 1)
        ]
    ))
    display(vdomh.table(colgroup, header, body))
    plt.close("all")

In [None]:
# Show aggregated and constituent motifs for the final clusterings
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "45%"}),
    vdomh.col(style={"width": "45%"}),
    vdomh.col(style={"width": "10%"})
)

header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Aggregate motif", style={"text-align": "center"}),
        vdomh.th("Constituent motifs", style={"text-align": "center"}),
        vdomh.th("Constituent motif IDs", style={"text-align": "center"})
    )
    )

core_motif_cwms = {}
core_motif_pfms = {}
for agg_motif_name, const_motif_keys in core_motif_defs:    
    const_motif_cwms = [tfm_motif_cwms[tfm_motif_keys.index(key)] for key in const_motif_keys]
    const_motif_pfms = [tfm_motif_pfms[tfm_motif_keys.index(key)] for key in const_motif_keys]
    
    agg_cwm, (const_inds, agg_inds) = aggregate_motifs(const_motif_cwms, return_inds=True)
    core_motif_cwms[agg_motif_name] = agg_cwm
    
    # Construct aggregate PFM
    agg_pfm = aggregate_motifs_from_inds(const_motif_pfms, const_inds, agg_inds)
    agg_pfm = agg_pfm / np.sum(agg_pfm, axis=1, keepdims=True)
    core_motif_pfms[agg_motif_name] = agg_pfm
        
    display(vdomh.h3(agg_motif_name))
    
    agg_fig = viz_sequence.plot_weights(agg_cwm, figsize=(20, 4), return_fig=True)
    agg_fig.tight_layout()
    
    const_figs = []
    for motif in const_motif_cwms:
        fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)
        fig.tight_layout()
        const_figs.append(figure_to_vdom_image(fig))

    body = vdomh.tbody(*([
        vdomh.tr(
            vdomh.td(figure_to_vdom_image(agg_fig), rowspan=str(len(const_motif_keys))),
            vdomh.td(const_figs[0]),
            vdomh.td(const_motif_keys[0])
        )] + [
            vdomh.tr(
                vdomh.td(const_figs[j + 1]),
                vdomh.td(const_motif_keys[j + 1])
            ) for j in range(len(const_motif_keys) - 1)
        ]
    ))
    display(vdomh.table(colgroup, header, body))
    plt.close("all")

### Extract constituent motif prevalences
For each aggregated motif, extract the prevalence of the constituent motifs (by task) in the peaks.

In [None]:
# Import the motif hits for each task
task_motif_hit_paths = get_motif_hit_paths()
task_motif_hits = []
for count_path, profile_path in task_motif_hit_paths:
    count_table = tfmodisco_hit_scoring.import_tfmodisco_hits(count_path)[["key", "peak_index"]]
    profile_table = tfmodisco_hit_scoring.import_tfmodisco_hits(profile_path)[["key", "peak_index"]]
    # We only need the key and peak index
    task_motif_hits.append({"C": count_table, "P": profile_table})

In [None]:
def get_hit_prevalence(hit_table, motif_keys):
    """
    Computes the motif prevalence from the hit table, as the proportion of
    peaks which have hits in the given motif keys.
    """
    total_peaks = len(np.unique(hit_table["peak_index"]))
    hit_peaks = len(np.unique(hit_table[np.isin(hit_table["key"], motif_keys)]["peak_index"]))
    return hit_peaks / total_peaks

In [None]:
# Create matrix of motif prevalences
motif_prevalences = np.zeros((len(core_motif_defs), len(task_motif_hits)))
for i, (_, motif_keys) in enumerate(core_motif_defs):
    # Map each task index to the motif keys belonging to this aggregate motif (if any)
    task_const_keys = {j : [] for j in range(len(task_motif_hits))}
    for const_key in motif_keys:
        tokens = const_key.split(":")
        task_index = int(tokens[0][1:])
        task_const_keys[task_index].append(":".join(tokens[1:]))
        
    # Get the sum of prevalences for each task
    for j in task_const_keys:
        # Extract the set of motif keys, separately for counts/profiles
        motif_keys = {}
        for key in task_const_keys[j]:
            tokens = key.split(":")
            # May be compound key
            for token in tokens:
                head, motif_key = token[0], token[1:]
                try:
                    motif_keys[head].append(motif_key)
                except KeyError:
                    motif_keys[head] = [motif_key]
        
        # Compute prevalence over the motif keys, taking the maximum over the count/profile heads
        motif_prevalences[i, j] = max(
            get_hit_prevalence(task_motif_hits[j][head], motif_keys[head])
            for head in motif_keys.keys()
        ) if motif_keys else 0

### Compute similarity of benchmark motifs to aggregated motifs
For each aggregated motif, compute the similarity of the closest motif in each benchmark for each task.

In [None]:
def get_closest_motifs(query_motifs, target_motifs):
    """
    From a list of N target CWMs in `target_motifs`, and a list of
    M query CWMs in `query_motifs`, computes the most similar target
    motif to each query motif. Returns an N-array.
    """
    # Build similarity matrix
    sim_matrix = create_motif_similarity_matrix(query_motifs, target_motifs, show_progress=False)
    return np.max(sim_matrix, axis=1)

In [None]:
def get_benchmark_similarities(query_motifs, benchmark_motifs):
    """
    From a list of N target CWMs and a list of T dictionaries mapping
    motif keys to CWMs, computes an N x T matrix of the best motif
    similarity in each task to each query motif.
    """
    matrix = np.empty((len(query_motifs), len(benchmark_motifs)))
    for i in range(len(query_motifs)):
        for j in range(len(benchmark_motifs)):
            matrix[:, j] = get_closest_motifs(query_motifs, list(benchmark_motifs[j].values()))
    return matrix

In [None]:
query_motifs = [core_motif_cwms[pair[0]] for pair in core_motif_defs]

meme_best_sims = get_benchmark_similarities(query_motifs, meme_motifs)
homer_best_sims = get_benchmark_similarities(query_motifs, homer_motifs)
dichipmunk_best_sims = get_benchmark_similarities(query_motifs, dichipmunk_motifs)

### Compute similarity of benchmark motifs to aggregated motifs
For each aggregated motif, compute the similarity of the closest motif in the database of motifs.

In [None]:
database_motifs = import_database_pfms(motif_database_path)

In [None]:
query_motifs = {
    pair[0] : core_motif_pfms[pair[0]] for pair in core_motif_defs
}

database_best_sims = get_closest_tomtom_motif_similarities(
    query_motifs, database_motifs
)

### Construct the plot

In [None]:
height = motif_prevalences.shape[0] * 2
width = motif_prevalences.shape[1] * 4 + 1

fig, ax = plt.subplots(
    ncols=3, figsize=(width, height),
    gridspec_kw={
        "width_ratios": [(width - 1) * (1/2), (width - 1) * (3/8), (width - 1) * (1/8)],
        "wspace": 0,
    }
)

# Plot motif prevalences in each task

y, x = np.unravel_index(np.arange(motif_prevalences.size), motif_prevalences.shape)
x, y = x + 0.5, y + 0.5

# Set the radius such that the area is proportional to the prevalence
max_area = np.pi * (0.5 ** 2)
assert np.min(motif_prevalences) >= 0 and np.max(motif_prevalences) <= 1
area = motif_prevalences * max_area
radius = np.sqrt(area / np.pi)

# Plot the data
ax[0].set_xlim(0, motif_prevalences.shape[1])
ax[0].set_ylim(0, motif_prevalences.shape[0])
ax[0].set_xticks(np.arange(0.5, motif_prevalences.shape[1] + 0.5))
ax[0].set_yticks(np.arange(0.5, motif_prevalences.shape[0] + 0.5))
ax[0].set_xticklabels(["task_%d" % i for i in np.arange(0, motif_prevalences.shape[1])])
ax[0].set_yticklabels([pair[0] for pair in core_motif_defs][::-1])  # Flip y-axis
for i in range(motif_prevalences.shape[1]):
    ax[0].axvline(i, color="gray", alpha=0.2)
for i in range(motif_prevalences.shape[0]):
    ax[0].axhline(i, color="gray", alpha=0.2)

for i in range(motif_prevalences.shape[0]):
    for j in range(motif_prevalences.shape[1]):
        circle = plt.Circle((j + 0.5, motif_prevalences.shape[0] - i - 1 + 0.5), radius[i, j], alpha=0.3)
        ax[0].add_patch(circle)
        
# Plot benchmark motif distances in each task

# Create the benchmark array to show
full_sim_matrix = np.empty((meme_best_sims.shape[0], meme_best_sims.shape[1] * 3))
sim_matrices = [meme_best_sims, homer_best_sims, dichipmunk_best_sims]
for i in range(3):
    full_sim_matrix[:, np.arange(0, meme_best_sims.shape[1] * 3, 3) + i] = sim_matrices[i]
hm = ax[1].imshow(full_sim_matrix, cmap="Oranges")
fig.colorbar(hm)
ax[1].set_aspect("auto")
ax[1].set_xticks(np.arange(full_sim_matrix.shape[1]))
ax[1].set_xticklabels(
    sum([["task_%d_%s" % (i, s) for s in ("M", "H", "D")] for i in range(meme_best_sims.shape[1])], []),
    rotation=90
)
ax[1].set_yticks([])

# Plot the database motif distances

database_sims = np.array([database_best_sims[pair[0]][0] for pair in core_motif_defs])
database_labels = [database_best_sims[pair[0]][1] for pair in core_motif_defs]
ax[2].imshow(database_sims[:, None], cmap="Oranges")

# Create annotations
for i in range(len(database_labels)):
    ax[2].text(0, i, database_labels[i], ha="center", va="center")

ax[2].set_aspect("auto")
ax[2].set_xticks([0])
ax[2].set_xticklabels(["Database"])
ax[2].set_yticks([])

fig.tight_layout()
plt.show()

In [None]:
# Show aggregate motifs
for key, _ in core_motif_defs:
    display(vdomh.h3(key))
    viz_sequence.plot_weights(core_motif_cwms[key])