In [None]:
import sys
import os
sys.path.append(os.path.abspath("../../src/"))
import feature.util as feature_util
import feature.make_profile_dataset as make_profile_dataset
import model.profile_performance as profile_performance
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
import sklearn.cluster
import scipy.cluster.hierarchy
import json
import tqdm
tqdm.tqdm_notebook(range(0))

### Define paths for data of interest

In [None]:
tf_name = "GABPA"
num_tasks = 9
fold_num = 9
task_index = 7

# Note that this only works for single task queries right now

files_spec_path = "/users/amtseng/tfmodisco/data/processed/ENCODE/config/{0}/{0}_training_paths.json".format(tf_name)

In [None]:
# Extract the file specs
with open(files_spec_path, "r") as f:
    files_spec = json.load(f)
peaks_beds = files_spec["peak_beds"]
profile_hdf5 = files_spec["profile_hdf5"]

In [None]:
# Define the paths to the files and model, and some constants
reference_fasta = "/users/amtseng/genomes/hg38.fasta"
chrom_sizes = "/users/amtseng/genomes/hg38.canon.chrom.sizes"
chrom_splits_json = "/users/amtseng/tfmodisco/data/processed/ENCODE/chrom_splits.json"
input_length = 2114
profile_length = 1000

revcomp = False
chrom_set_key = "test"

In [None]:
# Get chromosome sets
with open(chrom_splits_json, "r") as f:
    chrom_splits = json.load(f)
split = chrom_splits[str(fold_num)]

if chrom_set_key in ("train", "val", "test"):
    chrom_set = split[chrom_set_key]
else:
    chrom_set = split["train"] + split["val"] + split["test"]

### Data preparation
Prepare inputs

In [None]:
# Maps coordinates to 1-hot encoded sequence
coords_to_seq = feature_util.CoordsToSeq(reference_fasta, center_size_to_use=input_length)

# Maps coordinates to profiles
coords_to_vals = make_profile_dataset.CoordsToVals(profile_hdf5, profile_length)

# Maps many coordinates to inputs sequences and profiles for the network
def coords_to_network_inputs(coords):
    input_seq = coords_to_seq(coords)
    profs = coords_to_vals(coords)

    if revcomp:
        input_seq_rc = np.flip(input_seq, axis=(1, 2))
        profs_rc = np.flip(profs, axis=(1, 3))
        return np.concatenate([input_seq, input_seq_rc]), np.swapaxes(np.concatenate([profs, profs_rc]), 1, 2)
    else:
        return input_seq, np.swapaxes(profs, 1, 2)

In [None]:
# Import set of positive peaks
pos_coords_table = pd.read_csv(peaks_beds[task_index], sep="\t", header=None, compression="gzip")
pos_coords_table = pos_coords_table[pos_coords_table[0].isin(chrom_set)]

# Summit-center
arr = pos_coords_table.values
arr[:, 1] = arr[:, 1] + arr[:, 9] - (input_length // 2)
arr[:, 2] = arr[:, 1] + input_length
pos_coords = arr[:, :3]

### Get the set of all control and target profiles

In [None]:
num_coords = len(pos_coords)
num_samples = num_coords * (2 if revcomp else 1)
target_profs = np.empty((num_samples, profile_length, 2))
cont_profs = np.empty((num_samples, profile_length, 2))

batch_size = 128
num_batches = int(np.ceil(num_coords / batch_size))
for i in tqdm.notebook.trange(num_batches):
    batch_slice = slice(i * batch_size, (i + 1) * batch_size)
    _, profiles = coords_to_network_inputs(pos_coords[batch_slice])
    
    target_profs[batch_slice] = profiles[:, task_index]
    cont_profs[batch_slice] = profiles[:, num_tasks + task_index]

### View profile correlations

In [None]:
def plot_profiles(target_profs, cont_profs):
    """
    Plots the given profiles with a heatmap.
    Arguments:
        `target_profs`: an N x O x 2 NumPy array of target profiles, either as raw
            counts or probabilities (they will be normalized)
        `cont_profs`: an N x O x 2 NumPy array of control profiles, either as
            raw counts or probabilities (they will be normalized)
    """
    assert len(target_profs.shape) == 3
    assert target_profs.shape == cont_profs.shape
    num_profs, width, _ = target_profs.shape

    # First, normalize the profiles along the output profile dimension
    def normalize(arr, axis=0):
        arr_sum = np.sum(arr, axis=axis, keepdims=True)
        arr_sum[arr_sum == 0] = 1  # If 0, keep 0 as the quotient instead of dividing by 0
        return arr / arr_sum
    target_profs_norm = normalize(target_profs, axis=1)
    cont_profs_norm = normalize(cont_profs, axis=1)

    # Compute the mean profiles across all examples
    target_profs_mean = np.mean(target_profs_norm, axis=0)
    cont_profs_mean = np.mean(cont_profs_norm, axis=0)

    # Perform k-means clustering on the target profiles, with the strands pooled
    kmeans_clusters = max(5, num_profs // 50)  # Set number of clusters based on number of profiles, with minimum
    kmeans = sklearn.cluster.KMeans(n_clusters=kmeans_clusters)
    cluster_assignments = kmeans.fit_predict(
        np.reshape(target_profs_norm, (target_profs_norm.shape[0], -1))
    )

    # Perform hierarchical clustering on the cluster centers to determine optimal ordering
    kmeans_centers = kmeans.cluster_centers_
    cluster_order = scipy.cluster.hierarchy.leaves_list(
        scipy.cluster.hierarchy.optimal_leaf_ordering(
            scipy.cluster.hierarchy.linkage(kmeans_centers, method="centroid"), kmeans_centers
        )
    )

    # Order the profiles so that the cluster assignments follow the optimal ordering
    cluster_inds = []
    for cluster_id in cluster_order:
        cluster_inds.append(np.where(cluster_assignments == cluster_id)[0])
    cluster_inds = np.concatenate(cluster_inds)

    # Compute a matrix of profiles, normalized to the maximum height, ordered by clusters
    def make_profile_matrix(flat_profs, order_inds):
        matrix = flat_profs[order_inds]
        maxes = np.max(matrix, axis=1, keepdims=True)
        maxes[maxes == 0] = 1  # If 0, keep 0 as the quotient instead of dividing by 0
        return matrix / maxes
    target_matrix = make_profile_matrix(target_profs_norm, cluster_inds)
    cont_matrix = make_profile_matrix(cont_profs_norm, cluster_inds)

    # Create a figure with the right dimensions
    mean_height = 4
    heatmap_height = min(num_profs * 0.004, 8)
    fig_height = mean_height + (2 * heatmap_height)
    fig, ax = plt.subplots(
        3, 2, figsize=(16, fig_height), sharex=True,
        gridspec_kw={
            "width_ratios": [1, 1],
            "height_ratios": [mean_height / fig_height, heatmap_height / fig_height, heatmap_height / fig_height]
        }
    )

    # Plot the average profiles
    ax[0, 0].plot(target_profs_mean[:, 0], color="darkslateblue")
    ax[0, 0].plot(-target_profs_mean[:, 1], color="darkorange")
    ax[0, 1].plot(cont_profs_mean[:, 0], color="darkslateblue")
    ax[0, 1].plot(-cont_profs_mean[:, 1], color="darkorange")

    # Set axes on average profiles
    max_mean_val = max(np.max(target_profs_mean), np.max(cont_profs_mean))
    mean_ylim = max_mean_val * 1.05  # Make 5% higher
    ax[0, 0].set_title("Target profiles")
    ax[0, 0].set_ylabel("Average probability")
    ax[0, 1].set_title("Control profiles")
    for j in (0, 1):
        ax[0, j].set_ylim(-mean_ylim, mean_ylim)
        ax[0, j].label_outer()

    # Plot the heatmaps
    ax[1, 0].imshow(target_matrix[:, :, 0], interpolation="nearest", aspect="auto", cmap="Blues")
    ax[1, 1].imshow(cont_matrix[:, :, 0], interpolation="nearest", aspect="auto", cmap="Blues")
    ax[2, 0].imshow(target_matrix[:, :, 1], interpolation="nearest", aspect="auto", cmap="Oranges")
    ax[2, 1].imshow(cont_matrix[:, :, 1], interpolation="nearest", aspect="auto", cmap="Oranges")

    # Set axes on heatmaps
    for i in (1, 2):
        for j in (0, 1):
            ax[i, j].set_yticks([])
            ax[i, j].set_yticklabels([])
            ax[i, j].label_outer()
    width = target_matrix.shape[1]
    delta = 100
    num_deltas = (width // 2) // delta
    labels = list(range(max(-width // 2, -num_deltas * delta), min(width // 2, num_deltas * delta) + 1, delta))
    tick_locs = [label + max(width // 2, num_deltas * delta) for label in labels]
    for j in (0, 1):
        ax[2, j].set_xticks(tick_locs)
        ax[2, j].set_xticklabels(labels)
        ax[2, j].set_xlabel("Distance from summit center (bp)")

    fig.tight_layout()
    plt.show()

In [None]:
plot_profiles(target_profs, cont_profs)

In [None]:
prof_pears, prof_spear, _ = profile_performance.profile_corr_mse(
    np.expand_dims(target_profs, axis=1), np.expand_dims(cont_profs, axis=1),
    1, 0, False, False
)
prof_pears, prof_spear = prof_pears[:, 0], prof_spear[:, 0]

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(20, 8))
ax[0].hist(prof_pears, bins=100)
ax[1].hist(prof_spear, bins=100)
ax[0].set_xlabel("Profile Pearson correlations")
ax[1].set_xlabel("Proflie Spearman correlations")
plt.show()

### View count correlations

In [None]:
target_counts = np.sum(target_profs, axis=1)
cont_counts = np.sum(cont_profs, axis=1)
log_target_counts = np.ravel(np.log(target_counts + 1))
log_cont_counts = np.ravel(np.log(cont_counts + 1))

fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter(log_target_counts, log_cont_counts, alpha=0.1)

xlims = ax.get_xlim()
ylims = ax.get_ylim()
min_limit, max_limit = np.min([xlims, ylims]), np.max([xlims, ylims])
ax.plot([min_limit, max_limit], [min_limit, max_limit], color="black", linestyle="--")
ax.set_xlim(xlims)
ax.set_ylim(ylims)

ax.set_xlabel("Log target counts")
ax.set_ylabel("Log control counts")
plt.show()

print("Pearson: %f" % scipy.stats.pearsonr(log_target_counts, log_cont_counts)[0])
print("Spearman: %f" % scipy.stats.spearmanr(log_target_counts, log_cont_counts)[0])