In [None]:
import numpy as np
from tqdm import tqdm
import os

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt

from transformers import get_constant_schedule_with_warmup
import copy

from src.train import trainModel
from src.dataloader import (
    getData_multispecies,
    spliceDataset,
    getDataPointListFull,
)

from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics,cross_entropy_2d
from src.gpu_metrics import run_bootstrap

In [None]:
!nvidia-smi

In [None]:
rng = np.random.default_rng(23673)

In [None]:
L = 32
N_GPUS = 3
k = 2
NUM_ACCUMULATION_STEPS=1
NUM_WORKERS = min(8, os.cpu_count() or 8)
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS

k = NUM_ACCUMULATION_STEPS*k

CL = 2 * np.sum(AR*(W-1))

In [None]:
data_dir = "data/processed_data"

species_list = [
    "homo_sapiens",
    "mus_musculus",
    "danio_rerio",
    "pan_troglodytes",
]

annotation_train, transcriptToLabel_train, seqData = getData_multispecies(
    data_dir, 'train', species_list
)

annotation_val, transcriptToLabel_val, _ = getData_multispecies(
    data_dir, 'val', species_list
)

In [None]:
SL=5000
CL_max=5000

In [None]:
assert CL_max % 2 == 0

In [None]:
# Build datasets using the pre-defined train/val splits
train_dataset = spliceDataset(
    getDataPointListFull(annotation_train, transcriptToLabel_train, SL, CL_max, shift=SL)
)
val_dataset = spliceDataset(
    getDataPointListFull(annotation_val, transcriptToLabel_val, SL, CL_max, shift=SL)
)

# required, since spliceDataset does not include seqData internally
train_dataset.seqData = seqData
val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 5
hs = []
learning_rate= k*1e-3
gamma=0.5

In [None]:
for model_nr in range(10):
    model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2)
    model_m.apply(keras_init)
    model_m = model_m.to(device)
    if torch.cuda.device_count() > 1:
        model_m = nn.DataParallel(model_m)
    
    modelFileName = '../Results/PyTorch_Models/transformer_encoder_10k_090724_{}'.format(model_nr)
    loss = categorical_crossentropy_2d().loss
    optimizer = torch.optim.AdamW(model_m.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
    warmup = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)
    h = trainModel(model_m,modelFileName,loss,train_loader,val_loader,optimizer,scheduler,warmup,BATCH_SIZE,epochs,device,skipValidation=False,NUM_ACCUMULATION_STEPS=NUM_ACCUMULATION_STEPS,CL_max=CL_max)
    hs.append(h)

    plt.plot(range(epochs),h['loss'],label='Train')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
setType = 'test'
data_dir = 'data/processed_data'
annotation_test, transcriptToLabel_test, seqData = getData_multispecies(
    data_dir, setType, species_list
)

In [None]:
BATCH_SIZE = 16*8*8

In [None]:
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2,determenistic=True)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_10k_090724_{}'.format(i))) for i,model in enumerate(models)]
for model in models:
    model.eval()

Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]
test_dataset = spliceDataset(getDataPointListFull(annotation_test,transcriptToLabel_test,SL,CL_max,shift=SL))
test_dataset.seqData = seqData
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


ce_2d = []
for (batch_features ,targets) in tqdm(test_loader):
    batch_features = batch_features.type(torch.FloatTensor).to(device)
    targets = targets.to(device)[:,:,CL_max//2:-CL_max//2]
    outputs = ([models[i](batch_features)[0].detach() for i in range(n_models)])
    outputs = torch.stack(outputs)
    outputs = torch.mean(outputs,dim=0)

    targets = torch.transpose(targets,1,2).cpu().numpy()
    outputs = torch.transpose(outputs,1,2).cpu().numpy()
    ce_2d.append(cross_entropy_2d(targets,outputs))

    is_expr = (targets.sum(axis=(1,2)) >= 1)
    Y_true_acceptor.extend(targets[is_expr, :, 1].flatten())
    Y_true_donor.extend(targets[is_expr, :, 2].flatten())
    Y_pred_acceptor.extend(outputs[is_expr, :, 1].flatten())
    Y_pred_donor.extend(outputs[is_expr, :, 2].flatten())

In [None]:
mean_ce = np.mean(ce_2d)
print('Cross entropy = {}'.format(mean_ce))
Y_true_acceptor, Y_pred_acceptor,Y_true_donor, Y_pred_donor = np.array(Y_true_acceptor), np.array(Y_pred_acceptor),np.array(Y_true_donor), np.array(Y_pred_donor)
print("\n\033[1m{}:\033[0m".format('Acceptor'))
acceptor_val_results = print_topl_statistics(Y_true_acceptor, Y_pred_acceptor)
print("\n\033[1m{}:\033[0m".format('Donor'))
donor_val_results =print_topl_statistics(Y_true_donor, Y_pred_donor)

In [None]:
df = pd.DataFrame({'Y_true_acceptor':Y_true_acceptor,'Y_pred_acceptor':Y_pred_acceptor,'Y_true_donor':Y_true_donor,'Y_pred_donor':Y_pred_donor})
df.to_csv('transformer_10k_test_set_predictions_090724.csv.gz',index=False)

In [None]:
df = pd.read_csv('transformer_10k_test_set_predictions_090724.csv.gz')
run_bootstrap(df['Y_true_acceptor'].astype(np.int8),df['Y_pred_acceptor'].astype(np.float32),df['Y_true_donor'].astype(np.int8),df['Y_pred_donor'].astype(np.float32),n_bootstraps = 1000)