In [1]:
import numpy as np
import pandas as pd
from scipy.stats import gmean
from tqdm import tqdm
import gffutils
import os.path
from liftover import ChainFile
from collections import defaultdict
from math import ceil
import pyfastx
import re
from scipy.sparse.csgraph import connected_components
from glob import glob
from scipy.sparse import lil_matrix,csr_matrix,coo_matrix,dok_matrix, save_npz
import pickle
#
blacklist = pd.read_csv('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/ENCFF220FIN.bed',header=None,sep='\t')



In [2]:
simpileRepeats = pd.read_csv('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/simpleRepeat.txt',header=None,sep='\t')

In [3]:
#gffutils.create_db('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX/gencode.v42.annotation.gtf', "/odinn/tmp/benediktj/Data/SplicePrediction-GTEX/gencode.v42.annotation.db", force=True,disable_infer_genes=True, disable_infer_transcripts=True)

In [4]:
gtf_gencode = gffutils.FeatureDB('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/gencode.v42.annotation.db')

In [5]:
leafcutterFiles = glob('/nfs/odinn/users/solvir/GTEx/GTEx_Analysis_v8_sQTL_leafcutter_counts/*_perind_numers.counts.gz')

In [6]:
#junctions = pd.read_csv('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/GTEx_Analysis_2017-06-05_v8_STARv2.5.3a_junctions.gct', skiprows=2,sep='\t')

In [7]:
def findLowCountJunctions(junctions,blacklist,simpileRepeats):
    tmp2 = junctions.iloc[:,2:]
    transcriptCount = np.sum(tmp2,axis=1)
    tmp2.columns = [x.split('-')[1] for x in junctions.columns[2:]]
    tmp2 = tmp2.T.groupby(tmp2.columns).sum().T

    includeJunction = np.sum(tmp2 > 0,axis=1)>=4

    discardJunctionDict = defaultdict(bool)
    discardReason = defaultdict(bool)
    prev_chrom = 'chr1'
    blacklist_chrom = blacklist[blacklist[0]==prev_chrom]
    simpileRepeats_chrom = simpileRepeats[simpileRepeats[1]==prev_chrom]
    
    for i,junction in tqdm(enumerate(junctions.Name.values)):
        chrom,start,end = junctions.iloc[i,0].split('_')
        start,end = int(start),int(end)
        cond1 = includeJunction[i]==False
        
        if cond1:
            geneID = junctions.iloc[i,1].split('.')[0]
            discardJunctionDict[junction+'_'+geneID] = True
            discardReason[junction+'_'+geneID] = 'LowReadCount'
            prev_chrom = chrom
            continue
            
        if prev_chrom != chrom:
            blacklist_chrom = blacklist[blacklist[0]==chrom]
        tmp = blacklist_chrom[np.logical_not(np.any([blacklist_chrom[1]>=end, blacklist_chrom[2]<=start],0))]
        cond2 = False
        for i_b in range(tmp.shape[0]):
            if (np.isin(start, range(tmp.iloc[i_b,1],tmp.iloc[i_b,2]+1)) or np.isin(end, range(tmp.iloc[i_b,1],tmp.iloc[i_b,2]+1))):
                cond2 = True
                break
                
        if cond2:
            geneID = junctions.iloc[i,1].split('.')[0]
            discardJunctionDict[junction+'_'+geneID] = True
            discardReason[junction+'_'+geneID] = 'InBlacklistedRegion'
            prev_chrom = chrom
            continue
                
        if prev_chrom != chrom:
            simpileRepeats_chrom = simpileRepeats[simpileRepeats[1]==chrom]
        
        tmp = simpileRepeats_chrom[np.logical_not(np.any([simpileRepeats_chrom[2]>=end,simpileRepeats_chrom[3]<=start],0))]
        cond3 = False
        for i_b in range(tmp.shape[0]):
            if (np.isin(start, range(tmp.iloc[i_b,2],tmp.iloc[i_b,3]+1)) or np.isin(end, range(tmp.iloc[i_b,2],tmp.iloc[i_b,3]+1))):
                cond3 = True
                break
        
        if cond3:
            geneID = junctions.iloc[i,1].split('.')[0]
            discardJunctionDict[junction+'_'+geneID] = True
            discardReason[junction+'_'+geneID] = 'InRepeatRegion'
            prev_chrom = chrom
            continue
        
        prev_chrom = chrom
        
    return discardJunctionDict,discardReason

In [8]:
#discardJunctionDict,discardReason = findLowCountJunctions(junctions,blacklist,simpileRepeats)

In [9]:
with open('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/discardJunctions.pkl', 'rb') as f:
    discardJunctionDict = pickle.load(f)

with open('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/discardReason.pkl', 'rb') as f:
     discardReason = pickle.load(f)


In [10]:
#with open('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/discardJunctions.pkl', 'wb') as f:
#    pickle.dump(discardJunctionDict, f)

In [11]:
#with open('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/discardReason.pkl', 'wb') as f:
#    pickle.dump(discardReason, f)

In [12]:
leafcutterFiles = glob('/odinn/tmp/bjarnih/RNA/leafCutter/GTEx/*/meta_results/clusters_*_summary.tab')
chrmToLeafcutterFiles = defaultdict(list)
for file in leafcutterFiles:
    chrm = file.split('/')[-1].split('_')[1]
    chrmToLeafcutterFiles[chrm].append(file)

In [13]:
def getCombinedLeafCutterDF(chrm,chrmToLeafcutterFiles):
    files = chrmToLeafcutterFiles[chrm]
    for i in range(len(files)):
        if i == 0:
            df = pd.read_csv(files[i],sep='\t')
        else:
            df = pd.concat([df,pd.read_csv(files[i],sep='\t')],axis=0)
            df = df.drop_duplicates('splice_event_id')
    return df.sort_values('Start')

In [14]:
fasta = pyfastx.Fasta('/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/GRCh38.p13.genome.fa')

In [15]:
data_dir = '/odinn/tmp/benediktj/Data/SplicePrediction-GTEX-V8/'

seqData = {}

CHROM_GROUP = ['chr1', 'chr3', 'chr5', 'chr7', 'chr9',
'chr11', 'chr13', 'chr15', 'chr17', 'chr19', 'chr21',
'chr2', 'chr4', 'chr6', 'chr8', 'chr10', 'chr12',
'chr14', 'chr16', 'chr18', 'chr20', 'chr22', 'chrX', 'chrY']

for chrom in CHROM_GROUP:
        seqData[chrom] = dok_matrix((len(fasta[chrom]), 5), dtype=np.int8)

seqData[chrom] = dok_matrix((len(fasta[chrom]), 5), dtype=np.int8)

In [16]:
def create_datapoints(seq, strand, tx_start, tx_end):
    # This function first converts the sequence into an integer array, where
    # A, C, G, T, Missing are mapped to 1, 2, 3, 4, 5 respectively. If the strand is
    # negative, then reverse complementing is done. . It then calls reformat_data and one_hot_encode

    seq = seq.upper()
    seq = re.sub(r'[^AGTC]', '5',seq)
    seq = seq.replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4')

    tx_start = int(tx_start)
    tx_end = int(tx_end) 

    Y_idx = []
    
    X0 = np.asarray([int(x) for x in seq])

    X = one_hot_encode(X0)

    return X

def ceil_div(x, y):
    return int(ceil(float(x)/y))


IN_MAP = np.asarray([[0, 0, 0, 0,0],
                     [1, 0, 0, 0,0],
                     [0, 1, 0, 0,0],
                     [0, 0, 1, 0,0],
                     [0, 0, 0, 1,0],
                    [0, 0, 0, 0,1]])
# One-hot encoding of the inputs: 0 is for padding, and 1, 2, 3, 4 correspond
# to A, C, G, T, Missing respectively.

OUT_MAP = np.asarray([[1, 0, 0],
                      [0, 1, 0],
                      [0, 0, 1],
                      [0, 0, 0]])

def one_hot_encode(Xd):
    return IN_MAP[Xd.astype('int8')]

def getJunctions(gtf,transcript,strand):
    #transcript = gtf[transcript_id.split('.')[0]]
    exon_junctions = []
    tx_start = int(transcript[3])
    tx_end = int(transcript[4])
    exons = gtf.children(transcript, featuretype="exon")
    for exon in exons:
        exon_start = int(exon[3])
        exon_end = int(exon[4])
        exon_junctions.append((exon_start,exon_end))

    intron_junctions = []

    if strand=='+':
        intron_start = exon_junctions[0][1]
        for i,exon_junction in enumerate(exon_junctions[1:]):
            intron_end = exon_junction[0]
            intron_junctions.append((intron_start,intron_end))
            if i+1 != len(exon_junctions[1:]):
                intron_start = exon_junction[1]

    elif strand=='-':
        exon_junctions.reverse()
        intron_start = exon_junctions[0][1]
        for i,exon_junction in enumerate(exon_junctions[1:]):
            intron_end = exon_junction[0]
            intron_junctions.append((intron_start,intron_end))
            if i+1 != len(exon_junctions[1:]):
                intron_start = exon_junction[1]

    return np.array(intron_junctions)

In [17]:
transcripts = gtf_gencode.features_of_type('transcript')
gene_to_label = {}
save_seq = True
prev_chrom = 'chr1'
leaf_cutter_junctions = getCombinedLeafCutterDF(prev_chrom,chrmToLeafcutterFiles)
for transcript in tqdm(transcripts): 
    chrom,gene_start,gene_end,strand,gene_id,transcript_id,gene_type,gene_name,level = transcript[0],transcript[3],transcript[4],transcript[6],transcript[8]['gene_id'][0],transcript[8]['transcript_id'][0],transcript[8]['gene_type'][0],transcript[8]['gene_name'][0],transcript[8]['level'][0]
    
    try:
        cond1 = 'Ensembl_canonical' in transcript[8]['tag']
        cond2 = gene_type=='protein_coding'
        cond3 = int(level)<3
        if cond1 and cond2 and cond3:
            intron_junctions = getJunctions(gtf_gencode,transcript,strand)
            junction_starts = defaultdict(int)
            junction_ends = defaultdict(int)
            
           
            if len(intron_junctions>0):
                if chrom!=prev_chrom:
                    leaf_cutter_junctions = getCombinedLeafCutterDF(chrom,chrmToLeafcutterFiles)
                for junction in intron_junctions:
                    junction_starts[junction[0]] = 1
                    junction_ends[junction[1]] = 1
                simple_gene_id = gene_id.split('.')[0]
                alt_junctions = leaf_cutter_junctions[leaf_cutter_junctions['Gene_id']==simple_gene_id]
                if alt_junctions.shape[0]>0:
                    clusters = defaultdict(int)
                    for i,pos in enumerate(alt_junctions['Start']):
                        if junction_starts[pos] == 1:
                            clusters[alt_junctions.iloc[i,:]['ClusterID']] = 1
                    for i,pos in enumerate(alt_junctions['End']):
                        if junction_ends[pos] == 1:
                            clusters[alt_junctions.iloc[i,:]['ClusterID']] = 1
                    for cluster in clusters.keys():
                        cluster_junctions = alt_junctions[alt_junctions['ClusterID']==cluster][['Start','End']]
                        for i_junc in range(cluster_junctions.shape[0]):
                            start,end = cluster_junctions.iloc[i_junc,:]['Start'],cluster_junctions.iloc[i_junc,:]['End']
                            junction_id = '{}_{}_{}_{}'.format(chrom,start,end,simple_gene_id)
                            if not discardJunctionDict[junction_id]:
                                junction_starts[start] = 1
                                junction_ends[end] = 1
                    
                    junction_starts = {k:v for k,v in junction_starts.items() if v != 0}
                    junction_ends = {k:v for k,v in junction_ends.items() if v != 0}
                    gene_to_label[gene_id] = [junction_starts, junction_ends]
            
                    if save_seq:
                        seq = fasta[chrom][int(gene_start)-1:int(gene_end)]
                        seq = seq.seq
                        X = create_datapoints(seq, strand, gene_start, gene_end)
                        seqData[chrom][int(gene_start)-1:int(gene_end)] = X
                        jn_start = list(junction_starts.keys())
                        jn_end = list(junction_ends.keys())
                        name = '{}\t{}\t{}\t{}'.format(gene_name,gene_id,transcript_id,level)

                        if strand=='+':
                            with open('{}/annotation_GTEX_v8.txt'.format(data_dir), 'a') as the_file:
                                the_file.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(name,chrom,strand,gene_start,gene_end,','.join([str(x) for x in jn_start]),','.join([str(x) for x in jn_end])))
                        if strand=='-':
                            with open('{}/annotation_GTEX_v8.txt'.format(data_dir,), 'a') as the_file:
                                the_file.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(name,chrom,strand,gene_start,gene_end,','.join([str(x) for x in jn_end]),','.join([str(x) for x in jn_start])))

                        if chrom!=prev_chrom:
                            save_npz('{}/sparse_sequence_data/{}.npz'.format(data_dir,prev_chrom), seqData[prev_chrom].tocoo())
                            del seqData[prev_chrom]

                        prev_chrom = chrom
                
    except:
        pass
        #print(gene[2])
        #print(gene[8]['transcript_support_level'])

252416it [5:16:39, 13.29it/s]  


In [18]:
save_npz('{}/sparse_sequence_data/{}.npz'.format(data_dir,prev_chrom), seqData[prev_chrom].tocoo())
#del seqData[prev_chrom]

In [19]:
with open('{}/gene_to_label.pickle'.format(data_dir), 'wb') as handle:
    pickle.dump(gene_to_label, handle, protocol=pickle.HIGHEST_PROTOCOL)