In [None]:
%%capture --no-display
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(os.path.abspath("../src/"))
import model.profile_models as profile_models
import model.train_profile_model as train_profile_model
import model.spline as spline
import feature.util as feature_util
import feature.make_profile_dataset as make_profile_dataset
import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import json
import pyBigWig
import tqdm
tqdm.tqdm_notebook(range(0))

### Define paths for the model and data of interest

In [None]:
# num_tasks = 9
# fold_num = 8
# run_num = 2
# epoch_num = 7
# files_spec_path = "/users/amtseng/tfmodisco/data/processed/ENCODE/config/{0}/{0}_training_paths.json".format(tf_name)
# # model_path = "/users/amtseng/tfmodisco/models/trained_models/%s_fold%d/%d/model_ckpt_epoch_%d.h5" % (tf_name, fold_num, run_num, epoch_num)

# model_path = "/users/amtseng/tfmodisco/models/trained_models/MAFK_finetune/tasks_separate_twoconv_univspline_concatlater/54/model_ckpt_epoch_7.h5"
# # model_path = "/users/amtseng/tfmodisco/models/trained_models/MAFK_finetune/tasks_separate_twoconv_univspline/27/model_ckpt_epoch_15.h5"

In [None]:
tf_name = "SPI1"
num_tasks = 4
fold_num = 7

files_spec_path = "/users/amtseng/tfmodisco/data/processed/ENCODE/config/{0}/{0}_training_paths.json".format(tf_name)
# model_path = "/users/amtseng/tfmodisco/models/trained_models/SPI1_fold7/3/model_ckpt_epoch_9.h5"
model_path = "/users/amtseng/tfmodisco/models/trained_models/SPI1_finetune/tasks_separate_twoconv_univspline_concatlater/23/model_ckpt_epoch_6.h5"

In [None]:
# Define the paths to the files and model, and some constants
reference_fasta = "/users/amtseng/genomes/hg38.fasta"
chrom_sizes = "/users/amtseng/genomes/hg38.canon.chrom.sizes"
chrom_splits_json = "/users/amtseng/tfmodisco/data/processed/ENCODE/chrom_splits.json"
input_length = 1346
profile_length = 1000

mappability_path = "/users/amtseng/tfmodisco/data/processed/mappability/hg38.k24.umap.bw"

In [None]:
# Extract the file specs
with open(files_spec_path, "r") as f:
    files_spec = json.load(f)
peaks_beds = files_spec["peak_beds"]
profile_hdf5 = files_spec["profile_hdf5"]

In [None]:
# Get chromosome sets
with open(chrom_splits_json, "r") as f:
    chrom_splits = json.load(f)
split = chrom_splits[str(fold_num)]
train_chroms, val_chroms, test_chroms = split["train"], split["val"], split["test"]
all_chroms = train_chroms + val_chroms + test_chroms

In [None]:
# Import the model
custom_objects = {
    "kb": keras.backend,
    "profile_loss": train_profile_model.get_profile_loss_function(num_tasks, profile_length),
    "count_loss": train_profile_model.get_count_loss_function(num_tasks),
    "SplineWeight1D": spline.SplineWeight1D
}
model = keras.models.load_model(model_path, custom_objects=custom_objects)
# keras.utils.plot_model(model, show_shapes=True)

### Data preparation
Use classes from `make_profile_dataset` to prepare positive and negative inputs.

In [None]:
# Maps coordinates to 1-hot encoded sequence
coords_to_seq = feature_util.CoordsToSeq(reference_fasta, center_size_to_use=input_length)

# Maps coordinates to profiles
coords_to_vals = make_profile_dataset.CoordsToVals(profile_hdf5, profile_length)

# Maps many coordinates to inputs sequences and profiles for the network
def coords_to_network_inputs(coords):
    input_seq = coords_to_seq(coords)
    profs = coords_to_vals(coords)
    return input_seq, np.swapaxes(profs, 1, 2)

In [None]:
# Import set of positive peaks
pos_coords = []
for task_index in range(num_tasks):
    pos_coords_table = pd.read_csv(peaks_beds[task_index], sep="\t", header=None, compression="gzip")
    pos_coords_table = pos_coords_table[pos_coords_table[0].isin(all_chroms)]
    pos_coords.append(pos_coords_table.values[:, :3])

In [None]:
# Mappability of regions
mappability_reader = pyBigWig.open(mappability_path, "r")

### Predicting

In [None]:
def predict_coords(coords):
    """
    Fetches the necessary data from the given coordinate and runs it through the
    network. Returns the network predictions AND the true values from the dataset.
    The returned predicted profiles are in terms of log probabilities, and the
    returned predicted counts are also log. Returned values are all NumPy arrays.
    """
    input_seq, profiles = coords_to_network_inputs(coords)
    
    true_profs = profiles[:, :num_tasks, :, :]
    cont_profs = profiles[:, num_tasks:, :, :]
    true_counts = np.sum(true_profs, axis=2)

    # Run through the model
    logit_pred_profs, log_pred_counts = model.predict([input_seq, cont_profs])
    
    # Convert logit profile predictions to log probabilities
    log_pred_profs = profile_models.profile_logits_to_log_probs(
        logit_pred_profs
    )
    
    return log_pred_profs, log_pred_counts, true_profs, true_counts

In [None]:
def get_mappability(coords):
    """
    From a B x 3 object array of coordinates, gets the mappability of the region,
    by taking the average mappability within each interval. Returns a B-array
    of mappability values.
    """
    return np.array([
        np.mean(np.nan_to_num(mappability_reader.values(row[0], row[1], row[2])))
        for row in coords
    ])

In [None]:
preds = []
for task_index in range(num_tasks):
    print("Predicting on task %d" % task_index)
    num_coords = len(pos_coords[task_index])
    log_pred_profs = np.empty((num_coords, profile_length, 2))
    log_pred_counts = np.empty((num_coords, 2))
    true_profs = np.empty((num_coords, profile_length, 2))
    true_counts = np.empty((num_coords, 2))
    mappability = np.empty((num_coords,))

    batch_size = 128
    num_batches = int(np.ceil(num_coords / batch_size))
    for i in tqdm.notebook.trange(num_batches):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        b_log_pred_profs, b_log_pred_counts, b_true_profs, b_true_counts =\
            predict_coords(pos_coords[task_index][batch_slice])
        b_mappability = get_mappability(pos_coords[task_index][batch_slice])

        log_pred_profs[batch_slice] = b_log_pred_profs[:, task_index]
        log_pred_counts[batch_slice] = b_log_pred_counts[:, task_index]
        true_profs[batch_slice] = b_true_profs[:, task_index]
        true_counts[batch_slice] = b_true_counts[:, task_index]
        mappability[batch_slice] = b_mappability
        
    preds.append({
        "log_pred_profs": log_pred_profs,
        "log_pred_counts": log_pred_counts,
        "true_profs": true_profs,
        "true_counts": true_counts,
        "mappability": mappability
    })

### View count correlations

In [None]:
for task_index in range(num_tasks):
    log_true_counts = np.log(preds[task_index]["true_counts"] + 1)
    log_pred_counts = preds[task_index]["log_pred_counts"]
    mappability = preds[task_index]["mappability"]
    mappability = np.stack([mappability, mappability], axis=1)  # Tile into N x 2 to match counts
    pos_coords_chroms = pos_coords[task_index][:, 0]

    fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(40, 10))
    for i, (desc, chroms) in enumerate([
        ("train", train_chroms), ("val", val_chroms), ("test", test_chroms), ("all", all_chroms)
    ]):
        mask = np.isin(pos_coords_chroms, chroms)
        # Plot counts scatterplot of the specified index, treating strands separately
        ax[i].scatter(
            np.ravel(log_true_counts[mask]), np.ravel(log_pred_counts[mask]),
            c=np.ravel(mappability[mask]), cmap=cm.plasma, alpha=0.1,
        )
        
        # Draw y = x, but don't change axes
        xlims = ax[i].get_xlim()
        ylims = ax[i].get_ylim()
        min_limit, max_limit = np.min([xlims, ylims]), np.max([xlims, ylims])
        ax[i].plot([min_limit, max_limit], [min_limit, max_limit], color="black") 
        ax[i].set_xlim(xlims)
        ax[i].set_ylim(ylims)

        ax[i].set_xlabel("Log true counts")
        ax[i].set_ylabel("Log predicted counts")
        ax[i].set_title("Count predictions of task %d on %s chromosomes" % (task_index, desc))
    fig.colorbar(cm.ScalarMappable(norm=colors.Normalize(), cmap=cm.plasma), ax=ax)