In [1]:
import io
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

import re
import sys
import h5py
from math import ceil

CL_max = 100000
SL=5000

In [2]:
import pyfastx
import gffutils

In [3]:
#gffutils.create_db('/odinn/tmp/benediktj/Data/SplicePrediction/gencode.v39.annotation.gtf', "/odinn/tmp/benediktj/Data/SplicePrediction/gencode.v39.annotation.db", force=True,disable_infer_genes=True, disable_infer_transcripts=True)

In [4]:
#gtf = gffutils.FeatureDB("/odinn/tmp/benediktj/Data/SplicePrediction/gencode.v39.annotation.db")

In [5]:
#fasta = pyfastx.Fasta('/odinn/tmp/benediktj/SpliceAITrainingCode/hg38.fa')
fasta = pyfastx.Fasta('/odinn/tmp/benediktj/SpliceAITrainingCode/hg19.fa')

In [6]:
gencode_test = pd.read_csv('/odinn/tmp/benediktj/SpliceAITrainingCode/gencode_test.tsv',sep='\t')

In [7]:
#canonical_data.columns = ['name','?','chrm','strand','tx_start','tx_end','acceptor','donor']

In [8]:
gencode_test

Unnamed: 0,#name,chrom,strand,txStart,txEnd,exonEnds,exonStarts
0,AZIN2,chr1,+,33546714,33585995,"33546895,33547109,33547413,33547955,33549728,3...","33546989,33547202,33547779,33549555,33557651,3..."
1,PRUNE,chr1,+,150980953,151008189,"150981147,150990380,150991145,150997271,150998...","150990288,150990943,150997087,150997991,150999..."
2,C1orf21,chr1,+,184356192,184598154,"184356502,184446737,184476816,184559949,184567...","184446520,184476722,184559873,184567535,184588..."
3,LIN9,chr1,-,226418858,226497434,"226420307,226420894,226421224,226426797,226438...","226420797,226421045,226426672,226438540,226453..."
4,C1orf159,chr1,-,1017203,1051478,"1018367,1019763,1019886,1021392,1022584,102297...","1019733,1019861,1021258,1022519,1022882,102573..."
...,...,...,...,...,...,...,...
1647,NELFB,chr9,+,140149759,140167999,"140150063,140150522,140150880,140151506,140157...","140150359,140150781,140151276,140157489,140158..."
1648,NSMF,chr9,-,140342025,140353786,"140343943,140344126,140344477,140344707,140346...","140344051,140344375,140344628,140346817,140347..."
1649,MRPL41,chr9,+,140445651,140447007,140446430,140446526
1650,DPH7,chr9,-,140449361,140473387,"140450100,140459058,140459410,140459606,140468...","140458886,140459345,140459537,140468660,140469..."


In [9]:
def create_datapoints(seq, strand, tx_start, tx_end, jn_start, jn_end):
    # This function first converts the sequence into an integer array, where
    # A, C, G, T, N are mapped to 1, 2, 3, 4, 0 respectively. If the strand is
    # negative, then reverse complementing is done. The splice junctions 
    # are also converted into an array of integers, where 0, 1, 2, -1 
    # correspond to no splicing, acceptor, donor and missing information
    # respectively. It then calls reformat_data and one_hot_encode
    # and returns X, Y which can be used by Keras models.
    #seq = fasta[chrom][int(gene_start)-CL_max//2-1:int(gene_end)+CL_max//2]
    #seq = seq.seq
    #seq = 'N'*(CL_max//2) + seq[CL_max//2:-CL_max//2] + 'N'*(CL_max//2)
    seq = 'N'*(CL_max//2) + seq + 'N'*(CL_max//2)
    # Context being provided on the RNA and not the DNA
    #.replace('R', '0').replace('M', '0').replace('B', '0').replace('W', '0')
    seq = seq.upper()
    seq = re.sub(r'[^AGTC]', '0',seq)
    seq = seq.replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')


    tx_start = int(tx_start)
    tx_end = int(tx_end) 

    jn_start = [[x for x in jn_start]]
    jn_end = [[x for x in jn_end]]
    
    if strand == '+':

        X0 = np.asarray([int(x) for x in seq])
        Y0 = [-np.ones(tx_end-tx_start+1)]
        
        for t in range(1):
            
            if len(jn_start[t]) > 0:
                Y0[t] = np.zeros(tx_end-tx_start+1)
                for c in jn_start[t]:
                    if tx_start <= c <= tx_end:
                        Y0[t][c-tx_start] = 2
                for c in jn_end[t]:
                    if tx_start <= c <= tx_end:
                        Y0[t][c-tx_start] = 1
                    # Ignoring junctions outside annotated tx start/end sites
                     
    elif strand == '-':

        X0 = (5-np.asarray([int(x) for x in seq[::-1]])) % 5  # Reverse complement
        Y0 = [-np.ones(tx_end-tx_start+1)]

        for t in range(1):

            if len(jn_start[t]) > 0:
                Y0[t] = np.zeros(tx_end-tx_start+1)
                for c in jn_end[t]:
                    if tx_start <= c <= tx_end:
                        Y0[t][tx_end-c] = 2
                for c in jn_start[t]:
                    if tx_start <= c <= tx_end:
                        Y0[t][tx_end-c] = 1
    

    Xd, Yd = reformat_data(X0, Y0)
    X, Y = one_hot_encode(Xd, Yd)

    return X, Y

def reformat_data(X0, Y0):
    # This function converts X0, Y0 of the create_datapoints function into
    # blocks such that the data is broken down into data points where the
    # input is a sequence of length SL+CL_max corresponding to SL nucleotides
    # of interest and CL_max context nucleotides, the output is a sequence of
    # length SL corresponding to the splicing information of the nucleotides
    # of interest. The CL_max context nucleotides are such that they are
    # CL_max/2 on either side of the SL nucleotides of interest.

    num_points = ceil_div(len(Y0[0]), SL)

    Xd = np.zeros((num_points, SL+CL_max))
    Yd = [-np.ones((num_points, SL)) for t in range(1)]

    X0 = np.pad(X0, [0, SL], 'constant', constant_values=0)
    Y0 = [np.pad(Y0[t], [0, SL], 'constant', constant_values=-1) for t in range(1)]

    for i in range(num_points):
        Xd[i] = X0[SL*i:CL_max+SL*(i+1)]

    for i in range(num_points):
        Yd[0][i] = Y0[0][SL*i:SL*(i+1)]

    return Xd, Yd

def ceil_div(x, y):

    return int(ceil(float(x)/y))


IN_MAP = np.asarray([[0, 0, 0, 0],
                     [1, 0, 0, 0],
                     [0, 1, 0, 0],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])
# One-hot encoding of the inputs: 0 is for padding, and 1, 2, 3, 4 correspond
# to A, C, G, T respectively.

OUT_MAP = np.asarray([[1, 0, 0],
                      [0, 1, 0],
                      [0, 0, 1],
                      [0, 0, 0]])

def one_hot_encode(Xd, Yd):

    return IN_MAP[Xd.astype('int8')], \
           [OUT_MAP[Yd[t].astype('int8')] for t in range(1)]

In [10]:
def createDataset(gencode,setType,data_dir):
    h5f2 = h5py.File(data_dir + 'gencode_{}k_dataset'.format(CL_max//1000)
                        + '_' + setType + '_' 
                        + '.h5', 'w')

    if setType == 'train':
        CHROM_GROUP = ['chr11', 'chr13', 'chr15', 'chr17', 'chr19', 'chr21',
                           'chr2', 'chr4', 'chr6', 'chr8', 'chr10', 'chr12',
                           'chr14', 'chr16', 'chr18', 'chr20', 'chr22', 'chrX', 'chrY']
    elif setType == 'test':
        CHROM_GROUP = ['chr1', 'chr3', 'chr5', 'chr7', 'chr9']
    else:
        CHROM_GROUP = ['chr1', 'chr3', 'chr5', 'chr7', 'chr9',
                       'chr11', 'chr13', 'chr15', 'chr17', 'chr19', 'chr21',
                       'chr2', 'chr4', 'chr6', 'chr8', 'chr10', 'chr12',
                       'chr14', 'chr16', 'chr18', 'chr20', 'chr22', 'chrX', 'chrY']

    CHUNK_SIZE = 100
    idx = 0
    
    #if os.path.exists('{}/annotation_gencode_v39_{}.txt'.format(data_dir,setType)):
    #    os.remove('{}/annotation_gencode_v39_{}.txt'.format(data_dir,setType))

    for gene_nr in tqdm(range(gencode.shape[0])):
        chrom = gencode.iloc[gene_nr,1]
        strand = gencode.iloc[gene_nr,2]
        #print(gene)
        exon_junctions = []
        tx_start = int(gencode.iloc[gene_nr,3])
        tx_end = int(gencode.iloc[gene_nr,4])
        jn_start = [int(x) for x in gencode.iloc[gene_nr,5].split(',')[:-1]]
        jn_end = [int(x) for x in gencode.iloc[gene_nr,6].split(',')[:-1]]
        

        try:
            seq = fasta[chrom][int(tx_start)-1:int(tx_end)]
            seq = seq.seq
        except:
            print('Failed reading fasta file for {}:{}-{}'.format(chrom,tx_start,tx_end))
            print('SKIPPING')
            continue
            
            
        X, Y = create_datapoints(seq, strand,tx_start,tx_end, jn_start, jn_end)
        if idx%CHUNK_SIZE==0:
            X_batch = []
            Y_batch = [[]]

        X_batch.extend(X)
        Y_batch[0].extend(Y[0])

        if idx%CHUNK_SIZE==CHUNK_SIZE-1:
            X_batch = np.asarray(X_batch).astype('int8')
            Y_batch[0] = np.asarray(Y_batch[0]).astype('int8')
            lastIdxSave = idx
            h5f2.create_dataset('X' + str(idx//CHUNK_SIZE), data=X_batch)
            h5f2.create_dataset('Y' + str(idx//CHUNK_SIZE), data=Y_batch)
        idx += 1

    if lastIdxSave != idx-1:
        X_batch = np.asarray(X_batch).astype('int8')
        Y_batch[0] = np.asarray(Y_batch[0]).astype('int8')
        h5f2.create_dataset('X' + str(lastIdxSave//CHUNK_SIZE+1), data=X_batch)
        h5f2.create_dataset('Y' + str(lastIdxSave//CHUNK_SIZE+1), data=Y_batch)

    h5f2.close()

In [11]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-050422/'

In [12]:
createDataset(gencode_test,'test',data_dir)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1652/1652 [04:09<00:00,  6.63it/s]


In [11]:
#createDataset('train',data_dir)

58051it [11:50, 81.74it/s]  
