In [None]:
import os
import math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from model_structure import (SequenceSignal, 
                             transformer_model, 
                             train_val_loops)
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import pearsonr, gaussian_kde

## Train the CNN + Transformer model

In [None]:
# Global variables
DATA_DIR=Path('../ML_datasets/ATAC_data/')
X_TRAIN_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_train_augmented_encoding.npy')
Y_TRAIN_PATH = DATA_DIR.joinpath('train_target_Z_scores.npy')
X_VAL_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_val_encoding.npy')
Y_VAL_PATH = DATA_DIR.joinpath('val_target.npy')
N_VAL_EXAMPLES = np.load(Y_VAL_PATH).shape[0]
N_TRAIN_EXAMPLES = np.load(Y_TRAIN_PATH).shape[0]
Z_SCORE_INVERSE_MEANS = DATA_DIR.joinpath('sample_means.npy')
Z_SCORE_INVERSE_STD = DATA_DIR.joinpath('sample_stds.npy')

BATCH_SIZE = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataloaders = SequenceSignal.load_dataset(X_TRAIN_PATH, 
                                          Y_TRAIN_PATH, 
                                          X_VAL_PATH, 
                                          Y_VAL_PATH, 
                                          BATCH_SIZE, 
                                          device = DEVICE)

In [None]:
N_EPOCHS = 30
PATIENCE = 20
OUTPUT_SHAPE = 9
CRITERION = torch.nn.MSELoss()
CHECKPOINT_NAME = DATA_DIR.joinpath('ATAC_transformer.pth')

model = transformer_model.TransformerCNNMixtureModel(n_conv_layers = 4, 
                                                     n_filters = [256, 60, 60, 120], 
                                                     kernel_sizes = [7, 3, 5, 3], 
                                                     dilation = [1, 1, 1, 1], 
                                                     drop_conv = 0.1, 
                                                     n_fc_layers = 2, 
                                                     drop_fc = 0.4, 
                                                     n_neurons = [256, 256], 
                                                     output_size = OUTPUT_SHAPE, 
                                                     drop_transformer=0.2, 
                                                     input_size=4, 
                                                     n_encoder_layers = 2, 
                                                     n_heads=8, 
                                                     n_transformer_FC_layers=256)

model.to(device=DEVICE)
summary(model)

In [None]:
optimizer = torch.optim.AdamW(params = model.parameters(), lr = 1e-3, weight_decay = 5e-3)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                             epochs = N_EPOCHS, 
                                             max_lr = 2e-3, 
                                             steps_per_epoch = math.ceil(N_TRAIN_EXAMPLES / dataloaders[0].batch_size), 
                                             pct_start = 0.15,
                                             anneal_strategy = 'linear')
print(optimizer)

In [None]:
if not os.path.exists(CHECKPOINT_NAME):
    
    output = train_val_loops.train_N_epochs(model, optimizer, 
                                            criterion = CRITERION, 
                                            train_loader = dataloaders[0], 
                                            valid_loader = dataloaders[1], 
                                            num_epochs = N_EPOCHS, 
                                            patience = PATIENCE, 
                                            model_path = CHECKPOINT_NAME, 
                                            lr_scheduler = lr_scheduler, 
                                            means_path = Z_SCORE_INVERSE_MEANS, 
                                            stds_path = Z_SCORE_INVERSE_STD, 
                                            DEVICE = DEVICE, 
                                            use_amp=True)
else:
    
    training_state = torch.load(CHECKPOINT_NAME, 
                                weights_only = True, 
                                map_location = DEVICE)
    
    optimizer.load_state_dict(training_state['optimizer'])
    lr_scheduler.load_state_dict(training_state['lr_sched'])
    model.load_state_dict(training_state['network'])
    best_valid_loss = training_state['best_valid_loss']

In [None]:
#Load mean and std used for z-score normalization
X_VAL_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_test_encoding.npy')
Y_VAL_PATH = DATA_DIR.joinpath('test_target.npy')
N_VAL_EXAMPLES = np.load(Y_VAL_PATH).shape[0]

dataloaders = SequenceSignal.load_dataset(X_TRAIN_PATH, 
                                          Y_TRAIN_PATH, 
                                          X_VAL_PATH, 
                                          Y_VAL_PATH, 
                                          BATCH_SIZE, 
                                          device = DEVICE)

Y_PREDICT_PATH = DATA_DIR.joinpath('y_predict_ATAC.npy')
means = torch.from_numpy(np.load(Z_SCORE_INVERSE_MEANS)).float().to(DEVICE)
stds = torch.from_numpy(np.load(Z_SCORE_INVERSE_STD)).float().to(DEVICE)

model.eval()
y_predict = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)
y_true = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)

with torch.inference_mode():
    
    for i, data in enumerate(dataloaders[1], 0):
        inputs, labels = data
        outputs = model(inputs)

        outputs_denorm = outputs * stds + means

        start_idx = i * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE

        y_predict[start_idx:end_idx] = outputs_denorm
        y_true[start_idx:end_idx] = labels

y_true_array = y_true.cpu().numpy()
y_predict_array = y_predict.cpu().numpy()

print(CRITERION(y_predict, y_true))
np.save(Y_PREDICT_PATH, y_predict_array)

In [None]:
for j in range(OUTPUT_SHAPE):
    print(pearsonr(y_true_array[:, j], y_predict_array[:, j]))

In [None]:
plt.scatter(y_true_array[:, 5], y_predict_array[:, 5], s= 0.1, c = 'salmon')

In [None]:
def make_joints_plot2(C, C2, output_dir,
                     contexts = ['E5', 'E11', 'E13', 'EAD', 'HID', 'WID', 'LB', 'AB', 'O']):
    
    sns.set_theme(style="white", font_scale = 1.5)
    x = y_true_array[:, C]
    y = y_predict_array[:, C2]
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    r, _ = pearsonr(x, y)
    
    data = pd.DataFrame({'True Values': x, 'Predicted Values': y, 'Density': z})
    
    g = sns.jointplot(
        data=data, x='True Values', y='Predicted Values', 
        kind="scatter", palette="magma_r", alpha=0.7, s=10,
        marginal_kws=dict(bins=50, fill=True, color='#3D348B'), 
        height=10
    )
    
    # Joint scatter with density
    g.plot_joint(plt.scatter, c=data['Density'], cmap="viridis_r", s=10, alpha=0.5)

    # Diagonal reference line
    g.ax_joint.plot([min(x), max(x)], [min(x), max(x)], 'k--', linewidth=1.5)

    # Axis labels and title
    g.ax_joint.set_xlabel(f'log2 (ATAC-seq) observado {contexts[C]}')
    g.ax_joint.set_ylabel(f'log2 (ATAC-seq) predicho {contexts[C2]}')
    # Pearson correlation annotation
    g.ax_joint.text(0.05, 0.95, f'Pearson r = {r:.2f}',
                    transform=g.ax_joint.transAxes,
                    fontsize=18, verticalalignment='top',
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray"))

    plt.savefig(output_dir.joinpath(f'jointplot_{contexts[C]}_{contexts[C2]}.png'), dpi=100, bbox_inches='tight')
    plt.show()

for j in range(9):
    for j2 in range (9):
        make_joints_plot2(j, j2,  DATA_DIR)