In [None]:
import os
import sys
import pandas as pd 
import argparse
import numpy as np 
import glob
import time 
import logging
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle

def plot_metagene(result, plot_group = 'alias', facet_by = None, x_col = 'pos', y_col = 'count', shade = None, title = None, sharey = False, vline = False, ylim = None) : 
    
    plot = sns.FacetGrid(result, 
                         col = facet_by, 
                         hue = plot_group, 
                         height = 4, 
                         aspect = 1.3)

    plot.map(sns.lineplot, x_col, y_col)
        
    plot.add_legend()
    
    if shade : 
            result['m'] = result[y_col] - result[shade]
            result['M'] = result[y_col] + result[shade]
            plot.map(plt.fill_between, x_col, 'm', 'M', alpha=0.2)
            
    if vline : 
        plt.axvline(x = vline, color = 'grey', label = f'{vline} bp', alpha = 0.5)
    
    
    if title : 
        plot.fig.suptitle(f'{title}')
        
    if ylim : 
        plot.set(ylim=(0, ylim))
    
    return plot

def merge_and_plot(conditions, result, count = 'count', vline = False, ylim = None, facet_by = 'type') :
    
    result_merge = result.merge(
        conditions, how = 'left', left_on = 'alias', right_on = 'simple_name')
    
    result_grouped = result_merge.groupby(
        ['pos', 'type', 'condition']
    ).agg(
        M = (count, 'mean'),
        S = (count, 'std')
    ).reset_index()
    
    plot_metagene(result_grouped, 
                  shade = 'S', 
                  facet_by='type', 
                  x_col = 'pos', 
                  y_col = 'M', 
                  plot_group = 'condition',
                  vline = vline,
                  ylim = ylim
                 )
    
    return result_grouped

def parse_transcripts(transcripts) : 
    
    sizes = {}
    seqs = {}
    seq = ''
    i = 0
    with open(transcripts, 'r') as f : 
        for line in f :
            if line.startswith(">") :
                info = line.strip().replace(">","")
                if i == 1 :
                    seqs[info] = seq 
                    sizes[info] = len(seq)
                    seq = ''
                    i = 0
                i += 1
            else : 
                seq += line.strip()
                
    return seqs, sizes

def parse_windows(bed_windows) : 

    windows = {}
    
    if isinstance(bed_windows, pd.DataFrame) : 
        for index,row in bed_windows.iterrows() : 
            
            gene = bed_windows.iloc[index][0]
            start = int(bed_windows.iloc[index][1])
            end = int(bed_windows.iloc[index][2])
            
            if gene in windows.keys() : 
                    windows[gene].append([start, end])
            else : 
                windows[gene] = [ [start, end] ]

    
    else :
        with open(bed_windows, 'r') as f : 
            for line in f : 
                info = line.strip().split("\t")
                gene = info[0]
                start = int(info[1])
                end = int(round(float(info[2])))

                if gene in windows.keys() : 
                    windows[gene].append([start, end])
                else : 
                    windows[gene] = [ [start, end] ]
    return windows

def calculate_coverage(bed) : 
    
    bed_entries = 0 
    with open(bed, 'r') as f : 
        for line in f : 
            bed_entries += 1
    f.close()
    print(f"{bed_entries} total bed entries...")
    
    cov = {}
    bed_line = 0
    with open(bed, 'r') as f :
        for line in f : 
            if not (line.startswith('chrom')) and not (line.startswith("gene_name")) : 
                info = line.strip().split('\t')
                chrom = info[0]
                start = int(info[1])
                end = int(info[2])
                seq = str(info[3])
                count = float(info[4])
                
                # make a dict for each gene
                if not chrom in cov.keys() : 
                    cov[chrom] = {}
                
                for i in range(start, end + 1) :
                    if i in cov[chrom].keys() : 
                        cov[chrom][i] += ( count / len(seq) )
                    else : 
                        cov[chrom][i] = ( count / len(seq) )
                        
            bed_line += 1
            print(f'{bed_line} of {bed_entries} ({ round(( (bed_line)/(bed_entries) )*100,2) }%) total bed entries', end='\r')
    f.close() 
    
    pklf = f"{bed}.cov.pickle"
    with open(pklf, 'wb') as handle : 
        pickle.dump(cov, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    return cov
    
def adjust_pos(pos) : 
    
    if not pos == 100 and not pos == 0 : 
        return pos
    elif pos == 100 : 
        return 99.8
    else : 
        return 0.2
    
def sum_metagene_count_per_gene(bed, windows = None, scaling = False, transcript_sizes = None, sample = None, p5nt = ['G', 'A'], min_len = 21, max_len = 23) :
    
    time1 = time.time()
        
    pklf = f"{bed}.cov.pickle"
    if not os.path.exists(pklf) : 
        cov = calculate_coverage(bed)
    else : 
        with open(pklf, 'rb') as handle : 
            cov = pickle.load(handle)
    
    sample = os.path.basename(bed).split(".")[0]

    counts = []
    genes = []
    if windows : 
        total_windows = len(windows.keys())
        tracker = 0
        for gene,win_coord in windows.items() :
            for sub_win in win_coord :
                if gene in cov.keys() :
                    val = sum( [v for k,v in cov[gene].items() if sub_win[0] <= k <= sub_win[1]])
                    if val :
                        win_size = sub_win[1] - sub_win[0] + 1
                        genes.append(f"{gene}_{sub_win[0]}_{sub_win[1]}")
                        counts.append(val/win_size)
                            
           # tracker += 1
            #print(f'{tracker} of {total_windows} ({ round(( (tracker)/(total_windows) )*100,2) }%) windows processed', end='\r')
    else : 
        total_genes = len(cov.keys())
        tracker = 0
        for gene,subdict in cov.items() :
            val = sum(list(subdict.values()))
            if val :
                if gene in count.keys() : 
                    print("Multiple windows for 1 gene...exiting")
                    return 0
                else : 
                    genes.append(gene)
                    counts.append(count)
    
    df = pd.DataFrame({
        'gene' : genes, 
        f'{sample}' : counts
    })
    df = df.drop_duplicates()
    return df

def calculate_metagene_coord(bed, windows = None, scaling = False, transcript_sizes = None, sample = None, p5nt = ['G', 'A'], min_len = 21, max_len = 23) :
    
    time1 = time.time()
        
    pklf = f"{bed}.cov.pickle"
    if not os.path.exists(pklf) : 
        cov = calculate_coverage(bed)
    else : 
        with open(pklf, 'rb') as handle : 
            cov = pickle.load(handle)
    
    positions = []
    counts = []
    if windows : 
        total_windows = len(windows.keys())
        tracker = 0
        for gene,win_coord in windows.items() :
            for sub_win in win_coord :
                if gene in cov.keys() : 
                    cov_sub = { k:v for k,v in cov[gene].items() if sub_win[0] <= k <= sub_win[1] }

                    if bool(cov_sub) : 
                        win_size = sub_win[1] - sub_win[0] + 1

                        if scaling :
                            cov_sub_scaled = [ [(round(100*( (k-sub_win[0])/win_size)*5)/5), v] for k,v in cov_sub.items() ] 
                        else :
                            cov_sub_scaled = [ [(k-sub_win[0]), v] for k,v in cov_sub.items() ]

                        positions.extend([ i[0] for i in cov_sub_scaled ])
                        counts.extend([ i[1]/total_windows for i in cov_sub_scaled ])
            tracker += 1
            print(f'{tracker} of {total_windows} ({ round(( (tracker)/(total_windows) )*100,2) }%) genes processed', end='\r')

                
    # if no windows are specified
    else :
        total_genes = len(cov.keys())
        tracker = 0
        for gene,subdict in cov.items() :
            if scaling : 
                positions.extend([ round(100*(i/transcript_sizes[gene])*5/5) for i in list(subdict.keys()) ])
            else : 
                positions.extend(list(subdict.keys()))
                
            counts.extend(list(subdict.values()))
            
            tracker += 1
            print(f'{tracker} of {total_genes} ({ round(( (tracker)/(total_genes) )*100,2) }%) transcripts processed', end='\r')
    
    
    df1 = pd.DataFrame({
        'pos' : positions,
        'count' : counts, 
        'alias' : sample, 
        'type' : 'density'
    })
    
    res = df1.groupby(['pos', 'alias', 'type'])['count'].sum().reset_index()
    
    res['zscore'] = 2**((res['count'] - res['count'].mean())/res['count'].std(ddof=0))
    
    if scaling :
        if windows : 
            res = res.query('pos < 100 & pos > 0')
            
    time2 = time.time() 
    print(f"Processing {os.path.basename(bed)} took {round(time2-time1, 3)} s.\n")
    
    return res

def run_metagene(bed_input, bed_windows = None, scaling = False, transcripts = None, sample = None, p5nt = ['G', 'A'], min_len = 21, max_len = 23) : 
    
    if isinstance(bed_windows, pd.DataFrame) : 
        my_windows = parse_windows(bed_windows)
    elif bed_windows : 
        my_windows = parse_windows(bed_windows)
    else : 
        my_windows = None
        
    if transcripts : 
        transcript_info = parse_transcripts(transcripts)
        transcript_sizes = transcript_info[1]
    else : 
        transcript_sizes = None
        
    if scaling : 
        if bed_windows or transcripts : 
            pass
        else : 
            print("If scaling is enable must provide either bed windows OR transcripts")
            return 0
    
    if not isinstance(bed_windows, pd.DataFrame) :
        bed_string = f".{os.path.basename(bed_windows).replace('.bed', '')}" if bed_windows is not None else ""
    else : 
        bed_string = "my plot"
    
    if type(bed_input) is list : 
        for i,F in enumerate(bed_input) : 
            my_coords = calculate_metagene_coord(
                F, 
                windows = my_windows, 
                scaling = scaling, 
                transcript_sizes = transcript_sizes, 
                sample = os.path.basename(F).split(".")[0], 
                p5nt = p5nt,
                min_len = min_len, 
                max_len = max_len)
            
            if i == 0 :
                result = my_coords
            else : 
                result = pd.concat([result, my_coords], ignore_index=True)
        
        plot_metagene(result)
        return result
    else : 
        my_coords = calculate_metagene_coord(
            bed_input, 
            windows = my_windows, 
            scaling = scaling, 
            transcript_sizes = transcript_sizes, 
            sample = sample, 
            p5nt = p5nt,
            min_len = min_len, 
            max_len = max_len)
        
        plot_metagene(result, ycol = 'zscore', title = f'{bed_string}')  
        
def run_counting(bed_input, bed_windows = None, scaling = False, transcripts = None, sample = None, p5nt = ['G', 'A'], min_len = 21, max_len = 23) : 
    
    if isinstance(bed_windows, pd.DataFrame) : 
        my_windows = parse_windows(bed_windows)
    elif bed_windows : 
        my_windows = parse_windows(bed_windows)
    else : 
        my_windows = None
        
    if transcripts : 
        transcript_info = parse_transcripts(transcripts)
        transcript_sizes = transcript_info[1]
    else : 
        transcript_sizes = None
        
    bed_string = f".{os.path.basename(bed_windows).replace('.bed', '')}" if bed_windows is not None else ""
    
    if type(bed_input) is list : 
        for i,F in enumerate(bed_input) : 
            me_counts = sum_metagene_count_per_gene(
                F, 
                windows = my_windows, 
                scaling = scaling, 
                transcript_sizes = transcript_sizes, 
                sample = os.path.basename(F).split(".")[0], 
                p5nt = p5nt,
                min_len = min_len, 
                max_len = max_len)
            
            if i == 0 :
                result = me_counts
            else : 
                result = result.merge(me_counts, how = 'outer', on = 'gene')
        
        return result
    else : 
        me_counts = sum_metagene_count_per_gene(
            bed_input, 
            windows = my_windows, 
            scaling = scaling, 
            transcript_sizes = transcript_sizes, 
            sample = sample, 
            p5nt = p5nt,
            min_len = min_len, 
            max_len = max_len)
        
        return me_counts
    
def pad_bedfile(bed) : 
    
    bed['start'] = bed.apply( lambda x : x['end']-100 if (x['end']-100) >= 0 else 0, axis = 1)
    
    bed['end'] = bed['end'] + 100
    
    #bed = bed[['gene', 'start', 'end']]
    
    
    return bed

def reformat_bedfile(FILES, outdir) : 
    
    if not os.path.exists(outdir) : 
        os.mkdir(outdir)
        
    for F in FILES : 
        name = os.path.basename(F).replace(".bed.tsv", ".rpm")
        outname = os.path.join(outdir, name)
        
        if not os.path.exists(name) :
            
            lines = ''
            with open(F, 'r') as f : 
                for line in f : 
                    if not line.startswith("gene") : 
                        info = line.strip().split("\t")
                        lines += f"{info[0]}\t{info[1]}\t{info[2]}\t{info[3]}\t{info[7]}\t{info[5]}\n"
            f.close()

            out = open(outname, 'w')
            out.write(lines)
            out.close()