In [1]:
import os
import math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from optimization import (SequenceSignal, 
                          transformer_model, 
                          train_val_loops)
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from sklearn.metrics import (precision_score, recall_score, auc, 
                             precision_recall_curve, PrecisionRecallDisplay)

In [2]:
# Global variables
DATA_DIR=Path('../ATACNet/peaks/')
X_TRAIN_PATH = DATA_DIR.joinpath('subset_X_train.npy')
Y_TRAIN_PATH = DATA_DIR.joinpath('subset_y_train.npy')
X_VAL_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_S3_val_encoding.npy')
Y_VAL_PATH = DATA_DIR.joinpath('val_target.npy')
N_VAL_EXAMPLES = np.load(Y_VAL_PATH).shape[0]
N_TRAIN_EXAMPLES = np.load(Y_TRAIN_PATH).shape[0]
Z_SCORE_INVERSE_MEANS = DATA_DIR.joinpath('zscore_means.npy')
Z_SCORE_INVERSE_STD = DATA_DIR.joinpath('zscore_stds.npy')

BATCH_SIZE = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataloaders = SequenceSignal.load_dataset(X_TRAIN_PATH, 
                                          Y_TRAIN_PATH, 
                                          X_VAL_PATH, 
                                          Y_VAL_PATH, 
                                          BATCH_SIZE, 
                                          device = DEVICE)

In [3]:
N_EPOCHS = 30
PATIENCE = 20
OUTPUT_SHAPE = 9
CRITERION = torch.nn.MSELoss()
CHECKPOINT_NAME = DATA_DIR.joinpath('best_model_dELSs.pth')
PRETRAINED_WEIGHTS = DATA_DIR.joinpath('best_model_ATAC.pth')
USE_PRETRAIN = True

model = transformer_model.TransformerCNNMixtureModel(n_conv_layers = 4, 
                                                     n_filters = [256, 60, 60, 120], 
                                                     kernel_sizes = [7, 3, 5, 3], 
                                                     dilation = [1, 1, 1, 1], 
                                                     drop_conv = 0.1, 
                                                     n_fc_layers = 2, 
                                                     drop_fc = 0.4, 
                                                     n_neurons = [256, 256], 
                                                     output_size = OUTPUT_SHAPE, 
                                                     drop_transformer=0.2, 
                                                     input_size=4, 
                                                     n_encoder_layers = 2, 
                                                     n_heads=8, 
                                                     n_transformer_FC_layers=256)

model.to(device=DEVICE)

if USE_PRETRAIN:
    training_state = torch.load(PRETRAINED_WEIGHTS, 
                                weights_only = True, 
                                map_location = DEVICE)
    model.load_state_dict(training_state['network'])

summary(model)

Layer (type:depth-idx)                                            Param #
TransformerCNNMixtureModel                                        --
├─ModuleList: 1-1                                                 --
│    └─Sequential: 2-1                                            --
│    │    └─ConvPoolingBlock: 3-1                                 7,936
│    └─Sequential: 2-2                                            --
│    │    └─ConvPoolingBlock: 3-2                                 46,260
│    └─Sequential: 2-3                                            --
│    │    └─ConvPoolingBlock: 3-3                                 18,180
│    └─Sequential: 2-4                                            --
│    │    └─ConvPoolingBlock: 3-4                                 21,960
├─PositionalEncoding: 1-2                                         --
├─TransformerEncoder: 1-3                                         --
│    └─ModuleList: 2-5                                            --
│    │    └─Tr

In [4]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = 1e-3, weight_decay = 1e-4)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                             epochs = N_EPOCHS, 
                                             max_lr = 1e-3, 
                                             steps_per_epoch = math.ceil(N_TRAIN_EXAMPLES / dataloaders[0].batch_size), 
                                             pct_start = 0.1, 
                                             anneal_strategy = 'linear')
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.95, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 4e-05
    lr: 4e-05
    max_lr: 0.001
    max_momentum: 0.95
    maximize: False
    min_lr: 4e-09
    weight_decay: 0.0001
)


In [6]:
if not os.path.exists(CHECKPOINT_NAME):
    
    output = train_val_loops.train_N_epochs(model, optimizer, 
                                            criterion = CRITERION, 
                                            train_loader = dataloaders[0], 
                                            valid_loader = dataloaders[1], 
                                            num_epochs = N_EPOCHS, 
                                            patience = PATIENCE, 
                                            model_path = CHECKPOINT_NAME, 
                                            lr_scheduler = lr_scheduler, 
                                            means_path = Z_SCORE_INVERSE_MEANS, 
                                            stds_path = Z_SCORE_INVERSE_STD, 
                                            DEVICE = DEVICE)
else:
    
    training_state = torch.load(CHECKPOINT_NAME, 
                                weights_only = True, 
                                map_location = DEVICE)
    
    model.load_state_dict(training_state['network'])

In [8]:
model.eval()
y_predict = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)
y_true = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)

with torch.inference_mode():

    for i, data in enumerate(dataloaders[1], 0):
        inputs, labels = data
        y_predict[(i * BATCH_SIZE):(i * BATCH_SIZE + BATCH_SIZE)] = model(inputs)
        y_true[(i * BATCH_SIZE):(i * BATCH_SIZE + BATCH_SIZE)] = labels

y_true_array = y_true.cpu().numpy()
y_predict = nn.functional.softmax(y_predict, dim=1)
y_predict = (y_predict_scratch).cpu().numpy()
print(CRITERION(y_predict, y_true))

KeyboardInterrupt: 

In [9]:
for j in range(OUTPUT_SHAPE):
    y, x, _ = precision_recall_curve(y_true_array[:, j], y_predict[:, j])
    plt.plot(x, y)
    print(auc(x, y))

PearsonRResult(statistic=0.009558266, pvalue=0.0029595103343786843)
PearsonRResult(statistic=0.027728753, pvalue=6.519479564496305e-18)
PearsonRResult(statistic=-0.005506557, pvalue=0.08687730995222236)
PearsonRResult(statistic=0.013493934, pvalue=2.7201970128727547e-05)
PearsonRResult(statistic=0.031183572, pvalue=3.077158800661475e-22)
PearsonRResult(statistic=0.05363699, pvalue=1.5821796523795144e-62)
PearsonRResult(statistic=0.013789144, pvalue=1.806887300075937e-05)
PearsonRResult(statistic=0.07370263, pvalue=1.5879400530097262e-116)
PearsonRResult(statistic=-0.09653193, pvalue=7.8778400926948835e-199)
