In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score
import copy

from src.train import trainModel
from src.dataloader import getData,spliceDataset,h5pyDataset,collate_fn
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.models import SpliceFormer
from src.evaluation_metrics import print_topl_statistics,cross_entropy_2d

In [2]:
!nvidia_smi

/bin/bash: nvidia_smi: command not found


In [3]:
rng = np.random.default_rng(23673)

In [4]:
#gtf = None

In [5]:
L = 32
N_GPUS = 8
k = 2
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = k*6*N_GPUS


CL = 2 * np.sum(AR*(W-1))

In [6]:
data_dir = '../Data'
setType = 'train'
annotation, transcriptToLabel, seqData = getData(data_dir, setType)

In [7]:
# Maximum nucleotide context length (CL_max/2 on either side of the 
# position of interest)
# CL_max should be an even number
# Sequence length of SpliceAIs (SL+CL will be the input length and
# SL will be the output length)

SL=5000
CL_max=100000

In [8]:
assert CL_max % 2 == 0

In [9]:
train_gene, validation_gene = train_test_split(annotation['gene'].drop_duplicates(),test_size=.1,random_state=435)
annotation_train = annotation[annotation['gene'].isin(train_gene)]
annotation_validation = annotation[annotation['gene'].isin(validation_gene)]

In [10]:
train_dataset = spliceDataset(annotation_train,transcriptToLabel,SL,CL_max)
val_dataset = spliceDataset(annotation_validation,transcriptToLabel,SL,CL_max)
train_dataset.seqData = seqData
val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=k*100, shuffle=True, num_workers=16,collate_fn=collate_fn, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=k*100, shuffle=False,collate_fn=collate_fn, num_workers=16)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10
hs = []

for model_nr in range(1,5):
    model_m = SpliceFormer(CL_max,heads=16,n_transformer_blocks=4,depth=2)
    model_m.apply(keras_init)
    model_m = model_m.to(device)
    if torch.cuda.device_count() > 1:
        #print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model_m = nn.DataParallel(model_m)
    
    modelFileName = '../Results/PyTorch_Models/transformer_encoder_100k_130722_{}'.format(model_nr)
    #model_m.load_state_dict(torch.load('../Results/PyTorch_Models/SpliceAI_Ensembl_dgxtest_{}'.format(0)))
    #loss = nn.CrossEntropyLoss(weight=torch.from_numpy(weights).float().to(device),ignore_index=-1,reduction='mean')
    loss = categorical_crossentropy_2d().loss
    #loss = nn.KLDivLoss()
    learning_rate= k*1e-3
    optimizer = torch.optim.Adam(model_m.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    warmup = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)
    h = trainModel(model_m,modelFileName,loss,train_loader,val_loader,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=True)
    #hs.append(h)

    #plt.plot(range(epochs),h['loss'],label='Train')
    #plt.plot(range(epochs),h['val_loss'],label='Validation')
    #plt.xlabel('Epoch')
    #plt.ylabel('Loss')
    #plt.legend()
    #plt.show()

Epoch (train) 1/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [28:59<00:00, 17.22s/it, accepor_recall=0.86, donor_recall=0.882, loss=0.000226, pred_l1_dist=0]


epoch: 1/10, train loss = 0.010921


Epoch (train) 2/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:37<00:00, 16.41s/it, accepor_recall=0.896, donor_recall=0.911, loss=0.000174, pred_l1_dist=0]


epoch: 2/10, train loss = 0.000202


Epoch (train) 3/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:23<00:00, 16.27s/it, accepor_recall=0.908, donor_recall=0.916, loss=0.000169, pred_l1_dist=0]


epoch: 3/10, train loss = 0.000178


Epoch (train) 4/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:25<00:00, 16.29s/it, accepor_recall=0.923, donor_recall=0.931, loss=0.000157, pred_l1_dist=0]


epoch: 4/10, train loss = 0.000161


Epoch (train) 5/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:41<00:00, 16.45s/it, accepor_recall=0.923, donor_recall=0.928, loss=0.000144, pred_l1_dist=0]


epoch: 5/10, train loss = 0.000152


Epoch (train) 6/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:40<00:00, 16.44s/it, accepor_recall=0.935, donor_recall=0.941, loss=0.000137, pred_l1_dist=0]


epoch: 6/10, train loss = 0.000140


Epoch (train) 7/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:25<00:00, 16.29s/it, accepor_recall=0.936, donor_recall=0.937, loss=0.000127, pred_l1_dist=0]


epoch: 7/10, train loss = 0.000132


Epoch (train) 8/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:28<00:00, 16.32s/it, accepor_recall=0.952, donor_recall=0.954, loss=0.000104, pred_l1_dist=0]


epoch: 8/10, train loss = 0.000107


Epoch (train) 9/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:45<00:00, 16.49s/it, accepor_recall=0.963, donor_recall=0.965, loss=9.12e-5, pred_l1_dist=0]


epoch: 9/10, train loss = 0.000092


Epoch (train) 10/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:21<00:00, 16.26s/it, accepor_recall=0.962, donor_recall=0.968, loss=8.02e-5, pred_l1_dist=0]


epoch: 10/10, train loss = 0.000081


Epoch (train) 1/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:27<00:00, 16.31s/it, accepor_recall=0.866, donor_recall=0.879, loss=0.000228, pred_l1_dist=0]


epoch: 1/10, train loss = 0.011393


Epoch (train) 2/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:28<00:00, 16.32s/it, accepor_recall=0.891, donor_recall=0.902, loss=0.000212, pred_l1_dist=0]


epoch: 2/10, train loss = 0.000201


Epoch (train) 3/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:41<00:00, 16.45s/it, accepor_recall=0.911, donor_recall=0.914, loss=0.000163, pred_l1_dist=0]


epoch: 3/10, train loss = 0.000177


Epoch (train) 4/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:38<00:00, 16.42s/it, accepor_recall=0.914, donor_recall=0.926, loss=0.000155, pred_l1_dist=0]


epoch: 4/10, train loss = 0.000164


Epoch (train) 5/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:44<00:00, 16.48s/it, accepor_recall=0.927, donor_recall=0.932, loss=0.000149, pred_l1_dist=0]


epoch: 5/10, train loss = 0.000153


Epoch (train) 6/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:24<00:00, 16.28s/it, accepor_recall=0.923, donor_recall=0.935, loss=0.000146, pred_l1_dist=0]


epoch: 6/10, train loss = 0.000141


Epoch (train) 7/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:46<00:00, 16.50s/it, accepor_recall=0.936, donor_recall=0.938, loss=0.000141, pred_l1_dist=0]


epoch: 7/10, train loss = 0.000133


Epoch (train) 8/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:28<00:00, 16.32s/it, accepor_recall=0.954, donor_recall=0.958, loss=0.000102, pred_l1_dist=0]


epoch: 8/10, train loss = 0.000107


Epoch (train) 9/10: 100%|██████████████████████████████████████████████████████████████| 101/101 [27:37<00:00, 16.41s/it, accepor_recall=0.96, donor_recall=0.963, loss=9.23e-5, pred_l1_dist=0]


epoch: 9/10, train loss = 0.000091


Epoch (train) 10/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:51<00:00, 16.55s/it, accepor_recall=0.964, donor_recall=0.968, loss=9.44e-5, pred_l1_dist=0]


epoch: 10/10, train loss = 0.000080


Epoch (train) 1/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:33<00:00, 16.37s/it, accepor_recall=0.864, donor_recall=0.875, loss=0.000223, pred_l1_dist=0]


epoch: 1/10, train loss = 0.011015


Epoch (train) 2/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:32<00:00, 16.36s/it, accepor_recall=0.897, donor_recall=0.903, loss=0.000168, pred_l1_dist=0]


epoch: 2/10, train loss = 0.000201


Epoch (train) 3/10: 100%|██████████████████████████████████████████████████████████████| 101/101 [27:39<00:00, 16.43s/it, accepor_recall=0.9, donor_recall=0.911, loss=0.000168, pred_l1_dist=0]


epoch: 3/10, train loss = 0.000176


Epoch (train) 4/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:51<00:00, 16.55s/it, accepor_recall=0.912, donor_recall=0.924, loss=0.000151, pred_l1_dist=0]


epoch: 4/10, train loss = 0.000164


Epoch (train) 5/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:29<00:00, 16.33s/it, accepor_recall=0.925, donor_recall=0.934, loss=0.000154, pred_l1_dist=0]


epoch: 5/10, train loss = 0.000152


Epoch (train) 6/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:39<00:00, 16.43s/it, accepor_recall=0.931, donor_recall=0.94, loss=0.000142, pred_l1_dist=0]


epoch: 6/10, train loss = 0.000140


Epoch (train) 7/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:31<00:00, 16.35s/it, accepor_recall=0.941, donor_recall=0.945, loss=0.000125, pred_l1_dist=0]


epoch: 7/10, train loss = 0.000131


Epoch (train) 8/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:36<00:00, 16.40s/it, accepor_recall=0.946, donor_recall=0.954, loss=0.000122, pred_l1_dist=0]


epoch: 8/10, train loss = 0.000106


Epoch (train) 9/10: 100%|██████████████████████████████████████████████████████████████| 101/101 [27:45<00:00, 16.50s/it, accepor_recall=0.96, donor_recall=0.965, loss=9.91e-5, pred_l1_dist=0]


epoch: 9/10, train loss = 0.000091


Epoch (train) 10/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:44<00:00, 16.48s/it, accepor_recall=0.967, donor_recall=0.97, loss=8.14e-5, pred_l1_dist=0]


epoch: 10/10, train loss = 0.000080


Epoch (train) 1/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:41<00:00, 16.45s/it, accepor_recall=0.863, donor_recall=0.876, loss=0.000232, pred_l1_dist=0]


epoch: 1/10, train loss = 0.010805


Epoch (train) 2/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:45<00:00, 16.49s/it, accepor_recall=0.894, donor_recall=0.906, loss=0.000186, pred_l1_dist=0]


epoch: 2/10, train loss = 0.000200


Epoch (train) 3/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:37<00:00, 16.41s/it, accepor_recall=0.913, donor_recall=0.92, loss=0.000163, pred_l1_dist=0]


epoch: 3/10, train loss = 0.000178


Epoch (train) 4/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:34<00:00, 16.38s/it, accepor_recall=0.921, donor_recall=0.928, loss=0.000156, pred_l1_dist=0]


epoch: 4/10, train loss = 0.000162


Epoch (train) 5/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:26<00:00, 16.31s/it, accepor_recall=0.923, donor_recall=0.927, loss=0.000153, pred_l1_dist=0]


epoch: 5/10, train loss = 0.000152


Epoch (train) 6/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:52<00:00, 16.56s/it, accepor_recall=0.928, donor_recall=0.932, loss=0.000129, pred_l1_dist=0]


epoch: 6/10, train loss = 0.000142


Epoch (train) 7/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:51<00:00, 16.55s/it, accepor_recall=0.934, donor_recall=0.94, loss=0.000125, pred_l1_dist=0]


epoch: 7/10, train loss = 0.000132


Epoch (train) 8/10: 100%|████████████████████████████████████████████████████████████| 101/101 [27:37<00:00, 16.42s/it, accepor_recall=0.953, donor_recall=0.956, loss=0.000102, pred_l1_dist=0]


epoch: 8/10, train loss = 0.000106


Epoch (train) 9/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:40<00:00, 16.44s/it, accepor_recall=0.962, donor_recall=0.967, loss=8.86e-5, pred_l1_dist=0]


epoch: 9/10, train loss = 0.000091


Epoch (train) 10/10: 100%|█████████████████████████████████████████████████████████████| 101/101 [27:38<00:00, 16.42s/it, accepor_recall=0.961, donor_recall=0.967, loss=9.1e-5, pred_l1_dist=0]

epoch: 10/10, train loss = 0.000080





In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
h5f = h5py.File('/odinn/tmp/benediktj/Data/SplicePrediction-050422/gencode_100k_dataset_test_.h5')

num_idx = len(h5f.keys())//2

test_dataset = h5pyDataset(h5f,list(range(num_idx)))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

temp = 1
n_models = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,heads=16,n_transformer_blocks=4,depth=2)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_100k_130722_{}'.format(i))) for i,model in enumerate(models)]

for model in models:
    model.eval()
    
Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
ce_2d = []

for (batch_chunks,target_chunks) in tqdm(test_loader):
    batch_chunks = torch.transpose(batch_chunks[0].to(device),1,2)
    target_chunks = torch.transpose(torch.squeeze(target_chunks[0].to(device),0),1,2)
    #print(np.max(target_chunks.cpu().numpy()[:,2,:]))
    n_chunks = int(np.ceil(batch_chunks.shape[0]/BATCH_SIZE))
    batch_chunks = torch.chunk(batch_chunks, n_chunks, dim=0)
    target_chunks = torch.chunk(target_chunks, n_chunks, dim=0)
    targets_list = []
    outputs_list = []
    for j in range(len(batch_chunks)):
        batch_features = batch_chunks[j]
        targets = target_chunks[j]
        outputs = ([models[i](batch_features)[1].detach() for i in range(n_models)])
        outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
        #outputs = (outputs[0]+outputs[1]+outputs[2])/n_models
        targets_list.extend(targets.unsqueeze(0))
        outputs_list.extend(outputs.unsqueeze(0))

    targets = torch.transpose(torch.vstack(targets_list),1,2).cpu().numpy()
    outputs = torch.transpose(torch.vstack(outputs_list),1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))

    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [08:24<00:00, 29.66s/it]


In [12]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

Cross entropy = 0.00022665928169708042

[1mAcceptor:[0m
0.9962	0.9484	0.9878	0.9912	0.9772	0.9788	0.3465	0.0029	0.0005	13552	14289.0	14289

[1mDonor:[0m
0.9962	0.9514	0.9901	0.9924	0.9798	0.9806	0.3656	0.0019	0.0002	13594	14289.0	14289


In [15]:
13552+13594

27146

In [14]:
(0.9484+0.9514)/2

0.9499

In [13]:
(0.9772+0.9798)/2

0.9784999999999999

In [12]:
setType = 'test'
annotation_test, transcriptToLabel_test, seqData = getData(data_dir, setType)    


In [13]:
temp = 1
n_models = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,heads=16,n_transformer_blocks=4,depth=2)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_100k_130722_{}'.format(i))) for i,model in enumerate(models)]
#nr = [0,2,3]
#[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_201221_{}'.format(nr[i]))) for i,model in enumerate(models)]
#chunkSize = num_idx/10
for model in models:
    model.eval()

Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
test_dataset = spliceDataset(annotation_test,transcriptToLabel_test,SL,CL_max)
test_dataset.seqData = seqData
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=0,collate_fn=collate_fn, pin_memory=True)
ce_2d = []

for (batch_chunks,target_chunks) in tqdm(test_loader):
    batch_chunks = torch.transpose(batch_chunks.to(device),1,2)
    target_chunks = torch.transpose(torch.squeeze(target_chunks.to(device),0),1,2)
    batch_chunks = torch.split(batch_chunks, BATCH_SIZE, dim=0)
    target_chunks = torch.split(target_chunks, BATCH_SIZE, dim=0)
    targets_list = []
    outputs_list = []
    for j in range(len(batch_chunks)):
        batch_features = batch_chunks[j]
        targets = target_chunks[j]
        outputs = ([models[i](batch_features)[1].detach() for i in range(n_models)])
        outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
        #outputs = (outputs[0]+outputs[1]+outputs[2])/n_models
        targets_list.extend(targets.unsqueeze(0))
        outputs_list.extend(outputs.unsqueeze(0))

    targets = torch.transpose(torch.vstack(targets_list),1,2).cpu().numpy()
    outputs = torch.transpose(torch.vstack(outputs_list),1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))
    
    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [3:11:08<00:00, 127.43s/it]


In [14]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

Cross entropy = 0.00013630991255302998

[1mAcceptor:[0m
0.9829	0.9393	0.9917	0.996	0.9685	0.9813	0.4760	0.0027	0.0004	84265	89712.0	89712

[1mDonor:[0m
0.9845	0.9438	0.994	0.9971	0.9722	0.9831	0.5154	0.0018	0.0002	84674	89712.0	89712


In [15]:
(0.9685+0.9722)/2

0.97035

In [16]:
84265+84674	

168939

In [17]:
(0.9393+0.9438)/2

0.94155