In [37]:
import pyranges as pr
import pandas as pd
import numpy as np
import h5py 
import os

from cerberus.cerberus import *
from cerberus.main import *


In [27]:
def merge_ics(df, ic):
    """
    Assign each transcript in df to an intron chain in ic
    
    Parameters:
        df (pandas DataFrame): DataFrame w/ intron chain for each transcript
        ic (pandas DataFrame): DataFrame from cerberus reference with intron chains
        
    Returns:
        df (pandas DataFrame): 
    """
    # merge on intron chain, strand, chromosome, and gene id
    df = df.merge(ic, how='left',
                  on=['Chromosome', 'Strand',
                      'Coordinates', 'gene_id'])

    # formatting
    df.rename({'Name': 'ic_id'}, axis=1, inplace=True)
    df = df[['transcript_id', 'ic', 'ic_id']]
    
    return df

In [80]:
def merge_ends(ends, ref, mode):
    """
    Merge ends from a GTF with those already annotated in a cerberus reference
    
    Parameters:
        ends (pyranges PyRanges): PyRanges object of ends from GTF annotation
        ref (pyranges PyRanges): PyRanges object of reference ends from cerberus
            annotation
        mode (str): {'tss', 'tes'}
        
    Returns:
        df (pandas DataFrame): DataFrame detailing which transcript uses which 
            end from the cerberus reference
    """
    
    
    # limit to relevant columns
    ends = ends[['Chromosome', 'Start', 'End', 'Strand',
                 'gene_id', 'transcript_id']]

    # get only the relevant columns and deduplicate
    ends = ends.df
    t_ends = ends.copy(deep=True)
    ends.drop('transcript_id', axis=1, inplace=True)
    ends.drop_duplicates(inplace=True)
    ends = pr.PyRanges(ends)

    # find closest interval in ref
    ends = ends.nearest(ref,
                        strandedness=None)

    # fix the ends with mismatching gene ids - this part can be slow :(
    ends = ends.df
    fix_ends = ends.loc[ends.gene_id != ends.gene_id_b]
    fix_ends = fix_ends[['Chromosome', 'Start', 'End', 'Strand',
                         'gene_id']]
    # print('ends that need to be fixed')
    # print(fix_ends)
    ends = ends.loc[ends.gene_id == ends.gene_id_b]
    # print('okie dokie ends')
    # print(ends)

    for i, gid in enumerate(fix_ends.gene_id.unique().tolist()):
        gene_ends = fix_ends.loc[fix_ends.gene_id == gid].copy(deep=True)
        gene_ends = pr.PyRanges(gene_ends)
        gene_refs = ref.df.loc[ref.df.gene_id == gid].copy(deep=True)
        gene_refs = pr.PyRanges(gene_refs)
        gene_ends = gene_ends.nearest(gene_refs,
                                      strandedness=None)
        gene_ends = gene_ends.df
        ends = pd.concat([ends, gene_ends])

        if i % 100 == 0:
            print('Processed {} / {} genes'.format(i, len(fix_ends.gene_id.unique().tolist())))

    # merge back in to get transcript ids
    t_ends = t_ends.merge(ends, how='left',
                      on=['Chromosome', 'Start', 'End', 'Strand', 'gene_id'])

    # formatting
    t_ends.rename({'Name': '{}_id'.format(mode)}, axis=1, inplace=True)`
    t_ends = t_ends[['transcript_id', '{}_id'.format(mode), mode]]

    return t_ends

## convert transcriptome

In [83]:
def make_end_gtf_df(c,s,st,e,g,t):
    df = pd.DataFrame()
    cols = ['Chromosome', 'Strand', 'Start', 'End', 'gene_id', 'transcript_id']
    var = [c,s,st,e,g,t]
    for col, var in zip(cols, var):
        if type(var) == list:
            df[col] = var

    df = format_end_df(df)

    return df

def make_end_df(c,s,st,e,n, source,mode, add_id=True):
    df = pd.DataFrame()
    cols = ['Chromosome', 'Strand', 'Start', 'End', 'Name']
    var = [c,s,st,e,n]
    for col, var in zip(cols, var):
        if type(var) == list:
            df[col] = var

    # add source
    df['source'] = source

    df = format_end_df(df)

    # get end # and gene id
    if any(df.Name.isnull()):
        df['gene_id'] = np.nan
        df[mode] = np.nan
    else:
        df['gene_id'] = df.Name.str.split('_', expand=True)[0]
        df[mode] = df.Name.str.split('_', expand=True)[1]

    # get arbitrary unique ids
    if add_id:
        df['id'] = [i for i in range(len(df.index))]

    return df

def format_end_df(df):
    sort_cols = ['Chromosome', 'Start', 'End', 'Strand']
    df = df.sort_values(by=sort_cols)
    order = ['Chromosome', 'Start', 'End', 'Strand', 'Name', 'gene_id', 'transcript_id', 'start', 'source']
    order = [o for o in order if o in df.columns]
    df = df[order]
    df.reset_index(drop=True, inplace=True)
    return df

def test_merge_ends(print_dfs=True):
    # cases where
    # - ends overlap existing ends as well as gene id
    # - ends do not overlap existing ends and we need to choose closest
    # - ends overlap existing ends but gene ids don't match 
    # - ends overlap but are not on same strand
    # - equidistant matches?

    for mode in ['tss', 'tes']:

        # annot
        n = 4
        c = ['1' for i in range(n)]
        s = ['+' for i in range(n)]
        st = [1, 200, 100, 300]
        e = [2, 201, 101, 301]
        g = ['gene1', 'gene1', 'gene2', 'gene3']
        t = ['t1', 't2', 't3', 't4']
        annot_df = make_end_gtf_df(c,s,st,e,g,t)
        annot_df = pr.PyRanges(annot_df)

        # ref 
        n = 5
        c = ['1' for i in range(n)]
        s = ['+' for i in range(n)]
        s[-1] = '-'
        st = [1, 150, 90, 200, 290]
        e = [40, 160, 98, 250, 340]
        n = ['gene1_1', # entry in annot intersects
             'gene1_2', # entry in annot needs to choose closest
             'gene2_1', # entry in annot matches but wrong gene
             'gene3_1', #
             'gene4_1'] # entry overlaps on the wrong strand
        source = 'test'
        ref_df = make_end_df(c,s,st,e,n, source, mode, add_id=False)
        ref_df = pr.PyRanges(ref_df)

        # ctrl
        t = ['t1', 't2', 't3', 't4']
        i = ['gene1_1', 'gene1_2', 'gene2_1', 'gene3_1']
        ends = ['1', '2', '1', '1']
        ctrl = pd.DataFrame()
        ctrl['transcript_id'] = t
        ctrl['{}_id'.format(mode)] = i
        ctrl[mode] = ends

        test = merge_ends(annot_df, ref_df, mode)

        def order_df(df):
            df.sort_values(by='transcript_id', inplace=True)
            df.reset_index(drop=True, inplace=True)
            return df

        ctrl = order_df(ctrl)
        test = order_df(test)

        if print_dfs:
            print('test')
            print(test)
            print(test.index)
            print(test.dtypes)
            print('ctrl')
            print(ctrl)
            print(ctrl.index)
            print(ctrl.dtypes)

        pd.testing.assert_frame_equal(ctrl, test, check_like=True)

        assert len(ctrl.index) == len(test.index)


Processed 0 / 2 genes
  Chromosome  Start  End Strand gene_id transcript_id  Start_b  End_b  \
0          1      1    2      +   gene1            t1        1     40   
1          1    100  101      +   gene2            t3       90     98   
2          1    200  201      +   gene1            t2      150    160   
3          1    300  301      +   gene3            t4      200    250   

  Strand_b   tss_id source gene_id_b tss  Distance  
0        +  gene1_1   test     gene1   1         0  
1        +  gene2_1   test     gene2   1         3  
2        +  gene1_2   test     gene1   2        41  
3        +  gene3_1   test     gene3   1        51  

test
  transcript_id   tss_id tss
0            t1  gene1_1   1
1            t2  gene1_2   2
2            t3  gene2_1   1
3            t4  gene3_1   1
RangeIndex(start=0, stop=4, step=1)
transcript_id    object
tss_id           object
tss              object
dtype: object
ctrl
  transcript_id   tss_id tss
0            t1  gene1_1   1
1          

In [66]:
annot_df

Unnamed: 0,Chromosome,Start,End,Strand,gene_id,transcript_id
0,1,1,2,+,gene1,t1
1,1,100,101,+,gene2,t3
2,1,200,201,+,gene1,t2
3,1,300,301,+,gene3,t4


In [67]:
ref_df

Unnamed: 0,Chromosome,Start,End,Strand,Name,source,gene_id,tss
0,1,1,40,+,gene1_1,test,gene1,1
1,1,90,98,+,gene2_1,test,gene2,1
2,1,150,160,+,gene1_2,test,gene1,2
3,1,200,250,+,gene3_1,test,gene3,1
4,1,290,340,-,gene4_1,test,gene4,1


In [70]:
test

Unnamed: 0,transcript_id,tss_id,tss
0,t1,gene1_1,1
1,t3,gene2_1,1
2,t2,gene1_2,2
3,t4,gene3_1,1


In [71]:
ctrl

Unnamed: 0,transcript_id,tss_id
0,t1,gene1_1
1,t2,gene1_2
2,t3,gene2_1
3,t4,gene3_1


In [35]:
def make_ics_gtf_df(c,s,co,g,t):
    df = pd.DataFrame()
    cols = ['Chromosome', 'Strand', 'Coordinates', 'gene_id', 'transcript_id']
    var = [c,s,co,g,t]
    for col, var in zip(cols, var):
        if type(var) == list:
            df[col] = var

    df = format_ics_df(df)
    
    return df   

def make_ics_map_df(t, ic, i):
    df = pd.DataFrame()
    cols = ['transcript_id', 'ic', 'ic_id'] 
    var = [t, ic, i]
    for col, var in zip(cols, var):
        df[col] = var
    return df

def make_ics_df(c,s,co,n, source):
    df = pd.DataFrame()
    cols = ['Chromosome', 'Strand', 'Coordinates', 'Name']
    var = [c,s,co,n]
    for col, var in zip(cols, var):
        if type(var) == list:
            df[col] = var

    df = format_ics_df(df)

    df['gene_id'] = df.Name.str.split('_', expand=True)[0]
    df['ic'] = df.Name.str.split('_', expand=True)[1]
    df['source'] = source
    df.ic = df.ic.astype(int)

    return df

def format_ics_df(df):
    sort_cols = ['Chromosome', 'Strand', 'Coordinates']
    df = df.sort_values(by=sort_cols)
    order = ['Chromosome', 'Strand', 'Coordinates', 'Name', 'gene_id', 'transcript_id', 'ic', 'source']
    order = [o for o in order if o in df.columns]
    df = df[order]
    df.reset_index(drop=True, inplace=True)
    return df

# merge ics
def test_merge_ics(print_dfs=True):

    # ics to annotate
    n = 3
    c = ['1' for i in range(n)]
    s = ['+' for i in range(n)]
    co = ['1-2-3', '1-2-3', '4-5-6']
    g = ['gene1', 'gene1', 'gene2']
    t = ['t1', 't2', 't3']
    annot_df = make_ics_gtf_df(c,s,co,g,t)


    # ref ics
    n = 3
    c = ['1' for i in range(n)]
    s = ['+' for i in range(n)]
    co = ['1-2-3', '4-5-6', '4-5-6']
    n = ['gene1_1', 'gene1_2', 'gene2_1']
    source = 'test'
    ref_df = make_ics_df(c,s,co,n, source)

    # ctrl 
    t = ['t1', 't2', 't3']
    ic = [1, 1, 1]
    i = ['gene1_1', 'gene1_1', 'gene2_1']
    ctrl = make_ics_map_df(t, ic, i)

    test = merge_ics(annot_df, ref_df)

    if print_dfs:
        print('test')
        print(test)
        print(test.index)
        print(test.dtypes)
        print('ctrl')
        print(ctrl)
        print(ctrl.index)
        print(ctrl.dtypes)

    pd.testing.assert_frame_equal(ctrl, test, check_like=True)
    assert len(ctrl.index) == len(test.index)


In [36]:
test_merge_ics()

test
  transcript_id  ic    ic_id
0            t1   1  gene1_1
1            t2   1  gene1_1
2            t3   1  gene2_1
Int64Index([0, 1, 2], dtype='int64')
transcript_id    object
ic                int64
ic_id            object
dtype: object
ctrl
  transcript_id  ic    ic_id
0            t1   1  gene1_1
1            t2   1  gene1_1
2            t3   1  gene2_1
RangeIndex(start=0, stop=3, step=1)
transcript_id    object
ic                int64
ic_id            object
dtype: object
