In [1]:
import os
import math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from optimization import (SequenceSignal, 
                          CNN_model,
                          train_val_loops)
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import pearsonr
from sklearn.metrics import (precision_score, recall_score, auc, 
                             precision_recall_curve, PrecisionRecallDisplay)

## Train the CNN 

In [2]:
# Global variables
DATA_DIR=Path('../net_output/')
X_TRAIN_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_train_augmented_encoding.npy')
Y_TRAIN_PATH = DATA_DIR.joinpath('train_target_Z_norm.npy')
X_VAL_PATH = DATA_DIR.joinpath('dataset_1kb_300bp_val_encoding.npy')
Y_VAL_PATH = DATA_DIR.joinpath('val_target.npy')
N_VAL_EXAMPLES = np.load(Y_VAL_PATH).shape[0]
N_TRAIN_EXAMPLES = np.load(Y_TRAIN_PATH).shape[0]
Z_SCORE_INVERSE_MEANS = DATA_DIR.joinpath('means.npy')
Z_SCORE_INVERSE_STD = DATA_DIR.joinpath('stds.npy')

BATCH_SIZE = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataloaders = SequenceSignal.load_dataset(X_TRAIN_PATH, 
                                          Y_TRAIN_PATH, 
                                          X_VAL_PATH, 
                                          Y_VAL_PATH, 
                                          BATCH_SIZE, 
                                          device = DEVICE)

In [3]:
N_EPOCHS = 30
PATIENCE = 20
OUTPUT_SHAPE = 9
CRITERION = torch.nn.MSELoss()
CHECKPOINT_NAME = DATA_DIR.joinpath('best_model_ATAC.pth')

model = CNN_model.ConvNet(n_conv_layers = 4, 
                           n_filters = [256, 60, 60, 120], 
                           kernel_sizes = [7, 3, 5, 3], 
                           dilation = [1, 1, 1, 1], 
                           drop_conv = 0.2, 
                           n_fc_layers = 2, 
                           n_neurons = [256, 256], 
                           drop_fc = 0.4, 
                           output_size = 9, 
                           sequence_length = 1000)

model.to(device=DEVICE)
summary(model)

Layer (type:depth-idx)                        Param #
ConvNet                                       --
├─ModuleList: 1-1                             --
│    └─Sequential: 2-1                        --
│    │    └─ConvPoolingBlock: 3-1             7,936
│    └─Sequential: 2-2                        --
│    │    └─ConvPoolingBlock: 3-2             46,260
│    └─Sequential: 2-3                        --
│    │    └─ConvPoolingBlock: 3-3             18,180
│    └─Sequential: 2-4                        --
│    │    └─ConvPoolingBlock: 3-4             21,960
├─ModuleList: 1-2                             --
│    └─Linear: 2-5                            1,812,736
│    └─Linear: 2-6                            65,792
├─Dropout: 1-3                                --
├─Linear: 1-4                                 2,313
Total params: 1,975,177
Trainable params: 1,975,177
Non-trainable params: 0

In [4]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = 1e-3, weight_decay = 1e-4)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                             epochs = N_EPOCHS, 
                                             max_lr = 1e-3, 
                                             steps_per_epoch = math.ceil(N_TRAIN_EXAMPLES / dataloaders[0].batch_size), 
                                             pct_start = 0.1, 
                                             anneal_strategy = 'linear')
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.95, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 4e-05
    lr: 4e-05
    max_lr: 0.001
    max_momentum: 0.95
    maximize: False
    min_lr: 4e-09
    weight_decay: 0.0001
)


In [5]:
if not os.path.exists(CHECKPOINT_NAME):
    
    output = train_val_loops.train_N_epochs(model, optimizer, 
                                            criterion = CRITERION, 
                                            train_loader = dataloaders[0], 
                                            valid_loader = dataloaders[1], 
                                            num_epochs = N_EPOCHS, 
                                            patience = PATIENCE, 
                                            model_path = CHECKPOINT_NAME, 
                                            lr_scheduler = lr_scheduler, 
                                            means_path = Z_SCORE_INVERSE_MEANS, 
                                            stds_path = Z_SCORE_INVERSE_STD, 
                                            DEVICE = DEVICE)
else:
    
    training_state = torch.load(CHECKPOINT_NAME, 
                                weights_only = True, 
                                map_location = DEVICE)
    
    optimizer.load_state_dict(training_state['optimizer'])
    lr_scheduler.load_state_dict(training_state['lr_sched'])
    model.load_state_dict(training_state['network'])
    best_valid_loss = training_state['best_valid_loss']

Training model:
SEConvNet(
  (convs): ModuleList(
    (0): Sequential(
      (0): ConvPoolingBlock(
        (block): Sequential(
          (0): Conv1d(4, 256, kernel_size=(7,), stride=(1,))
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Dropout1d(p=0.2, inplace=False)
          (4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      )
    )
    (1): Sequential(
      (0): ConvPoolingBlock(
        (block): Sequential(
          (0): Conv1d(256, 120, kernel_size=(3,), stride=(1,))
          (1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Dropout1d(p=0.2, inplace=False)
          (4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
      )
    )
    (2): Sequential(
      (0): ConvPoolingBlock(
        (block): Sequential(
          (0): Conv1d(120, 60, 

KeyboardInterrupt: 

In [6]:
# Load mean and std used for z-score normalization
Y_PREDICT_PATH = DATA_DIR.joinpath('y_predict_ATAC.npy')
means = torch.from_numpy(np.load(Z_SCORE_INVERSE_MEANS)).float().to(DEVICE)
stds = torch.from_numpy(np.load(Z_SCORE_INVERSE_STD)).float().to(DEVICE)

model.eval()
y_predict = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)
y_true = torch.zeros((N_VAL_EXAMPLES, OUTPUT_SHAPE), device = DEVICE)

with torch.inference_mode():
    
    for i, data in enumerate(dataloaders[1], 0):
        inputs, labels = data
        outputs = model(inputs)

        outputs_denorm = outputs * stds + means

        start_idx = i * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE

        y_predict[start_idx:end_idx] = outputs_denorm
        y_true[start_idx:end_idx] = labels

y_true_array = y_true.cpu().numpy()
y_predict_array = y_predict.cpu().numpy()

print(CRITERION(y_predict, y_true))
np.save(Y_PREDICT_PATH, y_predict_array)

tensor(1.3911, device='cuda:0')


In [7]:
for j in range(OUTPUT_SHAPE):
    print(pearsonr(y_true_array[:, j], y_predict_array[:, j]))

PearsonRResult(statistic=np.float32(0.56627053), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.5609125), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.58280313), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.5828313), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.5425546), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.60089225), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.6085993), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.60682625), pvalue=np.float64(0.0))
PearsonRResult(statistic=np.float32(0.63440585), pvalue=np.float64(0.0))
