In [1]:
import os
import math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from model_structure import (SequenceSignal, 
                          CNN_model,
                          train_val_loops, 
                          transformer_model)
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import pearsonr, gaussian_kde
from sklearn.preprocessing import StandardScaler

In [2]:
# Read metadata file as a dictionary for later
metadata = pd.read_csv('../metadata_files/ATAC_ids_complete_data.csv', header = None)
metadata_dict = dict()
for j in range(metadata.shape[0]):
    metadata_dict[metadata.iloc[j, 0]] = metadata.iloc[j, 1]

In [9]:
# define model
BATCH_SIZE = 2048
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_NAME = 'ATAC_transformer.pth'
OUTPUT_SHAPE = 9
CRITERION = torch.nn.MSELoss()

model = transformer_model.TransformerCNNMixtureModel(n_conv_layers = 4, 
                                                     n_filters = [256, 60, 60, 120], 
                                                     kernel_sizes = [7, 3, 5, 3], 
                                                     dilation = [1, 1, 1, 1], 
                                                     drop_conv = 0.1, 
                                                     n_fc_layers = 2, 
                                                     drop_fc = 0.4, 
                                                     n_neurons = [256, 256], 
                                                     #n_neurons = [128, 128], 
                                                     output_size = OUTPUT_SHAPE, 
                                                     drop_transformer=0.2, 
                                                     input_size=4, 
                                                     n_encoder_layers = 2, 
                                                     n_heads=8, 
                                                     n_transformer_FC_layers=256) 
                                                     #n_transformer_FC_layers=128)

model.to(device=DEVICE)
state = torch.load(CHECKPOINT_NAME, map_location=DEVICE)
state_dict = state["network"]

# Handle compiled model keys
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
    print("Detected compiled model weights, stripping '_orig_mod.' prefixes...")
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

model.load_state_dict(state_dict, strict=True)
summary(model)

RuntimeError: Error(s) in loading state_dict for TransformerCNNMixtureModel:
	Missing key(s) in state_dict: "transformerEncoder.0.self_attn.in_proj_weight", "transformerEncoder.0.self_attn.in_proj_bias", "transformerEncoder.0.self_attn.out_proj.weight", "transformerEncoder.0.self_attn.out_proj.bias", "transformerEncoder.0.linear1.weight", "transformerEncoder.0.linear1.bias", "transformerEncoder.0.linear2.weight", "transformerEncoder.0.linear2.bias", "transformerEncoder.0.norm1.weight", "transformerEncoder.0.norm1.bias", "transformerEncoder.0.norm2.weight", "transformerEncoder.0.norm2.bias", "transformerEncoder.1.self_attn.in_proj_weight", "transformerEncoder.1.self_attn.in_proj_bias", "transformerEncoder.1.self_attn.out_proj.weight", "transformerEncoder.1.self_attn.out_proj.bias", "transformerEncoder.1.linear1.weight", "transformerEncoder.1.linear1.bias", "transformerEncoder.1.linear2.weight", "transformerEncoder.1.linear2.bias", "transformerEncoder.1.norm1.weight", "transformerEncoder.1.norm1.bias", "transformerEncoder.1.norm2.weight", "transformerEncoder.1.norm2.bias". 
	Unexpected key(s) in state_dict: "transformerEncoder.layers.0.self_attn.in_proj_weight", "transformerEncoder.layers.0.self_attn.in_proj_bias", "transformerEncoder.layers.0.self_attn.out_proj.weight", "transformerEncoder.layers.0.self_attn.out_proj.bias", "transformerEncoder.layers.0.linear1.weight", "transformerEncoder.layers.0.linear1.bias", "transformerEncoder.layers.0.linear2.weight", "transformerEncoder.layers.0.linear2.bias", "transformerEncoder.layers.0.norm1.weight", "transformerEncoder.layers.0.norm1.bias", "transformerEncoder.layers.0.norm2.weight", "transformerEncoder.layers.0.norm2.bias", "transformerEncoder.layers.1.self_attn.in_proj_weight", "transformerEncoder.layers.1.self_attn.in_proj_bias", "transformerEncoder.layers.1.self_attn.out_proj.weight", "transformerEncoder.layers.1.self_attn.out_proj.bias", "transformerEncoder.layers.1.linear1.weight", "transformerEncoder.layers.1.linear1.bias", "transformerEncoder.layers.1.linear2.weight", "transformerEncoder.layers.1.linear2.bias", "transformerEncoder.layers.1.norm1.weight", "transformerEncoder.layers.1.norm1.bias", "transformerEncoder.layers.1.norm2.weight", "transformerEncoder.layers.1.norm2.bias". 

In [4]:
# Load mean and std used for z-score normalization
Z_SCORE_INVERSE_MEANS = '../net_output/z_score_means.npy'
Z_SCORE_INVERSE_STD = '../net_output/z_score_std.npy'
means = torch.from_numpy(np.load(Z_SCORE_INVERSE_MEANS)).float().to(DEVICE)
stds = torch.from_numpy(np.load(Z_SCORE_INVERSE_STD)).float().to(DEVICE)

In [5]:
X_TRAIN_PATH = '../net_output/dataset_1kb_300bp_train_augmented_encoding.npy'
Y_TRAIN_PATH = '../net_output/train_target_Z_scores.npy'
X_VAL_PATH = '../net_output/dataset_1kb_300bp_test_encoding.npy'
Y_VAL_PATH = '../net_output/test_target_ATAC.npy'
N_VAL_EXAMPLES = np.load(Y_VAL_PATH).shape[0]

dataloaders = SequenceSignal.load_dataset(X_TRAIN_PATH, 
                                          Y_TRAIN_PATH, 
                                          X_VAL_PATH, 
                                          Y_VAL_PATH, 
                                          BATCH_SIZE, 
                                          device = DEVICE, shuffle=False)

means = torch.from_numpy(np.load(Z_SCORE_INVERSE_MEANS)).float().to(DEVICE)
stds = torch.from_numpy(np.load(Z_SCORE_INVERSE_STD)).float().to(DEVICE)

model.eval()
y_predict = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)
y_true = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)

with torch.inference_mode():
    
    for i, data in enumerate(dataloaders[1], 0):
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(inputs)
        outputs_denorm = outputs * stds + means
        start_idx = i * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE

        y_predict[start_idx:end_idx] = outputs_denorm
        y_true[start_idx:end_idx] = labels

mse = (CRITERION(y_predict, y_true)).cpu().numpy()

y_true_array = y_true.cpu().numpy()
y_predict_array = y_predict.cpu().numpy()

In [6]:
for j in range(OUTPUT_SHAPE):
    print(pearsonr(y_true_array[:, j], y_predict_array[:, j]))

PearsonRResult(statistic=np.float32(0.6173046), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.62245554), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.66681993), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.65547746), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.6535795), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.70623577), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.7135478), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.7096356), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.6556064), pvalue=np.float64(0.0))


In [7]:
mse

array(1.0726405, dtype=float32)

## Heatmap

In [16]:
scaler = StandardScaler()
data1 = y_predict_array[:10000, :]

# Cluster just once
g = sns.clustermap(data1,
                   cmap='icefire',
                   yticklabels=False,
                   xticklabels=False,
                   figsize=(6, 10))

# Save row and column order
row_order = g.dendrogram_row.reordered_ind
col_order = g.dendrogram_col.reordered_ind

# Save for reuse
np.save("row_order.npy", row_order)
np.save("col_order.npy", col_order)

plt.close()

data2 = y_true_array[:10000, :]
row_order = np.load("row_order.npy")
col_order = np.load("col_order.npy")

# Reorder both datasets
data1_reordered = data1[row_order, :][:, col_order]
data2_reordered = data2[row_order, :][:, col_order]

vmin = min(data1_reordered.min(), data2_reordered.min())
vmax = max(data1_reordered.max(), data2_reordered.max())

fig, axes = plt.subplots(1, 2, figsize=(12, 16), gridspec_kw={'wspace': 0.1}, constrained_layout=True)

hm1 = sns.heatmap(data1_reordered,
                  cmap='vlag',
                  vmin=vmin, vmax=vmax,
                  yticklabels=False,
                  xticklabels=['E5', 'E11', 'E13', 'EAD', 'HID', 'WID', 'LB', 'AB', 'O'],
                  cbar=False,
                  ax=axes[0])
axes[0].set_title("Predicho", fontsize = 16)
axes[0].tick_params(axis='x', labelsize=15)

sns.heatmap(data2_reordered,
            cmap='vlag',
            vmin=vmin, vmax=vmax,
            yticklabels=False,
            xticklabels=['E5', 'E11', 'E13', 'EAD', 'HID', 'WID', 'LB', 'AB', 'O'],
            cbar=False,
            ax=axes[1])
axes[1].set_title("Observado", fontsize=16)
axes[1].tick_params(axis='x', labelsize=15)

cbar = fig.colorbar(hm1.collections[0],
                    ax=axes,
                    orientation="horizontal",
                    fraction=0.03,
                    pad=0.08)
cbar.set_label(r'$\log_{2}$' + '(ATAC-seq)', fontsize=16) 
cbar.ax.tick_params(labelsize=16) 

for ax in axes:
    ncols = data1_reordered.shape[1]
    for x in range(1, ncols):
        ax.axvline(x, color="white", linewidth=0.5)

plt.savefig("comparison.png", format="png", dpi=300, bbox_inches="tight")
plt.close()

## Correlation plots

In [28]:
contexts = ['E5', 'E11', 'E13', 'EAD', 'HID', 'WID', 'LB', 'AB', 'O']
sns.set_theme(style="white", font_scale=1.5)
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.ravel() 

for i, j in enumerate(contexts):
    
    print(f'Processing {j} ...')
    
    x = y_true_array[:, i]
    y = y_predict_array[:, i]
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    r, _ = pearsonr(x, y)

    sc = axes[i].scatter(x, y, c=z, cmap="vlag", s = 3, alpha=0.4)
    
    axes[i].plot([min(x), max(x)], [min(x), max(x)], 'k--', linewidth=1.5)
    
    axes[i].set_xlabel(r'$\log_{2}$' + '(ATAC-seq) observado')
    axes[i].set_ylabel(r'$\log_{2}$' + '(ATAC-seq) predicho')
    
    axes[i].text(0.05, 0.95, fr'$\rho$ = {r:.2f}',
                 transform=axes[i].transAxes,
                 fontsize=15, verticalalignment='top',
                 bbox=dict(boxstyle="round,pad=0.3",
                           facecolor="white", edgecolor="gray"))
    
    axes[i].set_title(f"Contexto: {j}")

from matplotlib.ticker import MaxNLocator
for ax in axes:
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    
plt.tight_layout()
plt.savefig("figura_corr2.png", dpi=200)
plt.close()

Processing E5 ...
Processing E11 ...
Processing E13 ...
Processing EAD ...
Processing HID ...
Processing WID ...
Processing LB ...
Processing AB ...
Processing O ...


In [29]:
#Predicho vs observado
contexts_comb = [('EAD', 'AB'), ('HID', 'AB'), ('WID', 'AB'), ('AB', 'E11'), ('AB', 'EAD'), ('AB', 'WID'), ('LB', 'O'), ('AB', 'O'), ('LB', 'HID')]
contexts = {'E5': 0, 'E11': 1, 'E13': 2, 'EAD': 3, 'HID': 4, 'WID': 5, 'LB': 6, 'AB': 7, 'O': 8}
sns.set_theme(style="white", font_scale=1.5)
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.ravel() 

for j, (i, k) in enumerate(contexts_comb):
    
    print(f'Processing {i}, {k} ...')
    
    x = y_true_array[:, contexts[k]]
    y = y_predict_array[:, contexts[i]]
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    r, _ = pearsonr(x, y)

    sc = axes[j].scatter(x, y, c=z, cmap="vlag", s = 3, alpha=0.4)
    
    axes[j].plot([min(x), max(x)], [min(x), max(x)], 'k--', linewidth=1.5)
    
    axes[j].set_xlabel(r'$\log_{2}$' + f'(ATAC-seq) observado en {k}')
    axes[j].set_ylabel(r'$\log_{2}$' + f'(ATAC-seq) predicho en {i}')
    
    axes[j].text(0.05, 0.95, fr'$\rho$ = {r:.2f}',
                 transform=axes[j].transAxes,
                 fontsize=14, verticalalignment='top',
                 bbox=dict(boxstyle="round,pad=0.3",
                           facecolor="white", edgecolor="gray"))
    #axes[i].set_title(f"Contextos: {i}, {k}")

from matplotlib.ticker import MaxNLocator
for ax in axes:
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    
plt.tight_layout()
plt.savefig("figura_corr_no_match.png", dpi=200)
plt.close()

Processing EAD, AB ...
Processing HID, AB ...
Processing WID, AB ...
Processing AB, E11 ...
Processing AB, EAD ...
Processing AB, WID ...
Processing LB, O ...
Processing AB, O ...
Processing LB, HID ...


## Model predictions to make bigwig files

In [5]:
#val = SequenceSignal.Sequence('model_validation/windows_1kb_1bp_3L_window_encoding.npy', device = DEVICE)
val = SequenceSignal.Sequence('../../predict_intervals_tests/windows_predict_1kp_10bp_validation_encoding.npy', 
                              device = DEVICE)
size = len(val)
val = DataLoader(val, batch_size = BATCH_SIZE, shuffle = False)

model.eval()
y_predict = torch.zeros((size, OUTPUT_SHAPE), device = DEVICE)

with torch.inference_mode():

    for i, data in enumerate(val, 0):
        outputs = model(data)
        
        outputs_denorm = outputs * stds + means
        start_idx = i * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE

        y_predict[start_idx:end_idx] = outputs_denorm
        
y_predict_array = (2**y_predict).cpu().numpy()
#y_predict_array = (y_predict).cpu().numpy()

In [6]:
np.savetxt("../../predict_intervals_tests/predictions_model_val.txt", y_predict_array)