In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score
import copy

from src.train import trainModel
from src.dataloader import getData,spliceDataset,h5pyDataset,collate_fn
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.models import SpliceFormer, SpliceAI_10K
from src.evaluation_metrics import print_topl_statistics

In [2]:
data_dir = '../Data'

CL_max=40000
# Maximum nucleotide context length (CL_max/2 on either side of the 
# position of interest)
# CL_max should be an even number
SL=5000

BATCH_SIZE = 1

setType = 'test'
annotation, transcriptToLabel, seqData = getData(data_dir, setType)

In [3]:
def makeTrackPlot(targets_list,outputs_list,name,showPlot=False):
    fig = plt.figure(figsize=(40, 6))
    plt.plot(outputs_list[:,2],c='green')
    plt.plot(outputs_list[:,1],c='red')

    plt.scatter(np.arange(targets_list.shape[0])[targets_list[:,1]==1],1.1*np.ones(targets_list.shape[0])[targets_list[:,1]==1],c='red',marker="v", clip_on=False, zorder=10,label='Acceptor')
    plt.scatter(np.arange(targets_list.shape[0])[targets_list[:,2]==1],1.05*np.ones(targets_list.shape[0])[targets_list[:,2]==1],c='green',marker="^", clip_on=False, zorder=10,label='Donor')
    plt.legend( bbox_to_anchor=(1, 0.5))
    plt.axis([0,targets_list.shape[0],0,1.1])
    plt.xlabel('Distance from transcript start')
    plt.ylabel('Score')
    plt.title(name)
    plt.rcParams['savefig.facecolor']='white'
    #fig.savefig('/odinn/tmp/benediktj/Data/SplicePrediction/TransformerTraingSetTracks_070122/{}_Score.svg'.format(name), format='svg', dpi=400)
    if showPlot:
        plt.show()
    plt.close()
        
def ceil_div(x, y):
    return int(ceil(float(x)/y))        
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_models = 5
temp=1
transcriptToLoss = {}
loss = categorical_crossentropy_2d().loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(40000)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
transformer_models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_270522_{}'.format(i))) for (i,model) in enumerate(transformer_models)]

test_dataset = spliceDataset(annotation,transcriptToLabel,SL=5000,CL_max=40000)
test_dataset.seqData = seqData
test_loader_40k = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,collate_fn=collate_fn, pin_memory=True)

model_m = SpliceAI_10K(10000).to(device)
if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

spliceai_models = [copy.deepcopy(model_m) for i in range(5)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/spliceai_encoder_10k_060522_{}'.format(i))) for i,model in enumerate(spliceai_models)]
#[model.load_state_dict(torch.load('../Results/PyTorch_Models/SpliceAI_no_weights_4_{}'.format(i))) for i,model in enumerate(spliceai_models)]

for model in transformer_models:
    model.eval()
    
for model in spliceai_models:
    model.eval()
    
test_dataset = spliceDataset(annotation,transcriptToLabel,SL=5000,CL_max=10000)
test_dataset.seqData = seqData
test_loader_10k = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,collate_fn=collate_fn, pin_memory=True)

def getPredictions(models,batch_chunks,target_chunks,isTransformer):
    batch_chunks = torch.transpose(batch_chunks.to(device),1,2)
    if target_chunks.shape[0]>1:
        target_chunks = torch.squeeze(target_chunks.to(device),0)
    else:
        target_chunks = target_chunks.to(device)
    target_chunks = torch.transpose(target_chunks,1,2)
    #n_chunks = int(np.ceil(batch_chunks.shape[0]/BATCH_SIZE))
    batch_chunks = torch.split(batch_chunks, BATCH_SIZE, dim=0)
    target_chunks = torch.split(target_chunks, BATCH_SIZE, dim=0)
    targets_list = []
    outputs_list = []
    for j in range(len(batch_chunks)):
        batch_features = batch_chunks[j]
        targets = target_chunks[j]
        if isTransformer:
            outputs = ([models[i](batch_features)[1].detach() for i in range(n_models)])
        else:
            outputs = ([models[i](batch_features).detach() for i in range(n_models)])
        outputs = (outputs[0]+outputs[1]+outputs[2]+outputs[3]+outputs[4])/n_models
        targets_list.extend(targets)
        outputs_list.extend(outputs)
    targets_list = torch.hstack(targets_list).T
    outputs_list = torch.hstack(outputs_list).T
    targets_list = targets_list.cpu().numpy()
    outputs_list = outputs_list.cpu().numpy()
    return outputs_list,targets_list

for i,((batch_chunks_10k,target_chunks_10k),(batch_chunks_40k,target_chunks_40k)) in tqdm(enumerate(zip(test_loader_10k,test_loader_40k))):
    spliceai_outputs,targets = getPredictions(spliceai_models,batch_chunks_10k,target_chunks_10k,False)
    transformer_outputs,_ = getPredictions(transformer_models,batch_chunks_40k,target_chunks_40k,True)
    data = np.concatenate([targets[:,1:],spliceai_outputs[:,1:],transformer_outputs[:,1:]],axis=1)
    toKeep = np.any(data>0.01,axis=1)
    info = annotation.iloc[i,:]
    length = 5000*ceil_div(int(info['tx_end'])-int(info['tx_start'])+1, 5000)
    #length = int(5000*np.ceil((int(info['tx_end'])-int(info['tx_start']))/5000))
    if info['strand'] == '+':
        posistion = np.expand_dims(np.arange(int(info['tx_start']),int(info['tx_start']+length)),1)
    else:
        posistion = np.expand_dims(np.arange(int(info['tx_start']),int(info['tx_start']+length)),1)[::-1]
    name = np.expand_dims(np.repeat(info['name'],data.shape[0]),1)
    strand = np.expand_dims(np.repeat(info['strand'],data.shape[0]),1)
    chrom = np.expand_dims(np.repeat(info['chrom'],data.shape[0]),1)
    data = np.concatenate([name,chrom,strand,posistion,data],axis=1)
    data = data[toKeep]
    if i==0:
        results = data
    else:
        results = np.vstack([results,data])



8955it [8:48:15,  3.54s/it]


In [4]:
results = pd.DataFrame(results,columns=['Name','Chrom','Strand','Posistion','Acceptor label','Donor label','SpliceAI acceptor score','SpliceAI donor score','Transformer acceptor score','Transformer donor score'])
results.to_csv(data_dir+'/SpliceSitePredictionScores_020622.tsv',sep='\t',index=False)