### Prep test filters
- Test filtering criteria to use for analysis cutoffs
- Ideally removing these low abundance genes (basically dropouts in some libraries) will improve inter-replicate correlation.
- Figure out what CPM corresponds to 10 counts as referenced by this post
https://support.bioconductor.org/p/69433/

##### Conclusions
- Choose genes with at least 10 counts in all libraries of one condition (input or pd), or in the BG3 case, input or pd, mutant or non mutant
- For reproducibility scatter plot, it seems reasonable to further limit the comparison to genes with > 10 counts in the replicates being compared
- We loose many genes by applying a degradation rate CV cutoff. In order to preserve N numbers for gene group comparisons, it seems resonable not to perform further filtering and to use all genes which pass the read count filtering
- However, for the histogram in Fig. 1, where we are trying to estimate the decay rate and synthesis rate range, it seems reasonable to limit this to strictly 'high confidence' rates, in order to not over or under-estimate the range

In [None]:
#Imports
import sys
import os
import pandas as pd
import seaborn as sns
import numpy as np
import math
import scipy.stats as stats
import gffutils

sys.path.append('../scripts')
from plot_helpers import *
from utilities import filter_low_exp, load_dataset, calc_pseudocount_val

db = gffutils.FeatureDB(gffutils_db)

%load_ext autoreload
%autoreload 2

#### Choice of filtering level for the experiments
- Set the filtering level to 10 counts in at least 2/3 of libraries in one condition/replicate set
- For most experimental sets, npass=2 because there were three replicates
- For the ph/mock BG3 data, npass=4 for input mock because there were 6 replicates
##### Another option would be to set filtering level to n counts in all libraries, and maybe that would solve the r2 dropout issue

In [None]:
#Get the genes passing the read count filtering for the each experiment
outdir = '../Figures/summary_files'
os.makedirs(outdir, exist_ok = True)

#Exp1 (Brain4sU)
res_file1 = os.path.join(results_dir, 'gene_quantification','summary_abundance_by_gene_filtered.csv')
# Examine the following filters:
# A) 10 counts in all libs of one condition
# B) 10 counts in all
# C) 1 CPM in in all libs of one condition
# D) 1 CPM in all
# Note that 1 CPM is up to 2X more than 10 counts for the brain library, but I feel that counts
# is better to use because it's raw signal

exp = [{'RNAtype':'input', 'condition':1, 'npass':3}, {'RNAtype':'pd', 'condition':1, 'npass':3}]

passed_genes_A = filter_low_exp(res_file1, filter_co=10, filter_col='summed_est_counts', experiments=exp,
                              outname=os.path.join(outdir,'brain4sU_10count_lib'))

passed_genes_B = filter_low_exp(res_file1, filter_col='summed_est_counts', filter_co=10, npass=6,
                              outname=os.path.join(outdir,'brain4sU_10count_all'))

passed_genes_C = filter_low_exp(res_file1, filter_co=1, filter_col='CPM', experiments=exp,
                              outname=os.path.join(outdir,'brain4sU_1cpm_lib'))

passed_genes_D = filter_low_exp(res_file1, filter_col='CPM', filter_co=10, npass=6,
                              outname=os.path.join(outdir,'brain4sU_1cpm_all'))

print(f'co_10count_lib: {len(passed_genes_A)} passed the filter')
print(f'co_10count_all: {len(passed_genes_B)} passed the filter')
print(f'co_1cpm_lib: {len(passed_genes_C)} passed the filter')
print(f'co_1cpm_all: {len(passed_genes_D)} passed the filter')
# 1 CPM all causes a sharp decline from 8.7k genes for 10 counts per one condition to 3.7k genes for 1CPM all

In [None]:
df_A = load_dataset(res_file1, os.path.join(outdir, 'brain4sU_10count_lib_passed.csv'))
df_B = load_dataset(res_file1, os.path.join(outdir, 'brain4sU_10count_all_passed.csv'))
df_C = load_dataset(res_file1, os.path.join(outdir, 'brain4sU_1cpm_lib_passed.csv'))
df_D = load_dataset(res_file1, os.path.join(outdir, 'brain4sU_1cpm_all_passed.csv'))

# Add CPM to these datasets since might need to filter further for the scatter plots:
read_col = 'summed_est_counts'
df_C['CPM'] = df_C[read_col]*1e6/df_C.groupby(['replicate', 'condition', 'RNAtype'])[read_col].transform('sum')
df_D['CPM'] = df_D[read_col]*1e6/df_D.groupby(['replicate', 'condition', 'RNAtype'])[read_col].transform('sum')


In [None]:
from plotting_fxns import *
def compare_plot(df, experiments=None, val_col=None, remove_low_count=None, ax=None, other_cols=[], low_count_val=-4):
    '''
    Wrapper to make comparison df and plot results.
    '''
    cdf = compare_experiments(df.reset_index(), experiments=experiments, id_col='gene', val_col=val_col, other_cols=other_cols, pseudo='min',
                               log=True)
    
    if remove_low_count:
        cdf = cdf.query(f'{remove_low_count}_x > {low_count_val} & {remove_low_count}_y > {low_count_val}')
    ax = plot_scatter(cdf, experiments=[f'{val_col}_x', f'{val_col}_y'], id_col='gene', rsquare=True, ax=ax)
    return cdf, ax

In [None]:
# Make scatter plots without further filtering
fig = plt.figure(figsize=(dfig*2,dfig*2), constrained_layout=True)
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
cdf1, ax1 = compare_plot(df_A, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax1)
cdf2, ax2 = compare_plot(df_B, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax2)
cdf3, ax3 = compare_plot(df_C, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax3)
cdf4, ax4 = compare_plot(df_D, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax4)

axes = [ax1, ax2, ax3, ax4]
for a in axes:
    a.set_xlabel('rep1 TPM')
    a.set_ylabel('rep2 TPM')
ax1.text(0.1, 0.7, 'A', transform=ax1.transAxes)
ax2.text(0.1, 0.7, 'B', transform=ax2.transAxes)
ax3.text(0.1, 0.7, 'C', transform=ax3.transAxes)
ax4.text(0.1, 0.7, 'D', transform=ax4.transAxes)



In [None]:
# Make scatter plots with further filtering to reach the cutoff in the libraries being commpared by plot
fig = plt.figure(figsize=(dfig*2,dfig*2), constrained_layout=True)
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
cdf1, ax1 = compare_plot(df_A, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax1, remove_low_count='summed_est_counts', other_cols=['summed_est_counts'], 
                     low_count_val=10)
cdf2, ax2 = compare_plot(df_B, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax2, remove_low_count='summed_est_counts', other_cols=['summed_est_counts'], low_count_val=10)
cdf3, ax3 = compare_plot(df_C, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax3, remove_low_count='CPM', other_cols=['CPM'], low_count_val=1)
cdf4, ax4 = compare_plot(df_D, experiments=[{'condition':1, 'RNAtype':'input', 'replicate':1}, 
                    {'condition':1, 'RNAtype':'input', 'replicate':2}], val_col = 'summed_tpm_recalc', 
                     ax=ax4, remove_low_count='CPM', other_cols=['CPM'], low_count_val=1)

axes = [ax1, ax2, ax3, ax4]
for a in axes:
    a.set_xlabel('rep1 TPM')
    a.set_ylabel('rep2 TPM')
ax1.text(0.1, 0.7, 'A', transform=ax1.transAxes)
ax2.text(0.1, 0.7, 'B', transform=ax2.transAxes)
ax3.text(0.1, 0.7, 'C', transform=ax3.transAxes)
ax4.text(0.1, 0.7, 'D', transform=ax4.transAxes)

In [None]:
# Now look at SD for decay rate of the different filtered genes:
rate_df = pd.read_csv('../Figures/summary_files/INSPEcT_rates.csv', index_col='gene')


In [None]:
passed_A = set(pd.read_csv(os.path.join(outdir, 'brain4sU_10count_lib_passed.csv'), header=None)[0])
passed_B = set(pd.read_csv(os.path.join(outdir, 'brain4sU_10count_all_passed.csv'), header=None)[0])
passed_C = set(pd.read_csv(os.path.join(outdir, 'brain4sU_1cpm_lib_passed.csv'), header=None)[0])
passed_D = set(pd.read_csv(os.path.join(outdir, 'brain4sU_1cpm_all_passed.csv'), header=None)[0])

In [None]:
geo_outdir = '../Figures/GEO_files/'
rate_df = pd.read_csv(os.path.join(geo_outdir, 'INSPEcT_rates.csv'), index_col='gene_ID')
rate_df_A = rate_df.loc[rate_df.index.isin(passed_A)].copy()
rate_df_B = rate_df.loc[rate_df.index.isin(passed_B)].copy()
rate_df_C = rate_df.loc[rate_df.index.isin(passed_C)].copy()
rate_df_D = rate_df.loc[rate_df.index.isin(passed_D)].copy()

for df in [rate_df_A, rate_df_B, rate_df_C, rate_df_D]:
    df['log_deg_rate'] = df['degradation_rate'].apply(np.log10)
    df['deg_CV'] = df['degradation_sd']*100/df['degradation_rate']

In [None]:
# Histogram of the decay rates at various levels of filtering
fig = plt.figure(figsize=(dfig,dfig))
ax = fig.add_subplot(111)
ax = sns.histplot(x='log_deg_rate', data=rate_df_A, label='deg_A', element='step', color=color_dict['grey'], fill=False, ax=ax)
ax = sns.histplot(x='log_deg_rate', data=rate_df_B, label='deg_B', element='step', color=color_dict['blue'], fill=False, ax=ax)
ax = sns.histplot(x='log_deg_rate', data=rate_df_C, label='deg_C', element='step', color=color_dict['purple'], fill=False, ax=ax)
ax = sns.histplot(x='log_deg_rate', data=rate_df_D, label='deg_D', element='step', color=color_dict['green'], fill=False, ax=ax)
ax.set_xlim(-5, 0)
ax.legend()
# More stringent filtering causes many of the very low decay rate genes to disappear
# Low decay rate => present in the total, but very low counts in the PD. So these must be genes which were not detected in the pulldown libraries.

In [None]:
# Histogram of the decay rate CVs at various levels of filtering
fig = plt.figure(figsize=(dfig,dfig))
ax = fig.add_subplot(111)
ax = sns.histplot(x='deg_CV', data=rate_df_A, label='deg_A', element='step', color=color_dict['grey'], fill=False, ax=ax)
ax = sns.histplot(x='deg_CV', data=rate_df_B, label='deg_B', element='step', color=color_dict['blue'], fill=False, ax=ax)
ax = sns.histplot(x='deg_CV', data=rate_df_C, label='deg_C', element='step', color=color_dict['purple'], fill=False, ax=ax)
ax = sns.histplot(x='deg_CV', data=rate_df_D, label='deg_D', element='step', color=color_dict['green'], fill=False, ax=ax)
# ax.set_xlim(-5, 0)
ax.set_xlim(0, 200)
ax.legend()

In [None]:
# What is the relationship between CV and deg rates?
fig = plt.figure(figsize=(dfig*2,dfig*2), constrained_layout=True)
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

ax1 = sns.histplot(data=rate_df_A, x='log_deg_rate', y='deg_CV', cmap='rocket', ax=ax1, zorder=2)
ax2 = sns.histplot(data=rate_df_B, x='log_deg_rate', y='deg_CV', cmap='rocket', ax=ax2, zorder=2)
ax3 = sns.histplot(data=rate_df_C, x='log_deg_rate', y='deg_CV', cmap='rocket', ax=ax3, zorder=2)
ax4 = sns.histplot(data=rate_df_D, x='log_deg_rate', y='deg_CV', cmap='rocket', ax=ax4, zorder=2)

# ax1.scatter(rate_df_A['log_deg_rate'], rate_df_A['deg_CV'], s=1, color='k', alpha=0.1)
# ax2.scatter(rate_df_B['log_deg_rate'], rate_df_B['deg_CV'], s=1, color='k', alpha=0.1)
# ax3.scatter(rate_df_C['log_deg_rate'], rate_df_C['deg_CV'], s=1, color='k', alpha=0.1)
# ax4.scatter(rate_df_D['log_deg_rate'], rate_df_D['deg_CV'], s=1, color='k', alpha=0.1)

for ax in [ax1, ax2, ax3, ax4]:
    ax.set_ylim(0, 200)
    ax.set_xlim(-5, 0)

In [None]:
# Filtering implications for SD of deg rates:
# What is the median SD for each filtered set?
# What fraction of read count/CPM filtered genes would pass the filter if CV set to cv_co?
cv_co = 25
dfs = {'A':rate_df_A, 'B':rate_df_B, 'C':rate_df_C, 'D':rate_df_D}

for d in dfs:
    print(f'{d} CV median {dfs[d]["deg_CV"].median()}')
    print(f'frac passed CV co {len(dfs[d].query("deg_CV < @cv_co"))/len(dfs[d])}')


In [None]:
# What happens to the histograms of synthesis and decay rates if we limit it to genes under certain CV cutoffs?
cv_cos = [10, 25, 33, 50, 100]
fig = plt.figure(figsize=(dfig*2, dfig), constrained_layout=True)
ax = fig.add_subplot(111)
sigma_deg = '%1.2f' % rate_df_A['log_deg_rate'].std()
ax = sns.histplot(x='log_deg_rate', data=rate_df_A, label=f'none ({sigma_deg})', element='step', fill=False, ax=ax)
for co in cv_cos:
    this_df = rate_df_A.query('deg_CV <@co')
    sigma_deg = '%1.2f' % this_df['log_deg_rate'].std()
    ax = sns.histplot(x='log_deg_rate', data=this_df, label=f'{co} ({sigma_deg})', element='step', fill=False, ax=ax)
ax.set_ylabel('number of genes')
ax.legend()
ax.set_xlim(-5, 3)
plt.savefig('%s.%s' % (os.path.join(outdir, 'cvco_deg_hist'), out_fmt), dpi = out_dpi)
# Conclusion: doesn't change a ton and no cutoff actually looks pretty similar to 10, so going to leave as is

Now figure out the percentile change vs. the deg rate CV

In [None]:
rate_df_A['stab_percentile'] = rate_df_A['halflife'].rank(pct=True)*100
# https://stackoverflow.com/questions/55102473/how-do-we-fit-a-sigmoid-function-in-python
# https://stackoverflow.com/questions/65233445/how-to-calculate-sums-in-log-space-without-underflow
from scipy.optimize import curve_fit

def sigmoid (x, A, h, slope, C):
    return 1 / (1 + np.exp ((x - h) / slope)) *  A + C

p, _ = curve_fit(sigmoid, rate_df_A['degradation_rate'].apply(np.log10), rate_df_A['stab_percentile'])

fig = plt.figure(figsize=(dfig,dfig))
ax = fig.add_subplot(111)

x = rate_df_A['degradation_rate'].apply(np.log10)
sorted_x = sorted(x)
ax.scatter(x, rate_df_A['stab_percentile'], label='original')
ax.plot(sorted_x, sigmoid(sorted_x, *p), label='sigmoid fit', color=color_dict['purple'])
ax.legend()
ax.set_xlim(-5,0)
ax.set_xlabel('deg rate log'r'$_{10}$')
ax.set_ylabel('stability percentile')


In [None]:
# Now calculate how much the stability percentile would change for given CV amounts
test_df = rate_df_A[['degradation_rate', 'degradation_sd', 'deg_CV', 'stability_percentile']].copy()
test_df['lower_linear'] = test_df['degradation_rate'] - test_df['degradation_sd']
test_df['upper_linear'] = test_df['degradation_rate'] + test_df['degradation_sd']
test_df[['log_deg_rate', 'log_deg_sd', 'log_lower', 'log_upper']] = test_df[['degradation_rate', 'degradation_sd', 'lower_linear', 'upper_linear']].apply(np.log10)
test_df['lower_percentile'] = test_df['log_lower'].apply(lambda x: sigmoid(x, *p))
test_df['upper_percentile'] = test_df['log_upper'].apply(lambda x: sigmoid(x, *p))
test_df['percentile_spread'] = test_df['lower_percentile'] - test_df['upper_percentile']

In [None]:
def find_passed_genes(df, geneset=None, cv_co=[10, 25, 33, 50, 100]):
    '''Find the number of genes in the df that pass at each CV cutoff'''
    group_n = len(df.index.intersection(geneset))
    # print(f'num genes with no cutoff {len(df.index.intersection(geneset))}')
    d = {}
    for i in cv_co:
        genes = df.query('deg_CV <= @i').index
        # print(f'co {i}: {len(genes.intersection(geneset))}')
        d[i] = len(genes.intersection(geneset))/group_n
    return pd.DataFrame.from_dict(d, orient='index')

df = pd.read_csv('../Figures/Devreg/gene_cat_me3.csv', index_col='gene')
df['me3'] = df['category'] == 'updowngene'

me_df = find_passed_genes(test_df, geneset=df.query('me3').index)
tf_df = find_passed_genes(test_df, geneset=df.query('TF').index)
cts_df = find_passed_genes(test_df, geneset=df.query('CTS').index)
all_df = find_passed_genes(test_df, geneset=df.index)

frac_df = pd.concat([tf_df, me_df, cts_df, all_df], axis=1)
frac_df.columns = ['TF', 'me3', 'CTS', 'all']

In [None]:
frac_df

In [None]:
test_df['percentile_spread_higher?'] = test_df['percentile_spread'] > test_df['deg_CV']
# Percentile spread is not higher than CV in any case:
test_df['percentile_spread_higher?'].any()

##### So what CV cutoff should we use?
- Conservative: CV <10, gene would go from 10th percentile to 20th percentile
- Moderate: CV < 25, gene would go from 10th percentile to 35th percentile
- Liberal: CV < 33, gene would go from 10th percentile to 43rd percentile
- How many TFs will pass at each threshold?

In [None]:
fig = plt.figure(figsize=(dfig*2,dfig))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.scatter(test_df['deg_CV'], test_df['percentile_spread'], color='k', alpha=0.3, s=1)
ax1.set_xlabel('deg CV')
ax1.set_ylabel('percentile spread')
ax2 = sns.histplot(data=test_df, x='deg_CV', y='percentile_spread', cmap='rocket', ax=ax2, zorder=2)

This shows that the CV is about equal to the percentile spread on average, but some genes have less percentile spread
I'm thinking the ones that have less percentile spread might be genes that are closer to the edges, where larger quantitative changes 
don't have as large of an effect on the percentile (the sigmoid is flatter and less linear there)

In [None]:
# What about percentile vs. CV?
fig = plt.figure(figsize=(dfig*2,dfig))
ax1 = fig.add_subplot(121)
# ax2 = fig.add_subplot(122)

ax1.scatter(test_df['deg_CV'], test_df['stability_percentile'], color='k', alpha=0.3, s=1)
# ax1.set_xlabel('deg CV')
# ax1.set_ylabel('stability percentile')
ax1.set_xlim(0,500)
# ax2 = sns.histplot(data=test_df, x='deg_CV', y='stability_percentile', cmap='rocket', ax=ax2, zorder=2)