In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score
import copy

from src.train import trainModel

#from src.dataloader import getData,spliceDataset,h5pyDataset,collate_fn
from src.dataloader import get_GTEX_v8_Data,spliceDataset,h5pyDataset,getDataPointList,getDataPointListGTEX,DataPointGTEX
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d,kl_div_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics,cross_entropy_2d

In [2]:
!nvidia-smi

Wed Jun  7 18:14:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  Off  | 00000000:31:00.0 Off |                    0 |
| N/A   37C    P0    38W / 250W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  Off  | 00000000:98:00.0 Off |                    0 |
| N/A   38C    P0    40W / 250W |      0MiB / 40960MiB |      0%      Default |
|       

In [3]:
#!pip install pickle5

In [4]:
rng = np.random.default_rng(23673)

In [5]:
#gtf = None

In [6]:
L = 32
N_GPUS = 3
k = 2
NUM_ACCUMULATION_STEPS=1
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS

k = NUM_ACCUMULATION_STEPS*k

CL = 2 * np.sum(AR*(W-1))

In [7]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-rnasplice-blood-070623/'
setType = 'all'
annotation, gene_to_label, seqData = get_GTEX_v8_Data(data_dir, setType,'annotation_GTEX_v8.txt')

In [8]:
from collections import defaultdict

In [9]:
# Maximum nucleotide context length (CL_max/2 on either side of the 
# position of interest)
# CL_max should be an even number
# Sequence length of SpliceAIs (SL+CL will be the input length and
# SL will be the output length)

SL=5000
CL_max=40000

In [10]:
assert CL_max % 2 == 0

In [11]:
#train_gene, validation_gene = train_test_split(annotation['gene'].drop_duplicates(),test_size=.1,random_state=435)
#annotation_train = annotation[annotation['gene'].isin(train_gene)]
#annotation_validation = annotation[annotation['gene'].isin(validation_gene)]

In [12]:
#with open('{}/sparse_discrete_gene_label_data_{}.pickle'.format(data_dir,setType), 'rb') as handle:
#    gene_to_label_old = pickle.load(handle)

In [13]:
#for gene in gene_to_label_old.keys():
#    if len(gene_to_label[gene])==0:
#        gene_to_label[gene] = gene_to_label_old[gene]

In [14]:
train_dataset = spliceDataset(getDataPointListGTEX(annotation,gene_to_label,SL,CL_max,shift=SL))
#val_dataset = spliceDataset(getDataPointListGTEX(annotation_validation,gene_to_label,SL,CL_max,shift=SL))
train_dataset.seqData = seqData
#val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True)
#val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE//4, shuffle=False, num_workers=16)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 4
hs = []
learning_rate= k*1e-4
gamma=0.5
temp = 1
#final_lr = 1e-5
#gamma = 1/(learning_rate/final_lr)**(1/5) 

In [16]:
for model_nr in range(10):
    model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2)
    model_m.apply(keras_init)
    model_m = model_m.to(device)
    if torch.cuda.device_count() > 1:
        #print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model_m = nn.DataParallel(model_m)
    
    model_m.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_031122_{}'.format(model_nr)))
    modelFileName = '../Results/PyTorch_Models/transformer_encoder_40k_finetune_rnasplice-blood_all_050623_{}'.format(model_nr)
    loss = categorical_crossentropy_2d().loss
    #loss = kl_div_2d(temp=temp).loss
    optimizer = torch.optim.AdamW(model_m.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
    warmup = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=100)
    #if model_nr>0:
    h = trainModel(model_m,modelFileName,loss,train_loader,None,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=True,lowValidationGPUMem=True,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max,reinforce=True,continous_labels=False)
    #else:
    #    h = trainModel(model_m,modelFileName,loss,train_loader,val_loader,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=False,lowValidationGPUMem=True,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max,reinforce=True,continous_labels=False)
    #    plt.plot(range(epochs),h['loss'],label='Train')
    #    plt.plot(range(epochs),h['val_loss'],label='Validation')
    #    plt.xlabel('Epoch')
    #    plt.ylabel('Loss')
    #    plt.legend()
    #    plt.show()
    hs.append(h)

Epoch (train) 1/4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [51:13<00:00,  1.27s/it, a_r=0.62, d_r=0.619, loss=0.000879, r_a=0.99, r_d=0.99, r_loss=5.12]


epoch: 1/4, train loss = 0.000931


Epoch (train) 2/4: 100%|█| 2422/2422 [50:15<00:00,  1.24s/it, a_r=0.625, d_r=0.631, loss=0.000844, r_a=0.991, r_


epoch: 2/4, train loss = 0.000850


Epoch (train) 3/4: 100%|█| 2422/2422 [50:43<00:00,  1.26s/it, a_r=0.628, d_r=0.636, loss=0.00082, r_a=0.991, r_d


epoch: 3/4, train loss = 0.000828


Epoch (train) 4/4: 100%|█| 2422/2422 [50:30<00:00,  1.25s/it, a_r=0.638, d_r=0.635, loss=0.000795, r_a=0.991, r_


epoch: 4/4, train loss = 0.000813


Epoch (train) 1/4: 100%|█| 2422/2422 [50:33<00:00,  1.25s/it, a_r=0.623, d_r=0.627, loss=0.000859, r_a=0.991, r_


epoch: 1/4, train loss = 0.000924


Epoch (train) 2/4: 100%|█| 2422/2422 [50:34<00:00,  1.25s/it, a_r=0.626, d_r=0.627, loss=0.000839, r_a=0.992, r_


epoch: 2/4, train loss = 0.000845


Epoch (train) 3/4: 100%|█| 2422/2422 [50:20<00:00,  1.25s/it, a_r=0.629, d_r=0.634, loss=0.000816, r_a=0.991, r_


epoch: 3/4, train loss = 0.000824


Epoch (train) 4/4: 100%|█| 2422/2422 [50:43<00:00,  1.26s/it, a_r=0.633, d_r=0.632, loss=0.000801, r_a=0.992, r_


epoch: 4/4, train loss = 0.000808


Epoch (train) 1/4: 100%|█| 2422/2422 [50:30<00:00,  1.25s/it, a_r=0.617, d_r=0.625, loss=0.00088, r_a=0.991, r_d


epoch: 1/4, train loss = 0.000931


Epoch (train) 2/4: 100%|█| 2422/2422 [50:32<00:00,  1.25s/it, a_r=0.625, d_r=0.626, loss=0.000824, r_a=0.991, r_


epoch: 2/4, train loss = 0.000850


Epoch (train) 3/4: 100%|█| 2422/2422 [50:14<00:00,  1.24s/it, a_r=0.626, d_r=0.637, loss=0.000821, r_a=0.992, r_


epoch: 3/4, train loss = 0.000828


Epoch (train) 4/4: 100%|█| 2422/2422 [50:17<00:00,  1.25s/it, a_r=0.631, d_r=0.636, loss=0.000809, r_a=0.992, r_


epoch: 4/4, train loss = 0.000813


Epoch (train) 1/4: 100%|█| 2422/2422 [50:33<00:00,  1.25s/it, a_r=0.617, d_r=0.616, loss=0.000877, r_a=0.991, r_


epoch: 1/4, train loss = 0.000929


Epoch (train) 2/4: 100%|█| 2422/2422 [50:42<00:00,  1.26s/it, a_r=0.628, d_r=0.63, loss=0.000841, r_a=0.991, r_d


epoch: 2/4, train loss = 0.000849


Epoch (train) 3/4: 100%|█| 2422/2422 [50:25<00:00,  1.25s/it, a_r=0.629, d_r=0.639, loss=0.000817, r_a=0.992, r_


epoch: 3/4, train loss = 0.000827


Epoch (train) 4/4: 100%|█| 2422/2422 [50:41<00:00,  1.26s/it, a_r=0.629, d_r=0.633, loss=0.000806, r_a=0.991, r_


epoch: 4/4, train loss = 0.000812


Epoch (train) 1/4: 100%|█| 2422/2422 [50:05<00:00,  1.24s/it, a_r=0.617, d_r=0.621, loss=0.000848, r_a=0.991, r_


epoch: 1/4, train loss = 0.000928


Epoch (train) 2/4: 100%|█| 2422/2422 [49:52<00:00,  1.24s/it, a_r=0.622, d_r=0.629, loss=0.000843, r_a=0.991, r_


epoch: 2/4, train loss = 0.000849


Epoch (train) 3/4: 100%|█| 2422/2422 [49:25<00:00,  1.22s/it, a_r=0.625, d_r=0.633, loss=0.000817, r_a=0.992, r_


epoch: 3/4, train loss = 0.000826


Epoch (train) 4/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:48<00:00,  1.23s/it, a_r=0.63, d_r=0.641, loss=0.000811, r_a=0.992, r_d=0.991, r_loss=4.57]


epoch: 4/4, train loss = 0.000811


Epoch (train) 1/4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:02<00:00,  1.24s/it, a_r=0.616, d_r=0.621, loss=0.000877, r_a=0.99, r_d=0.99, r_loss=7.32]


epoch: 1/4, train loss = 0.000935


Epoch (train) 2/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:06<00:00,  1.24s/it, a_r=0.616, d_r=0.625, loss=0.000848, r_a=0.991, r_d=0.991, r_loss=7.27]


epoch: 2/4, train loss = 0.000854


Epoch (train) 3/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:07<00:00,  1.24s/it, a_r=0.632, d_r=0.63, loss=0.000824, r_a=0.991, r_d=0.991, r_loss=7.37]


epoch: 3/4, train loss = 0.000832


Epoch (train) 4/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:41<00:00,  1.23s/it, a_r=0.633, d_r=0.631, loss=0.000797, r_a=0.991, r_d=0.991, r_loss=6.99]


epoch: 4/4, train loss = 0.000816


Epoch (train) 1/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:41<00:00,  1.23s/it, a_r=0.614, d_r=0.621, loss=0.000875, r_a=0.991, r_d=0.99, r_loss=5.27]


epoch: 1/4, train loss = 0.000929


Epoch (train) 2/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:46<00:00,  1.23s/it, a_r=0.626, d_r=0.632, loss=0.000818, r_a=0.992, r_d=0.992, r_loss=4.67]


epoch: 2/4, train loss = 0.000849


Epoch (train) 3/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:26<00:00,  1.25s/it, a_r=0.628, d_r=0.631, loss=0.000822, r_a=0.992, r_d=0.991, r_loss=4.72]


epoch: 3/4, train loss = 0.000827


Epoch (train) 4/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:55<00:00,  1.24s/it, a_r=0.627, d_r=0.626, loss=0.000802, r_a=0.992, r_d=0.992, r_loss=4.7]


epoch: 4/4, train loss = 0.000811


Epoch (train) 1/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:05<00:00,  1.24s/it, a_r=0.618, d_r=0.623, loss=0.000859, r_a=0.991, r_d=0.991, r_loss=6.72]


epoch: 1/4, train loss = 0.000927


Epoch (train) 2/4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:45<00:00,  1.23s/it, a_r=0.623, d_r=0.624, loss=0.000858, r_a=0.99, r_d=0.99, r_loss=6.84]


epoch: 2/4, train loss = 0.000849


Epoch (train) 3/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:56<00:00,  1.24s/it, a_r=0.632, d_r=0.625, loss=0.000822, r_a=0.992, r_d=0.991, r_loss=6.81]


epoch: 3/4, train loss = 0.000828


Epoch (train) 4/4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:44<00:00,  1.23s/it, a_r=0.633, d_r=0.63, loss=0.000814, r_a=0.991, r_d=0.99, r_loss=6.72]


epoch: 4/4, train loss = 0.000813


Epoch (train) 1/4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:42<00:00,  1.23s/it, a_r=0.62, d_r=0.622, loss=0.000866, r_a=0.991, r_d=0.99, r_loss=7.1]


epoch: 1/4, train loss = 0.000930


Epoch (train) 2/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:58<00:00,  1.24s/it, a_r=0.632, d_r=0.637, loss=0.000815, r_a=0.991, r_d=0.991, r_loss=7.31]


epoch: 2/4, train loss = 0.000850


Epoch (train) 3/4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:01<00:00,  1.24s/it, a_r=0.63, d_r=0.631, loss=0.000814, r_a=0.991, r_d=0.99, r_loss=6.78]


epoch: 3/4, train loss = 0.000827


Epoch (train) 4/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:48<00:00,  1.23s/it, a_r=0.634, d_r=0.636, loss=0.000819, r_a=0.991, r_d=0.991, r_loss=6.73]


epoch: 4/4, train loss = 0.000812


Epoch (train) 1/4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:03<00:00,  1.24s/it, a_r=0.613, d_r=0.616, loss=0.000877, r_a=0.99, r_d=0.99, r_loss=8.1]


epoch: 1/4, train loss = 0.000930


Epoch (train) 2/4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:30<00:00,  1.23s/it, a_r=0.622, d_r=0.633, loss=0.000843, r_a=0.99, r_d=0.99, r_loss=8.11]


epoch: 2/4, train loss = 0.000851


Epoch (train) 3/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [49:57<00:00,  1.24s/it, a_r=0.632, d_r=0.629, loss=0.000812, r_a=0.991, r_d=0.991, r_loss=7.9]


epoch: 3/4, train loss = 0.000828


Epoch (train) 4/4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2422/2422 [50:41<00:00,  1.26s/it, a_r=0.63, d_r=0.636, loss=0.000805, r_a=0.991, r_d=0.991, r_loss=7.92]

epoch: 4/4, train loss = 0.000813



