## Goal -- how does our precision, sensitivity, and accuracy get impacted when applying the PODER filters? 

In [1]:
import pandas as pd
import numpy as np
import scipy.stats as st
import seaborn as sns
import sys
import os
import matplotlib.pyplot as plt
import yaml
from snakemake.io import expand
import pyranges as pr
from pyfaidx import Fasta
from mizani.formatters import percent_format
from scipy import stats


p = os.path.dirname(os.getcwd())
sys.path.append(p)

from scripts.utils import *
from scripts.vcf_utils import *
from scripts.plotting import *

from plotnine import *

In [2]:
def my_theme(base_size=11, w=4, h=3):
    """
    Custom plotnine theme with:
    - White background
    - Clean styling
    - Axes and ticks retained

    Parameters:
    - base_size: Base font size

    Returns:
    - plotnine.theme object
    """
    return (
        theme_minimal(base_size=base_size)
        + theme(
            # White background
            panel_background=element_rect(fill='white', color=None),
            plot_background=element_rect(fill='white', color=None),

            # Remove grid lines
            panel_grid_major=element_blank(),
            panel_grid_minor=element_blank(),
            panel_border=element_blank(),

            # Keep axis lines & ticks (don't blank them)
            axis_line=element_line(color='black'),
            axis_ticks=element_line(color='black'),

            plot_title=element_text(hjust=0.5, family='Helvetica'),
            axis_title_x=element_text(hjust=0.5, family='Helvetica'),
            axis_title_y=element_text(hjust=0.5, margin={'t':0, 'r':-2, 'b':0, 'l':0}, family='Helvetica'),
            
            # Styling text
            legend_title=element_blank(),
            axis_title=element_text(size=base_size + 1, family='Helvetica'),
            legend_text=element_text(size=base_size-2, family='Helvetica'),
            axis_text=element_text(size=base_size, color='black', family='Helvetica'),
            figure_size=(w, h),  # Controls plot dimensions (width x height in inches)
            plot_margin=0.05      # Shrinks surrounding white space
        )
    )

def clean_figure(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis="x", rotation=45)

In [4]:
config = load_config()
od = '../'
meta_df = load_meta()
meta_df['lab_sample'] = meta_df['lab_number_sample'].astype(str)+'_'+\
                        meta_df['lab_sampleid']+'_'+\
                        meta_df['cell_line_id']


def proc_cfg(entry, od):
    entry = entry.replace('../../', '')
    entry = od+entry
    return entry

init_plot_settings()

In [5]:
## SQANTI reads stuff
f = proc_cfg(expand(config['lr']['qc_sirvs']['sqanti_reads']['class_summary'],
                    annot_completeness='C')[0], od)
df = pd.read_csv(f, sep='\t')

# rename structural categories
m = {'antisense': 'Antisense',
     'full-splice_match': 'FSM',
     'genic': 'Genic', 
     'incomplete-splice_match': 'ISM',
     'intergenic': 'Intergenic',
     'novel_in_catalog': 'NIC',
     'novel_not_in_catalog': 'NNC'}
df['structural_category'] = df.structural_category.map(m)

# add antisense gene ids as assc_gene_2
df['assc_gene_2'] = df.associated_gene
inds = df.loc[df.structural_category=='Antisense'].index
df.loc[inds, 'assc_gene_2'] = df.loc[inds, 'assc_gene_2'].str.split('_', expand=True)[1]

# add relevant metadata
df = df.merge(meta_df[['lab_sample', 'population']], 
              how='left', on='lab_sample')

# add # mapped reads 
df['n_mapped_reads'] = df.groupby('lab_sample')['isoform'].transform('count')

# add spike type
df['spike_type'] = np.nan
df.loc[df.chrom.str.contains('ERCC'), 'spike_type'] = 'ERCC'
df.loc[df.chrom.str.contains('SIRV'), 'spike_type'] = 'SIRV'

spliced_sirvs = ['SIRV1', 'SIRV2', 'SIRV3', 'SIRV4', 'SIRV5', 'SIRV6', 'SIRV7']
df.loc[df.chrom.isin(spliced_sirvs), 'spike_type_2'] = 'spliced_sirv'

# add # spliced sirv reads
df['n_spliced_sirv_reads'] = df.loc[df.spike_type_2=='spliced_sirv'].groupby('lab_sample')['isoform'].transform('count')

# splicing novelty
df['splicing_novelty'] = 'Novel'
df.loc[df.structural_category.isin(['FSM', 'ISM']), 'splicing_novelty'] = 'Known'

# overal known vs. novel
df['overall_nov'] = 'Novel'
df.loc[df.structural_category == 'FSM', 'overall_nov'] = 'Known'

## Ground truth sirv / ercc gtf
gtf_df = pr.read_gtf('../data/qc_sirvs/SIRV_ERCC_longSIRV_multi-fasta_20210507.gtf').df

temp = gtf_df.loc[gtf_df.Chromosome.isin(spliced_sirvs)]
n_spliced_sirvs_tot = len(temp.transcript_id.unique())

# there are 69 spliced sirvs as ground truth

# get the transcript length of each sirv transcript
gtf_df['exon_len'] = gtf_df['End'] - gtf_df['Start']
gtf_df['transcript_len'] = gtf_df.groupby('transcript_id')['exon_len'].transform('sum')

gtf_df['mean_transcript_len'] = gtf_df.groupby('gene_id')['transcript_len'].transform('mean')
gtf_df['med_transcript_len'] = gtf_df.groupby('gene_id')['transcript_len'].transform('median')

# also get number of exons
gtf_df['n_exons'] = gtf_df.groupby('transcript_id')['exon_assignment'].transform('count')
gtf_df['mean_n_exons'] = gtf_df.groupby('gene_id')['n_exons'].transform('mean')
gtf_df['med_n_exons'] = gtf_df.groupby('gene_id')['n_exons'].transform('median')

# also get number of transcripts
gtf_df['n_transcripts'] = gtf_df.groupby('gene_id')['transcript_id'].transform('nunique')

# get # monoexonic transcripts
temp = gtf_df[['gene_id', 'transcript_id', 'n_exons']].groupby(['gene_id', 'n_exons']).nunique().reset_index()
temp = temp.loc[temp.n_exons==1]
temp.drop('n_exons', axis=1, inplace=True)
temp.rename({'transcript_id': 'n_monoexonic_transcripts'}, axis=1, inplace=True)
temp.loc[temp.gene_id=='SIRV6']
gtf_df = gtf_df.merge(temp,
                      how='left',
                      on='gene_id')

# sort the long sirvs at least by length
sirv_order = ['1', '2', '3', '4', '5', '6', '7',
              '4001', '4002', '4003',
              '6001', '6002', '6003', 
              '8001', '8002', '8003',
              '10001', '10002', '10003',
              '12001', '12002', '12003']
sirv_order = [f'SIRV{s}' for s in sirv_order]



## Apply filters one by one and create a label to say whether the transcript passed each filter 
* FSM? (pass_fsm_filter)
* Monoexonic? (pass_monoexonic_filter)
* Reproducibile for novel transcripts? (pass_reproducibility_filter)
* Promoted ISM? (promoted_ism_filter)
* ISM? (pass_ism_filter)

In [11]:
# filter 1 -- novel transcripts must be present in 2+ samples
temp['sample_sharing'] = temp.groupby('jxnHash')['lab_sample'].transform('nunique')
temp.loc[temp.sample_sharing>=2, 'filter_pass'] = True

# filter 2 - for ISMs, if associated_transcript is not in original collection, add it and pass it

# first, mark all ISMs as non-passing
temp.loc[temp.structural_category=='ISM', 'filter_pass'] = False

# get a random fsm entry to add for each uniq one
fsm_temp = temp.loc[(temp.structural_category=='FSM')].drop_duplicates(subset='associated_transcript', keep='first')
fsm_temp['lab_sample'] = np.nan
assert len(fsm_temp.index) == n_spliced_sirvs_tot

# now loop through lab samples and tack on undet. entries FSMs
# for all ISMs
add_df = pd.DataFrame()
for s in temp.lab_sample.unique():
    undet_fsms_from_isms = temp.loc[(temp.lab_sample==s)&\
                                    (temp.structural_category=='ISM')]
    
    temp2 = temp.loc[temp.lab_sample==s]

    # these are the ids of transcripts that are detected via isms
    ism_fsm_ids = temp2.loc[(temp2.structural_category=='ISM'),
                           'associated_transcript'].unique().tolist()
    
    # these are the ids of transcripts that are detected via fsms
    fsm_ids = temp2.loc[(temp2.structural_category=='FSM'),
                        'associated_transcript'].unique().tolist()
    
    # the set difference will be the ones we need to "promote"
    promote_fsm_ids = list(set(ism_fsm_ids)-set(fsm_ids))
    promote_fsm_ids
    
    add_df_temp = fsm_temp.loc[fsm_temp.associated_transcript.isin(promote_fsm_ids)]
    add_df_temp['lab_sample'] = s
    add_df_temp['promoted_ism'] = True
    
    add_df = pd.concat([add_df, add_df_temp], axis=0)

temp = pd.concat([temp, add_df], axis=0)

# filter 0 -- all FSMs pass filtering
temp.loc[temp.structural_category == 'FSM', 'filter_pass'] = True

# filter 3 -- all monoexonics fail filtering 
temp.loc[temp.exons == 1, 'filter_pass'] = False

In [12]:
# now reduce to just uniq junction chains
keep_cols = ['structural_category', 'associated_transcript', 'exons','jxnHash', 'assc_gene_2',
             'lab_sample', 'population',
             'n_mapped_reads', 'spike_type', 'spike_type_2', 'n_spliced_sirv_reads',
             'splicing_novelty', 'overall_nov', 'filter_pass', 'sample_sharing', 'promoted_ism']
temp = temp[keep_cols].drop_duplicates()

In [15]:
print(len(temp.index))
print(len(temp[['jxnHash', 'filter_pass', 'lab_sample']].drop_duplicates()))

9170
8990


In [21]:
# temp2 = temp.loc[temp.duplicated(subset=['jxnHash', 'filter_pass', 'lab_sample'], keep=False)].sort_values(by=['jxnHash'])
# temp2.loc[temp2.structural_category == 'ISM'].head()

Unnamed: 0,structural_category,associated_transcript,exons,jxnHash,assc_gene_2,lab_sample,population,n_mapped_reads,spike_type,spike_type_2,n_spliced_sirv_reads,splicing_novelty,overall_nov,filter_pass,sample_sharing,promoted_ism
209354,ISM,SIRV610,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,20_KE2_GM19328,LWK,6487,SIRV,spliced_sirv,3237.0,Known,Novel,False,37,
395155,ISM,SIRV615,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,4_PY4_GM10495,MPC,29916,SIRV,spliced_sirv,11068.0,Known,Novel,False,37,
274481,ISM,SIRV610,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,15_CH3_GM18631,HAC,7671,SIRV,spliced_sirv,3573.0,Known,Novel,False,37,
393335,ISM,SIRV610,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,4_PY4_GM10495,MPC,29916,SIRV,spliced_sirv,11068.0,Known,Novel,False,37,
176301,ISM,SIRV615,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,24_IS1_GM22234,AJI,7512,SIRV,spliced_sirv,3333.0,Known,Novel,False,37,


In [25]:
temp['ic_id'] = temp.jxnHash
temp.loc[temp.structural_category=='FSM', 'ic_id'] = 'associated_transcript'

print(len(temp.index))
print(len(temp[['ic_id', 'filter_pass', 'lab_sample']].drop_duplicates()))

9170
6348


In [24]:
temp2 = temp.loc[temp.duplicated(subset=['ic_id', 'filter_pass', 'lab_sample'], keep=False)].sort_values(by=['ic_id'])
temp2.loc[temp2.structural_category == 'ISM'].head()

Unnamed: 0,structural_category,associated_transcript,exons,jxnHash,assc_gene_2,lab_sample,population,n_mapped_reads,spike_type,spike_type_2,n_spliced_sirv_reads,splicing_novelty,overall_nov,filter_pass,sample_sharing,promoted_ism,ic_id
233872,ISM,SIRV615,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,11_NI6_GM19129,YRI,8911,SIRV,spliced_sirv,3983.0,Known,Novel,False,37,,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...
330385,ISM,SIRV615,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,43_EU3_GM12812,CEU,5457,SIRV,spliced_sirv,2897.0,Known,Novel,False,37,,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...
356911,ISM,SIRV615,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,42_EU2_GM12778,CEU,30767,SIRV,spliced_sirv,11012.0,Known,Novel,False,37,,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...
234565,ISM,SIRV610,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,11_NI6_GM19129,YRI,8911,SIRV,spliced_sirv,3983.0,Known,Novel,False,37,,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...
316868,ISM,SIRV610,2,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...,SIRV6,45_EU5_GM12878,CEU,15664,SIRV,spliced_sirv,6512.0,Known,Novel,False,37,,006f2d0edb78027e4384277f846a34c50f44f014fa5eec...
