In [1]:
import numpy as np
import sys
import time
import h5py
from tqdm import tqdm

import numpy as np
import re
from math import ceil
from sklearn.metrics import average_precision_score

import pandas as pd
import matplotlib.pyplot as plt
import pickle

from sklearn.model_selection import train_test_split

from scipy.sparse import load_npz
from glob import glob

from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.linear_model import LogisticRegression

from torch.utils.data import Dataset
import torch
import torch.nn as nn
from transformers import get_constant_schedule_with_warmup
from sklearn.metrics import precision_score,recall_score,accuracy_score

from src.train import trainModel
from src.dataloader import get_GTEX_v8_Data,spliceDataset,h5pyDataset,getDataPointList,getDataPointListGTEX,DataPointGTEX
from src.weight_init import keras_init
from src.losses import categorical_crossentropy_2d
from src.model import SpliceFormer
from src.evaluation_metrics import print_topl_statistics
import copy
import pyfastx
import gffutils

global IN_MAP,BATCH_SIZE,CL_max,CL,SL,rev_comp_dict,device,n_models,models,base_map

rev_comp_dict = {'A':'T','T':'A','C':'G','G':'C'}
IN_MAP = np.asarray([[0, 0, 0, 0],
                    [1, 0, 0, 0],
                    [0, 1, 0, 0],
                    [0, 0, 1, 0],
                    [0, 0, 0, 1]])
rng = np.random.default_rng(23673)
#ii = int(sys.argv[1])
L = 32
N_GPUS = 4
k = 3
NUM_ACCUMULATION_STEPS=1
# Hyper-parameters:
# L: Number of convolution kernels
# W: Convolution window size in each residual unit
# AR: Atrous rate in each residual unit

W = np.asarray([11, 11, 11, 11, 11, 11, 11, 11,
    21, 21, 21, 21, 41, 41, 41, 41])
AR = np.asarray([1, 1, 1, 1, 4, 4, 4, 4,
                 10, 10, 10, 10, 25, 25, 25, 25])
BATCH_SIZE = 16*k*N_GPUS
k = NUM_ACCUMULATION_STEPS*k
CL = 2 * np.sum(AR*(W-1))

SL=5000
CL_max=40000

from collections import OrderedDict

temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def generate_point_mutations(input_tensor, i, j,base_map,inside_gene,pos,base):
    """
    Generate all point mutations within the range [i, j] for a batch of one-hot encoded genomic sequences.

    Args:
    input_tensor (torch.Tensor): Batch of one-hot encoded genomic sequences with shape (4, n).
    i (int): Start index of the mutation range.
    j (int): End index of the mutation range.

    Returns:
    torch.Tensor: Batch of point mutations with shape (b, 4, n).
    """

    # Extract batch size and sequence length
    #n_in_gene = np.sum(inside_gene[20000:25000])
    n_in_gene = 5000
    #_,_, n = input_tensor[:,:,n_in_gene].shape

    # Create a copy of the input tensor to store point mutations
    # Loop through the mutation range [i, j]
    d = 0
    #for batch in np.array_split(np.arange(i, i+n_in_gene),(n_in_gene)//(BATCH_SIZE//3)+1):
    ref_base = []
    alt_base = []
    idx = []
    c = 0
    batch = np.arange(i, i+n_in_gene)
    mutated_sequence = input_tensor.clone()
    #for pos in batch:
    original_base = torch.argmax(input_tensor[:, pos], dim=0).cpu().numpy()
    possible_bases = [0,1,2,3]
    possible_bases.remove(original_base)
    #for base in possible_bases:
    if inside_gene[pos]==1:
        mutated_sequence[:, pos] = 0
        mutated_sequence[base, pos] = 1
        ref_base.append(base_map[original_base])
        alt_base.append(base_map[base])
        idx.append(d)
    c += 1
    d += 1
    return mutated_sequence, alt_base, ref_base, np.array(idx)


def getDeltas(delta,idx,inside_gene):
    acceptorDelta = delta[:,1,:]*inside_gene
    donorDelta = delta[:,2,:]*inside_gene
    pos_gain_a = torch.argmax(acceptorDelta,dim=1)
    pos_gain_d = torch.argmax(donorDelta,dim=1)
    pos_loss_a = torch.argmax(-acceptorDelta,dim=1)
    pos_loss_d = torch.argmax(-donorDelta,dim=1)

    delta_gain_a = acceptorDelta.gather(1,pos_gain_a.unsqueeze(1)).cpu().numpy()[:,0]
    delta_gain_d = donorDelta.gather(1,pos_gain_d.unsqueeze(1)).cpu().numpy()[:,0]
    delta_loss_a = -acceptorDelta.gather(1,pos_loss_a.unsqueeze(1)).cpu().numpy()[:,0]
    delta_loss_d = -donorDelta.gather(1,pos_loss_d.unsqueeze(1)).cpu().numpy()[:,0]
    pos_gain_a = pos_gain_a.cpu().numpy()
    pos_gain_d = pos_gain_d.cpu().numpy()
    pos_loss_a = pos_loss_a.cpu().numpy()
    pos_loss_d = pos_loss_d.cpu().numpy()
    return delta_gain_a,delta_loss_a,delta_gain_d,delta_loss_d,pos_gain_a-idx,pos_loss_a-idx,pos_gain_d-idx,pos_loss_d-idx

def ceil_div(x, y):
    return int(ceil(float(x)/y))

def one_hot_encode(Xd):
    return IN_MAP[Xd.astype('int8')]

def seqToArray(seq,strand):
    seq = 'N'*(CL_max//2) + seq + 'N'*(CL_max//2)
    seq = seq.upper()
    seq = re.sub(r'[^AGTC]', '0',seq)
    seq = seq.replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
    if strand == '+':
        X0 = np.asarray([int(x) for x in seq])
    elif strand == '-':
        X0 = (5-np.asarray([int(x) for x in seq[::-1]])) % 5  # Reverse complement
    return seq

class mutation_dataset(Dataset):
    def __init__(self, chrom, strand,start,end,ref,SL,CL_max,shift):
        self.gene = gene
        self.chrom = chrom
        self.strand = strand
        self.start = start
        self.end = end
        self.SL = SL
        self.CL_max = CL_max
        self.shift = shift
        length = end-start+1
        self.num_points = ceil_div(length, shift)
        self.ref = 'N'*(CL_max//2) + ref + 'N'*(CL_max//2)
        #self.ref = seqToArray(ref,strand)
        self.length = (end-start)/SL

    def __len__(self):
        return self.num_points*self.SL*3
    
    def reformat_data(X0):
        # This function converts X0, Y0 of the create_datapoints function into
        # blocks such that the data is broken down into data points where the
        # input is a sequence of length SL+CL_max corresponding to SL nucleotides
        # of interest and CL_max context nucleotides, the output is a sequence of
        # length SL corresponding to the splicing information of the nucleotides
        # of interest. The CL_max context nucleotides are such that they are
        # CL_max/2 on either side of the SL nucleotides of interest.

        #num_points = ceil_div(len(X0)-self.CL_max, self.SL)
        Xd = np.zeros((self.num_points, self.SL+self.CL_max))
        
        #for i in range(num_points):
        Xd[i] = X0[self.SL*i:self.CL_max+self.SL*(i+1)]

        return Xd
    
    def getData(self,seq,idx):
        X = np.zeros((5,self.SL+self.CL_max))
        seq = seq[int((CL_max+SL)*idx):int((CL_max+SL)*(idx+1))]
        if len(seq)<self.SL+self.CL_max:
            if strand=='+':
                seq = seq+'N'*(self.SL+self.CL_max-len(seq))
            else:
                seq = 'N'*(self.SL+self.CL_max-len(seq))+seq
        seq = seq.upper()
        seq = re.sub(r'[^AGTC]', '0',seq)

        seq = seq.replace('A', '1').replace('C', '2')
        seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')

        if strand == '+':
            X0 = np.asarray([int(x) for x in seq])
            
        elif strand == '-':
            X0 = (5-np.asarray([int(x) for x in seq[::-1]])) % 5  

        
        X[X0,np.arange(CL_max+SL)] = 1 
        
        return X[1:,:].copy()

    def __getitem__(self, idx):
        idx,pos = idx//(self.SL*3),idx%(self.SL*3)
        pos,base = pos//3,pos%3
        X = self.getData(self.ref,idx)
        #batch_features = torch.Tensor(ref).type(torch.FloatTensor).unsqueeze(dim=0).to(device)
        #ref_pred = ([models[i](batch_features)[0].detach() for i in range(n_models)])
        #ref_pred = torch.stack(ref_pred)
        #ref_pred = torch.mean(ref_pred,dim=0)
        inside_gene = (X.sum(axis=(0)) >= 1)
        mutation,alt_base,ref_base,idx = generate_point_mutations(torch.Tensor(X), self.CL_max//2, self.CL_max//2+SL,base_map,inside_gene,pos,base)
        return X,mutation

def predictMutations(ref,strand,chrom,shift,start,end):
    for i,(mutation,alt_base,ref_base,idx) in enumerate(mutations):
        
        mutation = mutation.type(torch.FloatTensor).to(device)

        alt_pred = ([models[i](mutation)[0].detach() for i in range(n_models)])
        alt_pred = torch.stack(alt_pred)
        alt_pred = torch.mean(alt_pred,dim=0)

        delta = alt_pred-ref_pred
        a1,b1,c1,d1,a2,b2,c2,d2 = getDeltas(delta,CL_max//2+idx,inside_gene)

        if strand == '+':
            df = pd.DataFrame({'CHROM':chrom,'POS':start+shift+idx,'REF':ref_base,'ALT':alt_base,'DS_AG':a1,'DS_AL':b1,'DS_DG':c1,'DS_DL':d1,'DP_AG':a2,'DP_AL':b2,'DP_DG':c2,'DP_DL':d2})
        else:
            df = pd.DataFrame({'CHROM':chrom,'POS':end-shift-idx,'REF':[rev_comp_dict[x] for x in ref_base],'ALT':[rev_comp_dict[x] for x in alt_base],'DS_AG':a1,'DS_AL':b1,'DS_DG':c1,'DS_DL':d1,'DP_AG':-a2,'DP_AL':-b2,'DP_DG':-c2,'DP_DL':-d2})
        if i==0:
            results = df
        else:
            results = pd.concat([results,df],axis=0)
    return results

In [3]:
gtf = gffutils.FeatureDB("/odinn/tmp/benediktj/Data/Gencode_V44/gencode.v44.annotation.db")
fasta = pyfastx.Fasta('/odinn/tmp/benediktj/Data/Gencode_V44/GRCh38.p14.genome.fa')

In [4]:
df = []
for gene in tqdm(gtf.features_of_type('gene')):
    if gene['gene_type'][0] == "protein_coding" and gene[0] != 'chrM':
        df.append([gene[0],gene[3],gene[4],gene[6],gene['gene_name'][0],gene['gene_id'][0]])
    #print(gene)
      #  print()

62700it [00:19, 3283.04it/s]


In [5]:
df = pd.DataFrame(df)
df.columns = ['CHROM','START','END','STRAND','NAME','ID']

#df = np.array_split(df,300)[ii]

base_map = 'ACGT'
for i in tqdm(range(df.shape[0])):
    gene_name,chrom,strand,start,end,gene_id = df.iloc[i,:][['NAME','CHROM','STRAND','START','END','ID']]
    ref = fasta[chrom][start-1:end].seq
    mutations = mutation_dataset(chrom, strand,start,end,ref,SL,CL_max,SL)
    break
    #ref = seqToArray(ref,strand)

  0%|                                                                                            | 0/20033 [00:00<?, ?it/s]


In [11]:
loader = torch.utils.data.DataLoader(mutations, batch_size=BATCH_SIZE, shuffle=False, num_workers=32, pin_memory=False)

In [12]:
len(ref)

6167

In [13]:
def predictSplicing(seq,models):
    outputs = []
    for i in range(seq.shape[0]):
        batch_features = torch.tensor(seq[i,:,:], device=device).float().unsqueeze(0)
        batch_features = torch.swapaxes(batch_features,1,2)
       
    
    outputs = torch.cat(outputs,dim=2)
    outputs = outputs.cpu().detach().numpy()
    return outputs

In [14]:
NUM_ACCUMULATION_STEPS = 1
temp = 1
n_models = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m = SpliceFormer(CL_max,bn_momentum=0.01/NUM_ACCUMULATION_STEPS,depth=4,heads=4,n_transformer_blocks=2,determenistic=True)
model_m.apply(keras_init)
model_m = model_m.to(device)

if torch.cuda.device_count() > 1:
    model_m = nn.DataParallel(model_m)

output_class_labels = ['Null', 'Acceptor', 'Donor']

#for output_class in [1,2]:
models = [copy.deepcopy(model_m) for i in range(n_models)]

[model.load_state_dict(torch.load('../Results/PyTorch_Models/transformer_encoder_40k_finetune_rnasplice-blood_all_050623_{}'.format(i))) for i,model in enumerate(models)]

for model in models:
    model.eval()

In [15]:
for (ref,alt) in tqdm(loader):
    ref = ref.to(device).float()
    alt = alt.to(device).float()
    ref_pred = ([models[i](ref)[0].detach() for i in range(n_models)])
    ref_pred = torch.mean(torch.stack(ref_pred),dim=0)
    alt_pred = ([models[i](alt)[0].detach() for i in range(n_models)])
    alt_pred = torch.mean(torch.stack(alt_pred),dim=0)
    delta = ref_pred-alt_pred

 24%|████████████████████▌                                                                | 38/157 [06:00<18:49,  9.49s/it]


KeyboardInterrupt: 