In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score

import pandas as pd
import matplotlib.pyplot as plt
import pickle
#import pickle5 as pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob



In [2]:
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score

from src.train import trainModel
from src.dataloader import getData,spliceDataset,h5pyDataset,getDataPointList,getDataPointListFull,DataPoint
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics
import copy
#import tensorflow as tf

In [3]:
!nvidia-smi

Fri Apr 21 17:06:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  Off  | 00000000:5E:00.0 Off |                    0 |
| N/A   33C    P0    32W / 250W |      0MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA Tesla V1...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   31C    P0    32W / 250W |      0MiB / 32510MiB |      2%      Defaul

In [4]:
#snps = df.drop_duplicates(subset=['not_sQTL'])['not_sQTL'].values
#res = pd.DataFrame({'Chr':[x.split(':')[0] for x in snps],'Pos':[int(x.split(':')[1]) for x in snps],'marker':snps})
#res.sort_values(['Chr','Pos'],ascending=True).to_csv('../Data/not_sQTL.gor',sep='\t',index=False)

In [5]:
rng = np.random.default_rng(23673)

In [6]:
data_dir = '../Data'

In [7]:
snv_list = pd.read_csv('../Data/snv_list.txt', sep='\t')

In [8]:
L = 32
N_GPUS = 8
k = 2
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
                21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = k*6*N_GPUS


CL = 2 * np.sum(AR*(W-1))

In [9]:
SL=5000
CL_max=40000

In [10]:
import pyfastx
data_dir = '../Data/'
fasta_file_path = '../Data/genome.fa'
gtf_file_path = '../Data/Homo_sapiens.GRCh38.87.db'
fasta = pyfastx.Fasta(fasta_file_path)

In [12]:
NUM_ACCUMULATION_STEPS = 1
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2,determenistic=True)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]

[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_031122_{}'.format(i))) for i,model in enumerate(models)]

for model in models:
    model.eval()

In [13]:
#gene_boundries = {}
#for gene in tqdm(genes):
#    gene_boundries[gene["gene_name"][0]] = [int(gene[3]),int(gene[4])]

In [14]:
#with open('/odinn/tmp/benediktj/Data/SplicePrediction-050422/gene_boundries.pkl', 'wb') as f:
#    pickle.dump(gene_boundries, f)

In [15]:
import gffutils

In [16]:
gtf = gffutils.FeatureDB(gtf_file_path)

In [17]:
mutant_snv = snv_list[snv_list['category']=='mutant']

In [None]:
def predictSplicing(seq,models):
    outputs = []
    for i in range(seq.shape[0]):
        batch_features = torch.tensor(seq[i,:,:], device=device).float().unsqueeze(0)
        batch_features = torch.swapaxes(batch_features,1,2)
        prediction = ([models[i](batch_features)[0].detach() for i in range(n_models)])
        prediction = torch.stack(prediction)
        prediction = torch.mean(prediction,dim=0)
        outputs.append(prediction)
    
    outputs = torch.cat(outputs,dim=2)
    outputs = outputs.cpu().detach().numpy()
    return outputs

def plotPrediction(outputs):
    plt.rcParams.update({'font.size': 18})
    fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(22, 6),sharex=True)
    x = np.arange(outputs.shape[2])
    ax1.plot(x,outputs[0,1,:],linewidth=2,zorder=-32)
    ax2.plot(x,outputs[0,2,:],linewidth=2,zorder=-32)
    plt.xlabel('Distance from transcript start (nt)')
    ax1.set_ylabel('Acceptor score')
    ax2.set_ylabel('Donor Score')
    ax1.legend(prop={'size': 14})
    ax2.legend(prop={'size': 14})
    plt.tight_layout()
    plt.show()

def ceil_div(x, y):

    return int(ceil(float(x)/y))


IN_MAP = np.asarray([[0, 0, 0, 0],
                     [1, 0, 0, 0],
                     [0, 1, 0, 0],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])

def one_hot_encode(Xd):

    return IN_MAP[Xd.astype('int8')]

def reformat_data(X0):
    # This function converts X0, Y0 of the create_datapoints function into
    # blocks such that the data is broken down into data points where the
    # input is a sequence of length SL+CL_max corresponding to SL nucleotides
    # of interest and CL_max context nucleotides, the output is a sequence of
    # length SL corresponding to the splicing information of the nucleotides
    # of interest. The CL_max context nucleotides are such that they are
    # CL_max/2 on either side of the SL nucleotides of interest.

    num_points = ceil_div(len(X0)-CL_max, SL)

    Xd = np.zeros((num_points, SL+CL_max))
    X0 = np.pad(X0, [0, SL], 'constant', constant_values=0)

    for i in range(num_points):
        Xd[i] = X0[SL*i:CL_max+SL*(i+1)]

    return Xd

def seqToArray(seq,strand):
    #seq = 'N'*(CL_max//2) + seq + 'N'*(CL_max//2)
    seq = seq.upper()
    seq = re.sub(r'[^AGTC]', '0',seq)
    seq = seq.replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
    if strand == '+':
        X0 = np.asarray([int(x) for x in seq])
            
    elif strand == '-':
        X0 = (5-np.asarray([int(x) for x in seq[::-1]])) % 5  # Reverse complement
        
    Xd = reformat_data(X0)
    return  one_hot_encode(Xd)

results = {}

for i in tqdm(range(mutant_snv.shape[0])):
    try:
        category,chrm,strand,pos,ref_s,alt_s,transcript_id,event_id = mutant_snv.iloc[i,:][['category','chr','strand','snp_position_hg38_0based_start','ref_allele', 'alt_allele','ensembl_transcript_id','internal_id']]
        pos = int(pos)+1
        chrm = f'chr{chrm}'
        gene_start = gtf[transcript_id].start
        gene_end = gtf[transcript_id].end
        start,end = np.max([pos-CL_max//2-2500+1,gene_start]),np.min([pos+2500+CL_max//2,gene_end])
        chrm_length = len(fasta[chrm])
        if start < 1:
            start = 1
        if end >chrm_length:
            end = chrm_length

        pos_s = pos-start

        ref = fasta[chrm][start-1:end].seq
        alt = ref
        ref_len = len(ref_s)
        alt_len = len(alt_s)
        #print(strand)
        #print(ref_s,ref[pos_s:(pos_s+ref_len)])
        assert ref_s == ref[pos_s:(pos_s+ref_len)]
        alt = alt[:pos_s]+alt_s+alt[(pos_s+ref_len):]
        alt_align = np.arange(len(ref))
        alt_align = np.concatenate([alt_align[:pos_s],np.repeat(pos_s,alt_len),alt_align[pos_s+ref_len:]])

        ref = 'N'*(CL_max//2+2500-1-(pos-start))+ref+(CL_max//2+2500-(end-pos))*'N'
        alt = 'N'*(CL_max//2+2500-1-(pos-start))+alt+(CL_max//2+2500-(end-pos))*'N'
        assert len(ref)==45000
        ref_len2 = len(ref)
        alt_len2 = len(alt)

        ref = seqToArray(ref,strand)
        alt = seqToArray(alt,strand)

        ref_prediction = predictSplicing(ref,models)[0,:,:]
        alt_prediction = predictSplicing(alt,models)[0,:,:]

        #tmp = np.zeros_like(ref_prediction)
        if strand=='-':
            ref_prediction = ref_prediction[:,::-1]
            alt_prediction = alt_prediction[:,::-1]

        ref_acceptor = ref_prediction[1,:]
        alt_acceptor = alt_prediction[1,:]
        ref_donor = ref_prediction[2,:]
        alt_donor = alt_prediction[2,:]

        #delta_1_a = alt_acceptor[:pos_s]-ref_acceptor[:pos_s]
        #delta_1_d = alt_donor[:pos_s]-ref_donor[:pos_s]
        #delta_3_a = alt_acceptor[pos_s+alt_len:]-ref_acceptor[pos_s+ref_len:]
        #delta_3_d = alt_donor[pos_s+alt_len:]-ref_donor[pos_s+ref_len:]

        #if ref_len2==alt_len2:
        #    delta_2_a = alt_acceptor[pos_s:pos_s+ref_len]-ref_acceptor[pos_s:pos_s+ref_len]
        #    delta_2_d = alt_donor[pos_s:pos_s+ref_len]-ref_donor[pos_s:pos_s+ref_len]
        #elif ref_len2>alt_len2:
        #    a_pad = np.pad(alt_acceptor[pos_s:pos_s+alt_len],(0, ref_len-alt_len), 'constant', constant_values=0)
        #    d_pad = np.pad(alt_donor[pos_s:pos_s+alt_len],(0, ref_len-alt_len), 'constant', constant_values=0)
        #    delta_2_a = a_pad-ref_acceptor[pos_s:pos_s+ref_len]
        #    delta_2_d = d_pad-ref_donor[pos_s:pos_s+ref_len]

       # elif ref_len2<alt_len2:
       #     a_pad = np.pad(ref_acceptor[pos_s:pos_s+ref_len],(0, alt_len-ref_len), 'constant', constant_values=0)
       #     d_pad = np.pad(ref_donor[pos_s:pos_s+ref_len],(0, alt_len-ref_len), 'constant', constant_values=0)
       #     delta_2_a = alt_acceptor[pos_s:pos_s+alt_len]-a_pad
       #     delta_2_d = alt_donor[pos_s:pos_s+alt_len]-d_pad

       #     delta_2_a =np.append(delta_2_a[:ref_len-1],delta_2_a[np.argmax(np.absolute(delta_2_a[ref_len-1:alt_len]))])
       #     delta_2_d =np.append(delta_2_d[:ref_len-1],delta_2_d[np.argmax(np.absolute(delta_2_d[ref_len-1:alt_len]))])

        #acceptorDelta = np.concatenate([delta_1_a,delta_2_a,delta_3_a])
        #donorDelta = np.concatenate([delta_1_d,delta_2_d,delta_3_d])
        acceptorDelta = alt_acceptor-ref_acceptor
        donorDelta = alt_donor-ref_donor
        top_a_creation_pos = np.argmax(acceptorDelta)
        top_d_creation_pos = np.argmax(donorDelta)
        top_a_disruption_pos = np.argmin(acceptorDelta)
        top_d_disruption_pos = np.argmin(donorDelta)
        top_a_creation_delta = acceptorDelta[top_a_creation_pos]
        top_d_creation_delta = donorDelta[top_d_creation_pos]
        top_a_disruption_delta = acceptorDelta[top_a_disruption_pos]
        top_d_disruption_delta = donorDelta[top_d_disruption_pos]

        results[event_id] = [gene_start,gene_end,top_a_creation_pos,top_d_creation_pos,top_a_disruption_pos,top_d_disruption_pos,top_a_creation_delta,top_d_creation_delta,-top_a_disruption_delta,-top_d_disruption_delta]
    except:
        print('{} failed'.format(event_id))

  1%|▍                                                                             | 177/28972 [01:33<3:43:54,  2.14it/s]

In [None]:
with open('/odinn/tmp/benediktj/Data/SplicePrediction-050422/mfas_transformer_gtex_210423.pkl', 'wb') as f:
    pickle.dump(results, f)

In [10]:
#with open('/odinn/tmp/benediktj/Data/SplicePrediction-050422/no_sqtl_deltas_transformer_gtex_130223.pkl', 'rb') as f:
#    results1 = pickle.load(f)

In [17]:
#with open('/odinn/tmp/benediktj/Data/SplicePrediction-050422/mfas_transformer_gtex_130323.pkl', 'rb') as f:
#    results = pickle.load(f)

In [None]:
df = pd.DataFrame(results).T
df.columns = ['gene_start','gene_end','top_a_creation_pos','top_d_creation_pos','top_a_disruption_pos','top_d_disruption_pos','top_a_creation_delta','top_d_creation_delta','top_a_disruption_delta','top_d_disruption_delta']
df.index.name = 'internal_id'

In [None]:
results = snv_list.merge(df,on='internal_id')

In [None]:
results

In [None]:
results[results['is_sdv']==True]

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc

In [None]:
tmp = results[['is_sdv','top_d_disruption_delta','top_a_disruption_delta']].dropna()
X,y = tmp[['top_d_disruption_delta','top_a_disruption_delta','top_d_disruption_delta','top_a_disruption_delta']].max(axis=1),tmp['is_sdv'].astype(int)
fpr1, tpr1,t1 = roc_curve(y, X)
auc_1 = auc(fpr1, tpr1)

#aucs_1.append(auc_1)
plt.plot(fpr1, tpr1,label=f"Transformer-40k (AUC = {auc_1 :.3f})")
plt.legend()
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
#plt.savefig('../Results/mafs_transformer_auc.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
plt.scatter(results['dpsi'],results[['top_d_disruption_delta','top_a_disruption_delta']].max(axis=1),s=1)
plt.ylabel('Transformer disruption delta score')
plt.xlabel('$\Delta$psi')
#plt.savefig('../Results/mafs_dpsi_correlation.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
results['fp'] = np.all([results['dpsi']>-0.2,results[['top_d_disruption_delta','top_a_disruption_delta']].max(axis=1)>0.5],axis=0)

In [None]:
higher = np.argmax(results[['top_d_disruption_delta','top_a_disruption_delta']].values,axis=1)
cond = [[x==0,x==1] for x in higher]
results['max_pos'] = results[['top_d_disruption_pos','top_a_disruption_pos']].values[cond]

In [None]:
results['splice_dist'] = np.abs(results['max_pos']-2500)

In [None]:
from matplotlib import colors

In [None]:
cond = results['fp']==False
plt.scatter(results['splice_dist'][cond],results['dpsi'][cond],s=1,c=results[['top_d_disruption_delta','top_a_disruption_delta']].max(axis=1)[cond], norm=colors.LogNorm())
cbar = plt.colorbar()
cbar.set_label('Transformer delta score', rotation=270, labelpad=15)
cond = results['fp']==True
plt.scatter(results['splice_dist'][cond],results['dpsi'][cond],c='red',s=10,label='False positives')


plt.legend()
plt.xlabel('Predicted Distance of Disrupted Site from variant')
plt.ylabel('$\Delta$psi')
#plt.savefig('../Results/mafs_predicted_distance_vs_dpsi.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
from sklearn.metrics import average_precision_score

In [None]:
from sklearn.metrics import precision_recall_curve
precision, recall,t1 = precision_recall_curve(y, X)
auc_1 = average_precision_score(y, X)

#aucs_1.append(auc_1)
plt.plot(recall,precision,label=f"Transformer-40k (PR-AUC = {auc_1 :.3f})")
plt.legend()
plt.ylabel("Precision")
plt.xlabel("Recall")
#plt.savefig('../Results/mafs_prc.png',dpi=300,bbox_inches='tight')
plt.show()