In [3]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score
import copy

from src.train import trainModel

#from src.dataloader import getData,spliceDataset,h5pyDataset,collate_fn
from src.dataloader import get_GTEX_v8_Data,spliceDataset,h5pyDataset,getDataPointList,getDataPointListGTEX,DataPointGTEX
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d,kl_div_2d
from src.model import SpliceAI_10K
from src.evaluation_metrics import print_topl_statistics,cross_entropy_2d,kullback_leibler_divergence_2d

In [4]:
!nvidia-smi

Wed Nov  8 17:23:16 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:5E:00.0 Off |                    0 |
| N/A   34C    P0    35W / 250W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   32C    P0    34W / 250W |      0MiB / 32510MiB |      2%      Defaul

In [5]:
#!pip install pickle5

In [6]:
rng = np.random.default_rng(23673)

In [7]:
#gtf = None

In [8]:
L = 32
N_GPUS = 3
k = 2
NUM_ACCUMULATION_STEPS=1
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS

k = NUM_ACCUMULATION_STEPS*k

CL = 2 * np.sum(AR*(W-1))

In [9]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-rnasplice-blood-070623/'
setType = 'train'
annotation, gene_to_label, seqData = get_GTEX_v8_Data(data_dir, setType,'annotation_GTEX_v8.txt')

In [8]:
#for key in tqdm(gene_to_label.keys()):
#    for d_key in gene_to_label[key][0]:
#        gene_to_label[key][0][d_key] = 1
#    for d_key in gene_to_label[key][1]:
#        gene_to_label[key][1][d_key] = 1

In [10]:
# Maximum nucleotide context length (CL_max/2 on either side of the 
# position of interest)
# CL_max should be an even number
# Sequence length of SpliceAIs (SL+CL will be the input length and
# SL will be the output length)

SL=5000
CL_max=10000

In [11]:
assert CL_max % 2 == 0

In [11]:
#train_gene, validation_gene = train_test_split(annotation['gene'].drop_duplicates(),test_size=.1,random_state=435)
#annotation_train = annotation[annotation['gene'].isin(train_gene)]
#annotation_validation = annotation[annotation['gene'].isin(validation_gene)]

In [12]:
#with open('{}/sparse_discrete_gene_label_data_{}.pickle'.format(data_dir,setType), 'rb') as handle:
#    gene_to_label_old = pickle.load(handle)

In [13]:
#for gene in gene_to_label_old.keys():
#    if len(gene_to_label[gene])==0:
#        gene_to_label[gene] = gene_to_label_old[gene]

In [12]:
train_dataset = spliceDataset(getDataPointListGTEX(annotation,gene_to_label,SL,CL_max,shift=SL))
#val_dataset = spliceDataset(getDataPointListGTEX(annotation_validation,gene_to_label,SL,CL_max,shift=SL))
train_dataset.seqData = seqData
#val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True)
#val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE//2, shuffle=False, num_workers=16)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 4
hs = []
learning_rate= k*1e-4
gamma=0.5
temp = 1
#final_lr = 1e-5
#gamma = 1/(learning_rate/final_lr)**(1/5) 

In [14]:
for model_nr in range(10):
    model_m = SpliceAI_10K(CL_max)
    model_m.apply(keras_init)
    model_m = model_m.to(device)
    if torch.cuda.device_count() > 1:
        #print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model_m = nn.DataParallel(model_m)
    
    model_m.load_state_dict(torch.load('../Results/PyTorch_Models/spliceai_encoder_10k_071122_{}'.format(model_nr)))
    modelFileName = '../Results/PyTorch_Models/spliceai_encoder_10k_finetune_rnasplice-blood_050623_{}'.format(model_nr)
    loss = categorical_crossentropy_2d().loss
    #loss = kl_div_2d(temp=temp).loss
    optimizer = torch.optim.AdamW(model_m.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
    warmup = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=100)
    h = trainModel(model_m,modelFileName,loss,train_loader,None,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=True,lowValidationGPUMem=True,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max,reinforce=False,continous_labels=False)
    hs.append(h)

    

    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.
Epoch (train) 1/4: 100%|██████████████| 1619/1619 [20:13<00:00,  1.33it/s, a_r=0.621, d_r=0.616, loss=0.000935]


epoch: 1/4, train loss = 0.001040


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:45<00:00,  1.37it/s, a_r=0.633, d_r=0.629, loss=0.000916]


epoch: 2/4, train loss = 0.000936


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:46<00:00,  1.36it/s, a_r=0.631, d_r=0.624, loss=0.000906]


epoch: 3/4, train loss = 0.000912


Epoch (train) 4/4: 100%|███████████████| 1619/1619 [19:45<00:00,  1.37it/s, a_r=0.63, d_r=0.635, loss=0.000893]


epoch: 4/4, train loss = 0.000895


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:45<00:00,  1.37it/s, a_r=0.619, d_r=0.624, loss=0.000984]


epoch: 1/4, train loss = 0.001047


Epoch (train) 2/4: 100%|███████████████| 1619/1619 [19:46<00:00,  1.36it/s, a_r=0.63, d_r=0.628, loss=0.000922]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|███████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.632, d_r=0.63, loss=0.000899]


epoch: 3/4, train loss = 0.000912


Epoch (train) 4/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.634, d_r=0.629, loss=0.000901]


epoch: 4/4, train loss = 0.000895


Epoch (train) 1/4: 100%|███████████████| 1619/1619 [19:47<00:00,  1.36it/s, a_r=0.624, d_r=0.622, loss=0.00098]


epoch: 1/4, train loss = 0.001044


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.627, d_r=0.627, loss=0.000956]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.633, d_r=0.629, loss=0.000913]


epoch: 3/4, train loss = 0.000912


Epoch (train) 4/4: 100%|███████████████| 1619/1619 [19:46<00:00,  1.36it/s, a_r=0.631, d_r=0.63, loss=0.000906]


epoch: 4/4, train loss = 0.000895


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.618, d_r=0.616, loss=0.000957]


epoch: 1/4, train loss = 0.001041


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.628, d_r=0.627, loss=0.000921]


epoch: 2/4, train loss = 0.000936


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:47<00:00,  1.36it/s, a_r=0.636, d_r=0.633, loss=0.000898]


epoch: 3/4, train loss = 0.000911


Epoch (train) 4/4: 100%|██████████████| 1619/1619 [19:47<00:00,  1.36it/s, a_r=0.634, d_r=0.631, loss=0.000891]


epoch: 4/4, train loss = 0.000895


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.623, d_r=0.623, loss=0.000972]


epoch: 1/4, train loss = 0.001037


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.627, d_r=0.631, loss=0.000929]


epoch: 2/4, train loss = 0.000935


Epoch (train) 3/4: 100%|███████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.633, d_r=0.63, loss=0.000905]


epoch: 3/4, train loss = 0.000911


Epoch (train) 4/4: 100%|████████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.637, d_r=0.633, loss=0.0009]


epoch: 4/4, train loss = 0.000895


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:46<00:00,  1.36it/s, a_r=0.621, d_r=0.621, loss=0.000988]


epoch: 1/4, train loss = 0.001038


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.626, d_r=0.627, loss=0.000901]


epoch: 2/4, train loss = 0.000935


Epoch (train) 3/4: 100%|███████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.635, d_r=0.63, loss=0.000903]


epoch: 3/4, train loss = 0.000913


Epoch (train) 4/4: 100%|███████████████| 1619/1619 [19:50<00:00,  1.36it/s, a_r=0.63, d_r=0.629, loss=0.000916]


epoch: 4/4, train loss = 0.000896


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.626, d_r=0.624, loss=0.000957]


epoch: 1/4, train loss = 0.001042


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.626, d_r=0.633, loss=0.000935]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.636, d_r=0.633, loss=0.000913]


epoch: 3/4, train loss = 0.000913


Epoch (train) 4/4: 100%|██████████████| 1619/1619 [19:47<00:00,  1.36it/s, a_r=0.638, d_r=0.633, loss=0.000905]


epoch: 4/4, train loss = 0.000896


Epoch (train) 1/4: 100%|███████████████| 1619/1619 [19:47<00:00,  1.36it/s, a_r=0.623, d_r=0.623, loss=0.00096]


epoch: 1/4, train loss = 0.001041


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.635, d_r=0.629, loss=0.000914]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:50<00:00,  1.36it/s, a_r=0.632, d_r=0.636, loss=0.000898]


epoch: 3/4, train loss = 0.000913


Epoch (train) 4/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.636, d_r=0.638, loss=0.000882]


epoch: 4/4, train loss = 0.000896


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.624, d_r=0.627, loss=0.000959]


epoch: 1/4, train loss = 0.001044


Epoch (train) 2/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.627, d_r=0.624, loss=0.000937]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.631, d_r=0.632, loss=0.000913]


epoch: 3/4, train loss = 0.000913


Epoch (train) 4/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.636, d_r=0.635, loss=0.000892]


epoch: 4/4, train loss = 0.000897


Epoch (train) 1/4: 100%|██████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.622, d_r=0.622, loss=0.000961]


epoch: 1/4, train loss = 0.001043


Epoch (train) 2/4: 100%|███████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.63, d_r=0.629, loss=0.000937]


epoch: 2/4, train loss = 0.000937


Epoch (train) 3/4: 100%|██████████████| 1619/1619 [19:48<00:00,  1.36it/s, a_r=0.633, d_r=0.633, loss=0.000914]


epoch: 3/4, train loss = 0.000913


Epoch (train) 4/4: 100%|███████████████| 1619/1619 [19:49<00:00,  1.36it/s, a_r=0.64, d_r=0.639, loss=0.000916]

epoch: 4/4, train loss = 0.000896





In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
h5f = h5py.File('/odinn/tmp/benediktj/SpliceAITrainingCode/dataset_test_0_10k.h5')

num_idx = len(h5f.keys())//2

test_dataset = h5pyDataset(h5f,list(range(num_idx)))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceAI_10K(CL_max)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/spliceai_encoder_10k_finetune_rnasplice-blood_050623_{}'.format(i))) for i,model in enumerate(models)]

for model in models:
    model.eval()
    
Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
ce_2d = []

for (batch_chunks,target_chunks) in tqdm(test_loader):
    batch_chunks = torch.transpose(batch_chunks[0].to(device),1,2)
    target_chunks = torch.transpose(torch.squeeze(target_chunks[0].to(device),0),1,2)
    #print(np.max(target_chunks.cpu().numpy()[:,2,:]))
    n_chunks = int(np.ceil(batch_chunks.shape[0]/BATCH_SIZE))
    batch_chunks = torch.chunk(batch_chunks, n_chunks, dim=0)
    target_chunks = torch.chunk(target_chunks, n_chunks, dim=0)
    targets_list = []
    outputs_list = []
    for j in range(len(batch_chunks)):
        batch_features = batch_chunks[j]
        targets = target_chunks[j]
        outputs = ([models[i](batch_features).detach() for i in range(n_models)])
        #outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
        outputs = torch.mean(torch.stack(outputs),dim=0)
        #outputs = odds_gmean(torch.stack(outputs))
        #outputs = (outputs[0]+outputs[1]+outputs[2])/n_models
        targets_list.extend(targets.unsqueeze(0))
        outputs_list.extend(outputs.unsqueeze(0))

    targets = torch.transpose(torch.vstack(targets_list),1,2).cpu().numpy()
    outputs = torch.transpose(torch.vstack(outputs_list),1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))

    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())

100%|██████████████████████████████████████████████████████████████████████████| 16/16 [06:18<00:00, 23.67s/it]


In [16]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

Cross entropy = 0.0005279586512631472

[1mAcceptor:[0m
0.9969	0.9496	0.9873	0.9897	0.9782	0.9918	0.8195	0.1928	0.0542	13569	14289.0	14289

[1mDonor:[0m
0.9971	0.9515	0.9896	0.9917	0.9802	0.9919	0.8320	0.2018	0.0556	13596	14289.0	14289


In [22]:
from src.dataloader import getData
setType = 'test'
annotation_test, transcriptToLabel_test, seqData = getData('/odinn/tmp/benediktj/Data/SplicePrediction-050422', setType)    

In [23]:
from src.dataloader import getDataPointListFull,DataPointFull

In [24]:
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceAI_10K(CL_max)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/spliceai_encoder_10k_finetune_rnasplice-blood_050623_{}'.format(i))) for i,model in enumerate(models)]
#nr = [0,2,3]
#[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_201221_{}'.format(nr[i]))) for i,model in enumerate(models)]
#chunkSize = num_idx/10
for model in models:
    model.eval()

Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
test_dataset = spliceDataset(getDataPointListFull(annotation_test,transcriptToLabel_test,SL,CL_max,shift=SL))
test_dataset.seqData = seqData
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


#targets_list = []
#outputs_list = []
ce_2d = []
for (batch_features ,targets) in tqdm(test_loader):
    batch_features = batch_features.type(torch.FloatTensor).to(device)
    targets = targets.to(device)[:,:,CL_max//2:-CL_max//2]
    outputs = ([models[i](batch_features).detach() for i in range(n_models)])
    #outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
    outputs = torch.stack(outputs)
    outputs = torch.mean(outputs,dim=0)
    #outputs = odds_gmean(outputs)
    #targets_list.extend(targets.unsqueeze(0))
    #outputs_list.extend(outputs.unsqueeze(0))

    targets = torch.transpose(targets,1,2).cpu().numpy()
    outputs = torch.transpose(outputs,1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))

    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())


100%|██████████████████████████████████████████████████████████████████████████████████████| 1386/1386 [23:40<00:00,  1.02s/it]


In [25]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

Cross entropy = 0.0003201272902582534

[1mAcceptor:[0m
0.9799	0.9255	0.9885	0.9948	0.9583	0.9916	0.8257	0.1771	0.0503	83027	89712.0	89712

[1mDonor:[0m
0.9815	0.9302	0.9916	0.9964	0.9634	0.9916	0.8360	0.1814	0.0518	83451	89712.0	89712


In [17]:
setType = 'test'
annotation_test, gene_to_label, seqData = get_GTEX_v8_Data(data_dir, setType,'annotation_GTEX_v8.txt')

In [18]:
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceAI_10K(CL_max)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/spliceai_encoder_10k_finetune_rnasplice-blood_050623_{}'.format(i))) for i,model in enumerate(models)]
#nr = [0,2,3]
#[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_201221_{}'.format(nr[i]))) for i,model in enumerate(models)]
#chunkSize = num_idx/10
for model in models:
    model.eval()

Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
test_dataset = spliceDataset(getDataPointListGTEX(annotation_test,gene_to_label,SL,CL_max,shift=SL))
test_dataset.seqData = seqData
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


#targets_list = []
#outputs_list = []
ce_2d = []
for (batch_features ,targets) in tqdm(test_loader):
    batch_features = batch_features.type(torch.FloatTensor).to(device)
    targets = targets.to(device)[:,:,CL_max//2:-CL_max//2]
    outputs = ([models[i](batch_features).detach() for i in range(n_models)])
    #outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
    outputs = torch.stack(outputs)
    outputs = torch.mean(outputs,dim=0)
    #outputs = odds_gmean(outputs)
    #targets_list.extend(targets.unsqueeze(0))
    #outputs_list.extend(outputs.unsqueeze(0))

    targets = torch.transpose(targets,1,2).cpu().numpy()
    outputs = torch.transpose(outputs,1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))

    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())


    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.
100%|████████████████████████████████████████████████████████████████████████| 803/803 [25:53<00:00,  1.93s/it]


In [19]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

Cross entropy = 0.0007128483277989966

[1mAcceptor:[0m
0.9977	0.7404	0.8768	0.951	0.8324	0.9444	0.2401	0.0732	0.0207	73199	98870.0	98870

[1mDonor:[0m
0.9976	0.7412	0.8749	0.9459	0.8308	0.9422	0.2447	0.0743	0.0207	74201	100114.0	100114


In [20]:
(0.7404+0.7412)/2

0.7407999999999999

In [21]:
(0.8324+0.8308)/2

0.8316

In [22]:
98870+100114

198984

In [23]:
73199+74201

147400