In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score

import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.linear_model import LogisticRegression

In [2]:
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score

from src.train import trainModel
from src.dataloader import get_GTEX_v8_Data,spliceDataset,h5pyDataset,getDataPointList,getDataPointListGTEX,DataPointGTEX
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics
import copy
#import tensorflow as tf

In [3]:
rng = np.random.default_rng(23673)

In [4]:
L = 32
N_GPUS = 3
k = 2
NUM_ACCUMULATION_STEPS=1
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS

k = NUM_ACCUMULATION_STEPS*k

CL = 2 * np.sum(AR*(W-1))

In [5]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8'
setType = 'all'
annotation, gene_to_label, seqData = get_GTEX_v8_Data(data_dir, setType,'annotation_GTEX_v8.txt')

In [6]:
SL=5000
CL_max=40000

In [7]:
annotation = annotation[annotation['name']=='LDLR']

In [8]:
train_dataset = spliceDataset(getDataPointListGTEX(annotation,gene_to_label,SL,CL_max,shift=SL))
#val_dataset = spliceDataset(getDataPointListGTEX(annotation_validation,gene_to_label,SL,CL_max,shift=SL))
train_dataset.seqData = seqData
#val_dataset.seqData = seqData

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=16, pin_memory=True)
#val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE//4, shuffle=False, num_workers=16)

In [9]:
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2,determenistic=True,crop=False)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]
[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_finetune_rnasplice-blood_all_050623_{}'.format(i))) for i,model in enumerate(models)]
#nr = [0,2,3]
#[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_201221_{}'.format(nr[i]))) for i,model in enumerate(models)]
#chunkSize = num_idx/10
for model in models:
    model.eval()

Y_true_acceptor, Y_pred_acceptor = [],[]
Y_true_donor, Y_pred_donor = [],[]

#targets_list = []
#outputs_list = []
ce_2d = []

In [10]:
def generate_point_mutations(input_tensor, i, j,base_map):
    """
    Generate all point mutations within the range [i, j] for a batch of one-hot encoded genomic sequences.

    Args:
    input_tensor (torch.Tensor): Batch of one-hot encoded genomic sequences with shape (b, 4, n).
    i (int): Start index of the mutation range.
    j (int): End index of the mutation range.

    Returns:
    torch.Tensor: Batch of point mutations with shape (b, 4, n, j - i).
    """

    # Extract batch size and sequence length
    b, _, n = input_tensor.shape

    # Create a copy of the input tensor to store point mutations

    # Loop through the mutation range [i, j]
    for pos in range(i, j):
        # Iterate over each position in the sequence
        #for pos in range(n):
        # Skip the position if it's within the mutation range
        #if pos >= i and pos < j:
        #    continue

        # Get the original base at the current position
        original_base = torch.argmax(input_tensor[:, :, pos], dim=1).cpu().numpy()
        ref_base = [base_map[base] for base in original_base]
        # Generate a point mutation for each base
        for base in range(4):
            # Create a copy of the input tensor to modify
            mutated_sequence = input_tensor.clone()

            # Set the mutated base at the current position
            mutated_sequence[:, :, pos] = 0
            mutated_sequence[:, base, pos] = 1

            # Append the mutated sequence to the point mutations tensor
            yield mutated_sequence,base_map[base],ref_base

In [20]:
def getDeltas(delta,idx,inside_gene):
    acceptorDelta = delta[:,1,:]*inside_gene
    donorDelta = delta[:,2,:]*inside_gene
    pos_gain_a = torch.argmax(acceptorDelta,dim=1)
    pos_gain_d = torch.argmax(donorDelta,dim=1)
    pos_loss_a = torch.argmax(-acceptorDelta,dim=1)
    pos_loss_d = torch.argmax(-donorDelta,dim=1)

    delta_gain_a = acceptorDelta.gather(1,pos_gain_a.unsqueeze(1)).cpu().numpy()[:,0]
    delta_gain_d = donorDelta.gather(1,pos_gain_d.unsqueeze(1)).cpu().numpy()[:,0]
    delta_loss_a = -acceptorDelta.gather(1,pos_loss_a.unsqueeze(1)).cpu().numpy()[:,0]
    delta_loss_d = -donorDelta.gather(1,pos_loss_d.unsqueeze(1)).cpu().numpy()[:,0]
    pos_gain_a = pos_gain_a.cpu().numpy()
    pos_gain_d = pos_gain_d.cpu().numpy()
    pos_loss_a = pos_loss_a.cpu().numpy()
    pos_loss_d = pos_loss_d.cpu().numpy()
    return delta_gain_a,delta_loss_a,delta_gain_d,delta_loss_d,pos_gain_a-idx,pos_loss_a-idx,pos_gain_d-idx,pos_loss_d-idx

In [32]:
loader = train_loader
base_map = 'ACGT'
chrom = annotation.chrom.values[0]
strand = annotation.strand.values[0]
start = annotation.tx_start.values[0]
n_loops = 0
for (batch_features ,target) in loader:
    batch_features = batch_features.type(torch.FloatTensor).to(device)
    ref_pred = ([models[i](batch_features)[0].detach() for i in range(n_models)])
    ref_pred = torch.stack(ref_pred)
    ref_pred = torch.mean(ref_pred,dim=0)
    inside_gene = (batch_features.sum(axis=(1)) >= 1)
    
    mutations = generate_point_mutations(batch_features, CL_max//2, CL_max//2+SL,base_map)
    for i,(mutation,alt_base,ref_base) in tqdm(enumerate(mutations)):
        mutation = mutation.type(torch.FloatTensor).to(device)
        
        alt_pred = ([models[i](mutation)[0].detach() for i in range(n_models)])
        alt_pred = torch.stack(alt_pred)
        alt_pred = torch.mean(alt_pred,dim=0)
        
        delta = alt_pred-ref_pred
        a1,b1,c1,d1,a2,b2,c2,d2 = getDeltas(delta,CL_max//2+i//4,inside_gene)
        df = pd.DataFrame({'CHROM':chrom,'POS':start+np.arange(mutation.shape[0])*SL+i//4,'REF':ref_base,'ALT':alt_base,'DS_AG':a1,'DS_AL':b1,'DS_DG':c1,'DS_DL':d1,'DP_AG':a2,'DP_AL':b2,'DP_DG':c2,'DP_DL':d2})
        df = df[df['REF']!=df['ALT']]
        if n_loops==0:
            results = df
        else:
            results = pd.concat([results,df],axis=0)
        n_loops+=1

20000it [2:17:54,  2.42it/s]


In [33]:
results.sort_values(['CHROM','POS','REF','ALT']).to_csv('../Results/LDLR_splice_delta.tsv',sep='\t')

In [34]:
results = results.sort_values(['CHROM','POS','REF','ALT'])

In [35]:
results

Unnamed: 0,CHROM,POS,REF,ALT,DS_AG,DS_AL,DS_DG,DS_DL,DP_AG,DP_AL,DP_DG,DP_DL
0,chr19,11089463,G,A,0.000455,0.000215,0.000251,0.000082,20,1948,19399,1456
0,chr19,11089463,G,C,0.000301,0.000630,0.000789,0.001100,20,1948,42,1598
0,chr19,11089463,G,T,0.000875,0.000120,0.000483,0.000641,20,24123,1598,42
0,chr19,11089464,T,A,0.000758,0.000513,0.001016,0.001127,2,15,41,1597
0,chr19,11089464,T,C,0.000108,0.000543,0.000779,0.001323,0,1947,41,1597
...,...,...,...,...,...,...,...,...,...,...,...,...
8,chr19,11134461,A,G,0.000014,0.000086,0.000032,0.000016,-1227,-3174,-2052,-3708
8,chr19,11134461,A,T,0.000004,0.000038,0.000016,0.000020,-2233,-3178,-2666,-3708
8,chr19,11134462,A,C,0.000006,0.000055,0.000019,0.000008,-11459,-3175,-17453,-3709
8,chr19,11134462,A,G,0.000013,0.000042,0.000030,0.000016,-1545,-3179,-2667,-3709


In [37]:
results[(results[['DS_AG','DS_AL','DS_DG','DS_DL']]>0.2).any(axis=1)].sort_values('DS_AG')

Unnamed: 0,CHROM,POS,REF,ALT,DS_AG,DS_AL,DS_DG,DS_DL,DP_AG,DP_AL,DP_DG,DP_DL
0,chr19,11091062,G,A,0.000051,0.007535,0.015981,0.228605,-1556,349,-39,-1
0,chr19,11091062,G,C,0.000052,0.008070,0.013035,0.228569,-1556,349,-39,-1
0,chr19,11091067,T,A,0.000052,0.008301,0.013570,0.208772,-1561,344,-44,-6
0,chr19,11091063,T,A,0.000052,0.008837,0.013771,0.228600,-1557,348,-144,-2
0,chr19,11091063,T,C,0.000052,0.008439,0.013543,0.228489,-1557,348,-144,-2
...,...,...,...,...,...,...,...,...,...,...,...,...
4,chr19,11110639,T,A,0.999154,0.779250,0.005006,0.010409,2,13,702,143
4,chr19,11113307,C,A,0.999156,0.504283,0.016423,0.001758,2,-29,-201,89
3,chr19,11107381,G,A,0.999165,0.937096,0.001074,0.003443,2,11,9628,403
4,chr19,11113267,C,A,0.999431,0.811681,0.005142,0.000562,2,11,-161,561


In [36]:
results[results['POS']==11116650]

Unnamed: 0,CHROM,POS,REF,ALT,DS_AG,DS_AL,DS_DG,DS_DL,DP_AG,DP_AL,DP_DG,DP_DL
5,chr19,11116650,C,A,0.003215,0.000597,7.6e-05,0.002107,103,14924,453,87
5,chr19,11116650,C,G,0.000499,0.000555,0.004346,0.000283,103,14924,87,15145
5,chr19,11116650,C,T,0.003463,0.000515,0.000385,0.005563,103,-295,453,87


In [13]:
results = pd.read_csv('../Results/LDLR_splice_delta.tsv',sep='\t',index_col=0)

In [19]:
splice_ai_results = pd.read_csv('/odinn/groups/machinelearning/genomics/data/SpliceAI_scores/spliceai_scores.raw.snv.hg38.Over02.tsv',sep=' ',header=None)

  splice_ai_results = pd.read_csv('/odinn/groups/machinelearning/genomics/data/SpliceAI_scores/spliceai_scores.raw.snv.hg38.Over02.tsv',sep=' ',header=None)


In [23]:
results

Unnamed: 0,CHROM,POS,REF,ALT,DS_AG,DS_AL,DS_DG,DS_DL,DP_AG,DP_AL,DP_DG,DP_DL
0,chr19,11089463,G,A,0.000455,0.000215,0.000251,0.000082,20,1948,19399,1456
0,chr19,11089463,G,C,0.000301,0.000630,0.000789,0.001100,20,1948,42,1598
0,chr19,11089463,G,T,0.000875,0.000120,0.000483,0.000641,20,24123,1598,42
0,chr19,11089464,T,A,0.000758,0.000513,0.001016,0.001127,2,15,41,1597
0,chr19,11089464,T,C,0.000108,0.000543,0.000779,0.001323,0,1947,41,1597
...,...,...,...,...,...,...,...,...,...,...,...,...
8,chr19,11134461,A,G,0.000014,0.000086,0.000032,0.000016,-1227,-3174,-2052,-3708
8,chr19,11134461,A,T,0.000004,0.000038,0.000016,0.000020,-2233,-3178,-2666,-3708
8,chr19,11134462,A,C,0.000006,0.000055,0.000019,0.000008,-11459,-3175,-17453,-3709
8,chr19,11134462,A,G,0.000013,0.000042,0.000030,0.000016,-1545,-3179,-2667,-3709


In [22]:
splice_ai_results.columns = ['CHROM', 'POS', 'REF', 'ALT', 'SpliceAI_delta']

In [25]:
splice_ai_results = splice_ai_results[splice_ai_results['CHROM']==19]

In [31]:
splice_ai_results = splice_ai_results[np.all([splice_ai_results['POS']<=11134462,splice_ai_results['POS']>=11089463],axis=0)]

In [32]:
splice_ai_results

Unnamed: 0,CHROM,POS,REF,ALT,SpliceAI_delta
16679893,19,11089616,G,A,SpliceAI=A|LDLR|0.00|0.00|0.32|0.99|0|-23|27|-1
16679894,19,11089616,G,C,SpliceAI=C|LDLR|0.00|0.00|0.34|1.00|0|-23|27|-1
16679895,19,11089616,G,T,SpliceAI=T|LDLR|0.00|0.00|0.33|1.00|27|-23|27|-1
16679896,19,11089617,T,A,SpliceAI=A|LDLR|0.00|0.00|0.39|1.00|-1|-24|26|-2
16679897,19,11089617,T,C,SpliceAI=C|LDLR|0.00|0.00|0.31|0.51|-2|-24|26|-2
...,...,...,...,...,...
16680936,19,11131285,A,T,SpliceAI=T|LDLR|0.01|0.55|0.00|0.00|-31|2|4|2
16680937,19,11131286,G,A,SpliceAI=A|LDLR|0.01|0.55|0.00|0.00|9|1|3|1
16680938,19,11131286,G,C,SpliceAI=C|LDLR|0.01|0.55|0.00|0.00|9|1|3|1
16680939,19,11131286,G,T,SpliceAI=T|LDLR|0.00|0.55|0.00|0.00|9|1|3|1


In [33]:
results[(results[['DS_AG','DS_AL','DS_DG','DS_DL']]>0.2).any(axis=1)]

Unnamed: 0,CHROM,POS,REF,ALT,DS_AG,DS_AL,DS_DG,DS_DL,DP_AG,DP_AL,DP_DG,DP_DL
0,chr19,11089492,C,G,0.003643,0.000381,0.325300,0.004704,18157,1099,-1,13
0,chr19,11089615,G,C,0.012459,0.003162,0.289203,0.001296,1796,235,28,0
0,chr19,11089615,G,T,0.014856,0.003259,0.244395,0.001629,1796,235,28,0
0,chr19,11089616,G,A,0.047505,0.003401,0.881801,0.889001,1795,234,27,-1
0,chr19,11089616,G,C,0.051614,0.003496,0.903167,0.976664,1795,234,27,-1
...,...,...,...,...,...,...,...,...,...,...,...,...
8,chr19,11131286,G,A,0.013055,0.463166,0.003940,0.004452,288,1,3,-533
8,chr19,11131286,G,C,0.014993,0.460369,0.001994,0.004622,288,1,3,-533
8,chr19,11131286,G,T,0.025595,0.463588,0.003406,0.008612,288,1,509,-533
8,chr19,11131418,C,G,0.416425,0.050967,0.463469,0.037729,156,-131,-1,-14409


In [34]:
results[(results[['DS_AG','DS_AL','DS_DG','DS_DL']]>0.2).any(axis=1)].shape[0]/splice_ai_results.shape[0]

1.3234732824427482