In [None]:
#!/usr/bin/env python
# coding: utf-8

print('reading libraries...')

from os.path import exists
from bx.intervals.io import GenomicIntervalReader
from bx.bbi.bigwig_file import BigWigFile
import numpy as np
import time
import sys
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import pickle 
import helper_functions as hf
import sigfig as sf

def get_list_from_counts(value2cont):
    out=[]
    for v,c in value2cont.items():
        out+=[v]*c
    return out

# Inputs
inDir='data-outputted/'

fn2genome={
    'GSM2218592_B_Ets1_R1_bow':'mm9', # b cells, Saelee P et al
    'GSM4110116_Ets1_S0_1_129929':'hg19', # t cells, McCarter AC et al
    'GSM3520734_NK_ETS1_chip':'hg38', # nk cells, Taveirne S et al
}

print('loading genomes...')
genome2chr2seq={
    'hg38':hf.faLoadGenome('hg38.fa'),
    'hg19':hf.faLoadGenome('hg19.fa'),
    'mm9' :hf.faLoadGenome('mm9.fa')
}

# exit early to plot fast
beta=False

# set up plot parameters
setYlimCustom=False
ymax=3
ymin=0

# remove mitochondrial chromosome
genome2chr2seq[genome].pop('chrM')

for fn,genome in fn2genome.items():
    
    print(f'loading {fn} bigwig...')
    bw = BigWigFile( open( f'../../data-downloaded/{fn}.bigWig' ,'rb') )
    
    # initialize dataset
    merged_Chr2Kmer2Data={}

    print('\tloading ets kmer data...')
    
    # remove non chromosomal contigs from genome 
    rm_contigs = []
    for i in genome2chr2seq[genome].keys():
        if i[0:3] != 'chr':
            rm_contigs.append(i)
    if len(rm_contigs) != 0:
        for x in rm_contigs:
            genome2chr2seq[genome].pop(x)
    
    
    for i,chrom in enumerate(genome2chr2seq[genome]):

        merged_Chr2Kmer2Data[chrom]={}
        
        # load data processed in FS2ABi-Preprocess-Data.py
        fnPickle=f'{inDir}/{fn}/{fn}__Kmer2Data__beta=False__chrom={chrom}.pydict.pickle'
        
        if exists(fnPickle)==False: continue
        
        with open(fnPickle,'rb') as f: 
            Kmer2Data=pickle.load(f)
        
        # iterate over data
        for kmer in Kmer2Data:

            if kmer not in merged_Chr2Kmer2Data[chrom]:
                merged_Chr2Kmer2Data[chrom][kmer]={'pbm-aff':Kmer2Data[kmer]['pbm-aff']}
                merged_Chr2Kmer2Data[chrom][kmer]['ChipSignal2Count']={}

            for chipSignal in Kmer2Data[kmer]['chip-signal-list']:

                if chipSignal not in merged_Chr2Kmer2Data[chrom][kmer]['ChipSignal2Count']:
                    merged_Chr2Kmer2Data[chrom][kmer]['ChipSignal2Count'][chipSignal]=0
                merged_Chr2Kmer2Data[chrom][kmer]['ChipSignal2Count'][chipSignal]+=1

        if i>1 and beta: break
        
    # bin ets sites by their affinity (ie .1, .2, ... 1.0)
    print('\tbinning data...')
    bins=np.arange(0,1.1,.1)

    c2v={'bin':[],'signal':[],'chrom':[]}

    for i,chrom in enumerate(merged_Chr2Kmer2Data):

        for kmer in merged_Chr2Kmer2Data[chrom]:
            binIdx=np.digitize(merged_Chr2Kmer2Data[chrom][kmer]['pbm-aff'], bins,right=True)
            merged_Chr2Kmer2Data[chrom][kmer]['bin']=bins[binIdx]

        x2y={} # mapping xposition to y value of binned data at that xposition
        
        for kmer in merged_Chr2Kmer2Data[chrom]:
            x=merged_Chr2Kmer2Data[chrom][kmer]['bin']
            y=get_list_from_counts(merged_Chr2Kmer2Data[chrom][kmer]['ChipSignal2Count'])

            if x not in x2y:
                x2y[x]=[]

            x2y[x]+=y


        for xi,yiList in x2y.items():
            for yi in yiList:
                c2v['bin'].append(round(xi,3))
                c2v['signal'].append(yi)
                c2v['chrom'].append(chrom)


        if i>1 and beta: break

    # convert to dataframe
    df=pd.DataFrame(c2v)
    
    # drop infinites and na
    df=df.replace(np.inf, np.nan).replace(-np.inf, np.nan).dropna()

    # begin plot
    fig,ax=plt.subplots(1,2,figsize=(13,3),dpi=300,gridspec_kw={'width_ratios': [5,1]})


    #####################################################################
    # averages
    #####################################################################

    avgFunc=np.mean
    errFunc=stats.sem

    minAvg=+np.inf
    maxAvg=-np.inf


    avgList=[]
    errList=[]
    biList=[]
    space=.03

    # plot average of each column
    swarmDF={c:[] for c in ['bin','agg-signal','chrom']}
    for chrom in df.chrom.unique():
        for bi in sorted(df.bin.unique()):

            biData=df.loc[(df['bin']==bi) & (df['chrom']==chrom),'signal']
            biAvg=avgFunc(biData)
            biErr=errFunc(biData)

            avgList.append(biAvg)
            errList.append(biErr)

            if biAvg<minAvg: minAvg=biAvg
            if biAvg>maxAvg: maxAvg=biAvg

            biList .append(float(bi))

            swarmDF['bin'].append(bi)
            swarmDF['agg-signal'].append(biAvg)
            swarmDF['chrom'].append(chrom)

    swarmDF=pd.DataFrame(swarmDF)

    # set custon y-axis limits 
    if setYlimCustom:
        ax[0].set_ylim(ymin,ymax)

    sns.swarmplot(x='bin',y='agg-signal',data=swarmDF,ax=ax[0],palette='Blues',size=2)


        
    #####################################################################
    # trendlines
    #####################################################################

    pround=10

    sCorrList=[]
    pCorrList=[]

    for endPoint in bins[2:]:

        corrDF=swarmDF.loc[swarmDF['bin']<=endPoint ,:]
        corrDF=corrDF[~corrDF.isin([np.nan, np.inf, -np.inf]).any(1)]
        corrDF=corrDF.replace(np.inf, np.nan).replace(-np.inf, np.nan).dropna()

        pr,pp=[round(i,pround) for i in stats.pearsonr( corrDF['bin'], corrDF['agg-signal']) ]
        sr,sp=[round(i,pround) for i in stats.spearmanr(corrDF['bin'], corrDF['agg-signal']) ]

        sCorrList.append((sr,sp))
        pCorrList.append((pr,pp))

    fitDF=swarmDF.loc[~swarmDF['chrom'].isin(['chrM','chrY']),:].dropna(how='any')

    a, b = np.polyfit(fitDF['bin'],fitDF['agg-signal'], 1)

    X=range(len(swarmDF['bin'].unique()))
    Y= a * np.array(sorted(swarmDF['bin'].unique())) + b


    ax[0].plot(X,Y,zorder=100,lw=1,color='blue')      
                

    #####################################################################
    # labels
    #####################################################################


    ax[0].spines['top'].set_visible(None)
    ax[0].spines['right'].set_visible(None)
    ax[0].set_xlabel('Affinity Bin')
    ax[0].set_ylabel('ETS ChIP-seq Signal\nOver GGAW 8mer')

    # change scale
    if not setYlimCustom:
        ymin=minAvg-.5
        ymax=maxAvg+.5
    ax[0].set_ylim(ymin,ymax)
    
    xticklabels=[]
    sigfigRhoRound=3
    for (pr,pp),(sr,sp),xlabi in zip([('','')]+pCorrList,[('','')]+sCorrList,bins[1:]):
        if pr!='':
            newLab=f'{sf.round(xlabi,1)}\np={sf.round(pr,sigfigRhoRound)}\ns={sf.round(sr,sigfigRhoRound)}'
        else:
            newLab=f'{sf.round(xlabi,1)}'
        xticklabels.append(newLab)
        
    ax[0].set_xticklabels(xticklabels)

    
    #####################################################################
    # Noises from random kmers
    #####################################################################

    # # Add "noise level" from random kmers
    print('\tgenerating random kmer data...')
    genomeLen,Chrom2NumStartEnd=hf.generate_Chrom2NumStartEnd(genome2chr2seq[genome])

    chipSignalRandomKmers=[]
    
    if beta:randomSamplesize=1000
    else:   randomSamplesize=int(len(df)/1000) # there are 10 bins, so take average  bin size

    allowedChroms=set(swarmDF['chrom'])
    randSwarmDF={c:[]  for c in  ['chrom','signal']}
    for i in range(randomSamplesize):

        loop=True
        while loop:
            chrom,start=hf.choose_random_chrom_pos(genomeLen,Chrom2NumStartEnd)
            
            if chrom not in allowedChroms: continue
            
            end=start+8
            seq=genome2chr2seq[genome][chrom][start:end]

            # require real seq
            if 'N' in seq: continue

            # requre non-ets site
            for etsCore in ['GGAA','GGAT','TTCC','ATCC']:
                if etsCore in seq: continue

            # if all is well, parse chip signal
            result=bw.query(chrom, start, start+8, 1)
            if result!=None:
                result=result[0]['mean']

                if str(result)!='nan':
                    chipSignalRandomKmers.append(result)
                    randSwarmDF['chrom'].append(chrom)
                    randSwarmDF['signal'].append(result)
                    loop=False

    randSwarmDF=pd.DataFrame(randSwarmDF)

    randSwarmDF=randSwarmDF.groupby('chrom').apply(avgFunc).reset_index()

    ax[1].set_xlim(-1,1)

    sns.swarmplot(y='signal',data=randSwarmDF,ax=ax[1],size=3,color='lightgrey')

    ax[1].spines['top'].set_visible(None)
    ax[1].spines['right'].set_visible(None)
    ax[1].set_xlabel(f'Random Non-ETS Kmers\nN={randomSamplesize:,}')
    ax[1].set_ylabel('ETS ChIP-seq Signal')

    ax[1].set_ylim(ymin,ymax)

    ax[1].set_xticks([0])
    
    plt.show()