In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score
import copy

from src.train import trainModel

#from src.dataloader import getData,spliceDataset,h5pyDataset,collate_fn
from src.dataloader import get_GTEX_v8_Data,spliceDataset,h5pyDataset,getDataPointList,getDataPointListGTEX,DataPointGTEX
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d,kl_div_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics,cross_entropy_2d

In [2]:
!nvidia-smi

Tue May 23 17:16:53 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  Off  | 00000000:31:00.0 Off |                    0 |
| N/A   32C    P0    36W / 250W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  Off  | 00000000:98:00.0 Off |                    0 |
| N/A   30C    P0    35W / 250W |      0MiB / 40960MiB |      0%      Default |
|       

In [3]:
#!pip install pickle5

In [4]:
rng = np.random.default_rng(23673)

In [5]:
#gtf = None

In [6]:
L = 32
N_GPUS = 3
k = 2
NUM_ACCUMULATION_STEPS=1
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS

k = NUM_ACCUMULATION_STEPS*k

CL = 2 * np.sum(AR*(W-1))

In [7]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-rnasplice-blood/'
setType = 'all'
annotation, gene_to_label, seqData = get_GTEX_v8_Data(data_dir, setType,'annotation_rnasplice-blood.txt')

In [8]:
# Maximum nucleotide context length (CL_max/2 on either side of the 
# position of interest)
# CL_max should be an even number
# Sequence length of SpliceAIs (SL+CL will be the input length and
# SL will be the output length)

SL=5000
CL_max=40000

In [9]:
assert CL_max % 2 == 0

In [10]:
#train_gene, validation_gene = train_test_split(annotation['gene'].drop_duplicates(),test_size=.1,random_state=435)
#annotation_train = annotation[annotation['gene'].isin(train_gene)]
#annotation_validation = annotation[annotation['gene'].isin(validation_gene)]

In [11]:
#with open('{}/sparse_discrete_gene_label_data_{}.pickle'.format(data_dir,setType), 'rb') as handle:
#    gene_to_label_old = pickle.load(handle)

In [12]:
#for gene in gene_to_label_old.keys():
#    if len(gene_to_label[gene])==0:
#        gene_to_label[gene] = gene_to_label_old[gene]

In [13]:
train_dataset = spliceDataset(getDataPointListGTEX(annotation,gene_to_label,SL,CL_max,shift=SL))
#val_dataset = spliceDataset(getDataPointListGTEX(annotation_validation,gene_to_label,SL,CL_max,shift=SL))
train_dataset.seqData = seqData
#val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True)
#val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE//4, shuffle=False, num_workers=16)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 4
hs = []
learning_rate= k*1e-4
gamma=0.5
temp = 1
#final_lr = 1e-5
#gamma = 1/(learning_rate/final_lr)**(1/5) 

In [15]:
for model_nr in range(10):
    model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2)
    model_m.apply(keras_init)
    model_m = model_m.to(device)
    if torch.cuda.device_count() > 1:
        #print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model_m = nn.DataParallel(model_m)
    
    model_m.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_031122_{}'.format(model_nr)))
    modelFileName = '../Results/PyTorch_Models/transformer_encoder_40k_finetune_rnasplice-blood_all_120523_{}'.format(model_nr)
    loss = categorical_crossentropy_2d().loss
    #loss = kl_div_2d(temp=temp).loss
    optimizer = torch.optim.AdamW(model_m.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
    warmup = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=100)
    #if model_nr>0:
    h = trainModel(model_m,modelFileName,loss,train_loader,None,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=True,lowValidationGPUMem=True,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max,reinforce=True,continous_labels=False)
    #else:
    #    h = trainModel(model_m,modelFileName,loss,train_loader,val_loader,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=False,lowValidationGPUMem=True,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max,reinforce=True,continous_labels=False)
    #    plt.plot(range(epochs),h['loss'],label='Train')
    #    plt.plot(range(epochs),h['val_loss'],label='Validation')
    #    plt.xlabel('Epoch')
    #    plt.ylabel('Loss')
    #    plt.legend()
    #    plt.show()
    hs.append(h)

Epoch (train) 1/4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1618/1618 [30:20<00:00,  1.13s/it, a_r=0.561, d_r=0.591, loss=0.000981, r_a=0.99, r_d=0.99, r_loss=7.64]


epoch: 1/4, train loss = 0.001065


Epoch (train) 2/4: 100%|█| 1618/1618 [28:28<00:00,  1.06s/it, a_r=0.564, d_r=0.595, loss=0.000929, r_a=0.992, r_


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|█| 1618/1618 [28:37<00:00,  1.06s/it, a_r=0.559, d_r=0.587, loss=0.00092, r_a=0.992, r_d


epoch: 3/4, train loss = 0.000917


Epoch (train) 4/4:  32%|▎| 519/1618 [09:12<19:29,  1.06s/it, a_r=0.525, d_r=0.561, loss=0.00546, r_a=0.529, r_d=


ValueError: Caught ValueError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/splice-site-prediction/Code/src/model.py", line 251, in forward
    actions,acceptor_actions,donor_actions,acceptor_log_probs,donor_log_probs = self.select_action(state)
  File "/splice-site-prediction/Code/src/model.py", line 305, in select_action
    m = torch.distributions.Categorical(logits=policy_logits[:,:,0]-torch.nan_to_num(float('inf')*c, nan=0))
  File "/opt/conda/lib/python3.8/site-packages/torch/distributions/categorical.py", line 64, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter logits (Tensor of shape (32, 45000)) of distribution Categorical(logits: torch.Size([32, 45000])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[-28768., -29952., -30624.,  ...,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        ...,
        [ -2832.,  -3424.,  -4096.,  ...,    -inf,    -inf,    -inf],
        [-27696.,    -inf, -30416.,  ...,    -inf,    -inf,    -inf],
        [ -2832.,  -3424.,  -4096.,  ...,    -inf,    -inf,    -inf]],
       device='cuda:1', dtype=torch.float16, grad_fn=<SubBackward0>)
