In [1]:
import os
import sys
import pandas as pd 
import argparse
import numpy as np 
import glob
import time 
import logging
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle
import multiprocessing
from multiprocessing import  Pool
import time
import logging
from functools import partial
import random
from functools import reduce
import tqdm

In [None]:
def plot_metagene(result, plot_group = 'alias', facet_by = None, x_col = 'pos', y_col = 'count', shade = None, title = None, sharey = False, vline = False, ylim = None) : 
    
    plot = sns.FacetGrid(result, 
                         col = facet_by, 
                         hue = plot_group, 
                         height = 4, 
                         aspect = 1.3, 
                        col_wrap = 2)

    plot.map(sns.lineplot, x_col, y_col)
        
    plot.add_legend()
    
    if shade : 
            result['m'] = result[y_col] - result[shade]
            result['M'] = result[y_col] + result[shade]
            plot.map(plt.fill_between, x_col, 'm', 'M', alpha=0.2)
            
    if vline : 
        plt.axvline(x = vline, color = 'grey', label = f'{vline} bp', alpha = 0.5)
    
    
    if title : 
        plot.fig.suptitle(f'{title}')
        
    if ylim : 
        plot.set(ylim=(0, ylim))
    
    return plot

def merge_and_plot(conditions, result, count = 'count', vline = False, ylim = None, facet_by = 'type') :
    
    result_merge = result.merge(
        conditions, how = 'left', left_on = 'alias', right_on = 'simple_name')
    
    result_grouped = result_merge.groupby(
        ['pos', 'type', 'condition']
    ).agg(
        M = (count, 'mean'),
        S = (count, 'std')
    ).reset_index()
    
    plot_metagene(result_grouped, 
                  shade = 'S', 
                  facet_by='type', 
                  x_col = 'pos', 
                  y_col = 'M', 
                  plot_group = 'condition',
                  vline = vline,
                  ylim = ylim
                 )
    
    return result_grouped

def parse_transcripts(transcripts) : 
    
    windows = {}
    seq = ''
    i = 0
    with open(transcripts, 'r') as f : 
        for line in f :
            if line.startswith(">") :
                info = line.strip().replace(">","")
                if i == 1 :
                    windows[info] = [ [0, len(seq)] ]
                    seq = ''
                    i = 0
                i += 1
            else : 
                seq += line.strip()
        else : 
            windows[info] = [ [0, len(seq)] ]
    
    return windows

def parse_windows(bed_windows) : 

    windows = {}
    
    if isinstance(bed_windows, pd.DataFrame) : 
        for index,row in bed_windows.iterrows() : 
            
            gene = bed_windows.iloc[index][0]
            start = int(bed_windows.iloc[index][1])
            end = int(bed_windows.iloc[index][2])
            
            if gene in windows.keys() : 
                    windows[gene].append([start, end])
            else : 
                windows[gene] = [ [start, end] ]
    else :
        with open(bed_windows, 'r') as f : 
            for line in f : 
                info = line.strip().split("\t")
                gene = info[0]
                start = int(info[1])
                end = int(round(float(info[2])))

                if gene in windows.keys() : 
                    windows[gene].append([start, end])
                else : 
                    windows[gene] = [ [start, end] ]
                    
    print(f"Number of windows/transcripts: {len(windows)}")
    return windows

def calculate_coverage(bed, p5nt = ['G', 'A'], min_len = 21, max_len = 23, aggregate = 'density') : 
    
    bed_entries = 0 
    with open(bed, 'r') as f : 
        for line in f : 
            bed_entries += 1
    f.close()
        
    cov = {}
    bed_line = 0
    with open(bed, 'r') as f :
        for line in f : 
            if not (line.startswith('chrom')) and not (line.startswith("gene_name")) : 
                info = line.strip().split('\t')
                chrom = info[0]
                start = int(info[1])
                end = int(info[2])
                seq = str(info[3])
                five_prime = seq[0]
                
                if (five_prime in p5nt) and (len(seq) >= min_len and len(seq) <= max_len) : 
                    
                    count = float(info[4])
                    # make a dict for each gene + each 5' nucleotide
                    # cov > gene > A > pos
                    if not chrom in cov.keys() : 
                        cov[chrom] = {}
                    
                    
                    if not five_prime in cov[chrom].keys() : 
                        cov[chrom][five_prime] = {}
                        
                    if aggregate == 'density' : 
                        for i in range(start, end + 1) :
                            if i in cov[chrom][five_prime].keys() : 
                                cov[chrom][five_prime][i] += ( count / len(seq) )
                            else : 
                                cov[chrom][five_prime][i] = ( count / len(seq) )
                    elif aggregate == 'five_prime_position' : 
                        pos = end
                        if pos in cov[chrom][five_prime].keys() : 
                            cov[chrom][five_prime][pos] += count
                        else : 
                            cov[chrom][five_prime][pos] = count
                    else : 
                        print('Aggregate must be set to either [density OR five_prime_position]')
                        return 0 
                    
                        
            bed_line += 1
    f.close() 
    
    pklf = f"{bed}.cov.{''.join(p5nt)}.{aggregate}.pickle"
    with open(pklf, 'wb') as handle : 
        pickle.dump(cov, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    return cov
    
def calculate_metagene_coord(bed, windows = None, scaling = False, sample = None, p5nt = ['G', 'A'], combine_five_prime_nt = False, aggregate = 'density', missing_data_as_0 = False, sum_coverage = False) :
    
    time1 = time.time()

    pklf = f"{bed}.cov.{''.join(p5nt)}.{aggregate}.pickle"

    if not os.path.exists(pklf) : 
        meta_info = calculate_coverage(bed, p5nt = p5nt, aggregate = aggregate)
    else : 
        with open(pklf, 'rb') as handle : 
            meta_info = pickle.load(handle)
    
    if not sample : 
        sample = os.path.basename(bed).split(".")[0]
    
    positions = []
    counts = []
    five_prime_nts = []
    genes = []
    
    if not windows : 
        return 0
    
    total_windows = len(windows.keys())
    tracker = 1
    for gene,win_coord in windows.items() :
        for sub_win in win_coord :
            if gene in meta_info.keys() :
                for nt in meta_info[gene].keys() : 
                    if nt in p5nt :
                        cov = meta_info[gene][nt]
                        cov_sub = {}
                        
                        for i in range(sub_win[0], sub_win[1]+1) : 
                            if i in cov.keys() : 
                                cov_sub[i] = cov[i]
                            else : 
                                if missing_data_as_0 :
                                    cov_sub[i] = 0
                                    
                        if bool(cov_sub) : 

                            win_size = sub_win[1] - sub_win[0] + 1

                            if scaling :
                                cov_sub_scaled = [ [(round(100*( (k-sub_win[0])/win_size)*5)/5), v] for k,v in cov_sub.items() ] 
                            else :
                                cov_sub_scaled = [ [(k-sub_win[0]), v] for k,v in cov_sub.items() ]

                            positions.extend([ i[0] for i in cov_sub_scaled ])
                            counts.extend([ i[1] for i in cov_sub_scaled ])
                            genes.extend([ gene*len(cov_sub_scaled) ])

                            if combine_five_prime_nt : 
                                five_prime_nts.extend([ ','.join(p5nt) for i in range(0,len(cov_sub_scaled)) ])
                            else : 
                                five_prime_nts.extend([ nt for i in range(0,len(cov_sub_scaled)) ])
        
        #print(f'{tracker} of {total_windows} ({ round(( (tracker)/(total_windows) )*100,2) }%) genes processed', end='\r')
        tracker += 1

    df1 = pd.DataFrame({
        'pos' : positions,
        'count' : counts, 
        'five_prime_nt' : five_prime_nts,
        'alias' : f"{sample}", 
        'type' : 'density',
        'gene' : genes
    })
    
    if not sum_coverage : 
        res = df1.groupby(['pos', 'alias', 'type', 'five_prime_nt'])['count'].sum().reset_index()
    else :
        res = df1.groupby(['gene', 'alias', 'type', 'five_prime_nt'])['count'].sum().reset_index()
    
    res['zscore'] = 2**((res['count'] - res['count'].mean())/res['count'].std(ddof=0))
    
    if not sum_coverage : 
        if scaling :
            if windows :
                res = res.query('pos < 100 & pos > 0')
            
    time2 = time.time() 
    print(f"Processing {sample} took {time2 - time1} s")
    
    return res

def parallel_cov(files, func, windows = None, scaling = False, sample = None, p5nt = ['G', 'A'], combine_five_prime_nt = False, aggregate = 'density', missing_data_as_0 = False, sum_coverage = False, n_cores = multiprocessing.cpu_count()-1) : 
    
    """ Parallelize coverage calculation """
    
    args = []
    for f in files : 
        args.append([f] + [windows, scaling, sample] + [p5nt] + [combine_five_prime_nt] + [aggregate] + [missing_data_as_0] + [sum_coverage])
    pool = Pool(n_cores)
    df = pd.concat(pool.starmap(func, args))
    pool.close()
    pool.join()
    
    return df

def run_metagene(bed_input, bed_windows = None, scaling = False, transcripts = None, sample = None, p5nt = None, min_len = 21, max_len = 23, sum_coverage = False, combine_five_prime_nt = False, aggregate = 'density', missing_data_as_0 = False) : 
    
    if isinstance(bed_windows, pd.DataFrame) : 
        my_windows = parse_windows(bed_windows)
    elif bed_windows : 
        my_windows = parse_windows(bed_windows)
    elif transcripts :
        my_windows = parse_transcripts(transcripts)
    else :
        print("Neither bed or transcripts provided...exiting")
        return 0
    
    if not p5nt : 
        p5nt = ['G', 'A']
    
    print(f'Calculating metagene coordinates for {min_len} to {max_len} RNAs starting with {", ".join(p5nt)}')

    if scaling : 
        if isinstance(bed_windows, pd.DataFrame) : 
            pass
        elif bed_windows or transcripts : 
            pass
        else : 
            print("If scaling is enable must provide either bed windows OR transcripts")
            return 0
        
    if aggregate != 'density' and aggregate != 'five_prime_position' : 
        print('Aggregate must be set to either [density OR five_prime_position]')
        return 0 
    
    if not isinstance(bed_windows, pd.DataFrame) :
        bed_string = f".{os.path.basename(bed_windows).replace('.bed', '')}" if bed_windows is not None else ""
    else : 
        bed_string = "my plot"
    
    if type(bed_input) is list :
        result = parallel_cov(bed_input,
                              calculate_metagene_coord,
                              windows = my_windows,
                              scaling = scaling,
                              sample = sample,
                              p5nt = p5nt,
                              combine_five_prime_nt = combine_five_prime_nt,
                              aggregate = aggregate,
                              missing_data_as_0 = missing_data_as_0,
                              sum_coverage = sum_coverage,
                              n_cores = 16)
        return result

def pad_bedfile(bed) : 
    
    bed['start'] = bed.apply( lambda x : x['end']-100 if (x['end']-100) >= 0 else 0, axis = 1)
    
    bed['end'] = bed['end'] + 100
    
    #bed = bed[['gene', 'start', 'end']]
    
    return bed

def reformat_bedfile(FILES, outdir) : 
    
    if not os.path.exists(outdir) : 
        os.mkdir(outdir)
        
    for F in FILES : 
        name = os.path.basename(F).replace(".bed.tsv", ".rpm")
        outname = os.path.join(outdir, name)
        
        if not os.path.exists(name) :
            
            lines = ''
            with open(F, 'r') as f : 
                for line in f : 
                    if not line.startswith("gene") : 
                        info = line.strip().split("\t")
                        lines += f"{info[0]}\t{info[1]}\t{info[2]}\t{info[3]}\t{info[7]}\t{info[5]}\n"
            f.close()

            out = open(outname, 'w')
            out.write(lines)
            out.close()