# TF Complex Screening

Covariance screening process.

### Import

In [1]:
from IPython.core.display import display, HTML
import warnings
warnings.filterwarnings('ignore')
display(HTML("<style>.container { width:100% !important; }</style>"))
%matplotlib inline

In [2]:
repo_path = '/Users/mincheolkim/Github/'
data_path = '/Users/mincheolkim/Documents/'

In [3]:
import sys
sys.path.append(repo_path + 'scVI')
sys.path.append(repo_path + 'scVI-extensions')

In [4]:
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests
import imp
import numpy as np
import pandas as pd
from scipy.stats import pearsonr

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
import scvi_extensions.dataset.supervised_data_loader as sdl
import scvi_extensions.dataset.cropseq as cs
import scvi_extensions.inference.supervised_variational_inference as svi
import scvi_extensions.hypothesis_testing.mean as mn
import scvi_extensions.hypothesis_testing.variance as vr
import scvi_extensions.dataset.label_data_loader as ldl

In [7]:
import scanpy.api as sc
import itertools

### Read into scanpy

In [8]:
embedded_adata = sc.read('/Users/mincheolkim/Documents/nsnp20.raw.sng.km_vb1_default.pc60.norm.h5ad')

In [9]:
embedded_adata.obs['ko_gene_cov'] = embedded_adata.obs['guide_cov'].str.extract(r'^([^.]*).*')

### Transcription factors that survived the vargenes cut

In [192]:
tfs_considered = list(set(embedded_adata.var.index.tolist()) & set(embedded_adata.obs.ko_gene_cov.tolist()))

In [193]:
print(tfs_considered)                                                                                                                                                      

['NCOA3', 'ARID5B', 'GTF2I', 'IRF1', 'ID2', 'FOXP1', 'NCOA4', 'ZFP36L1', 'JUNB', 'XBP1', 'FLI1', 'STAT1', 'IRF2', 'STAT3', 'KLF6', 'MYC', 'SATB1', 'HOPX']


In [204]:
np.random.choice(embedded_adata.obs.ko_gene_cov.value_counts().index)

'DDB2'

### Correlation testing for each of these TFs

For each TF, find genes that correlate with it.

In [143]:
f
tfs_considered = ['IRF1', 'STAT1', 'STAT3', 'JUNB', 'IRF2']

In [145]:
for tf in tfs_considered:
    print(tf)

IRF1
STAT1
STAT3
JUNB
IRF2


In [146]:
hits = {}
for tf in tfs_considered:
    corrs = []
    pvals = []
    for gene in embedded_adata.var.index.tolist():
        corr, pval = pearsonr(
            x=embedded_adata.X[(embedded_adata.obs.guide_cov == "0").values, embedded_adata.var.index.tolist().index(tf)],
            y=embedded_adata.X[(embedded_adata.obs.guide_cov == "0").values, embedded_adata.var.index.tolist().index(gene)])
        corrs.append(corr)
        pvals.append(pval)
    reject, pvals_corrected, _, _ = multipletests(pvals, method='bonferroni')
    hits[tf] = embedded_adata.var.index[reject & (np.array(corrs) > 0)].tolist()

### For each TF pair, find the genes that they both are correlated with

In [147]:
tf_gene_overlaps = {}
tf_gene_overlap_counts = {}

In [148]:
for tf1, tf2 in itertools.combinations(hits.keys(), 2):
    overlap = set(hits[tf1]) & set(hits[tf2])4rbnm,ec\
    if len(overlap) > 0:
        tf_gene_overlaps[tf1 + '_' + tf2] = overlap
        tf_gene_overlap_counts[tf1 + '_' + tf2] = len(overlap)

In [149]:
tf_gene_overlap_counts

{'IRF1_STAT1': 71,
 'IRF1_STAT3': 46,
 'IRF1_JUNB': 47,
 'IRF1_IRF2': 10,
 'STAT1_STAT3': 19,
 'STAT1_JUNB': 13,
 'STAT1_IRF2': 19,
 'STAT3_JUNB': 35,
 'STAT3_IRF2': 2}

### For each TF pair, perform differential correlation analysis with target gene

In [150]:
def differential_correlation(g1, g2, label_1, label_2, label_col='ko_gene_cov'):
    corr_1, pval_1 = correlation(g1, g2, label_1, label_col)
    corr_2, pval_2 = correlation(g1, g2, label_2, label_col)
    
    n_1 = (embedded_adata.obs[label_col] == label_1).sum()
    n_2 = (embedded_adata.obs[label_col] == label_2).sum()
    
    return (np.arctanh(corr_1) - np.arctanh(corr_2))/(np.sqrt(np.absolute((1/n_1) - (1/n_2))))

In [151]:
def correlation(g1, g2, label, label_col='ko_gene_cov'):
    return pearsonr(
        x=embedded_adata.X[(embedded_adata.obs[label_col] == label).values, embedded_adata.var.index.tolist().index(g1)],
        y=embedded_adata.X[(embedded_adata.obs[label_col] == label).values, embedded_adata.var.index.tolist().index(g2)])

In [152]:
all_stats = {}
for pair, genes in tf_gene_overlaps.items():
    genes = tf_gene_overlaps[pair]
    tf1, tf2 = pair.split('_')
    all_stats[pair] = {}
    for gene in genes:
        if gene == tf1 or gene == tf2:
            continue
        all_stats[pair][gene] = (
            differential_correlation(gene, tf2, '0', tf1),
            differential_correlation(gene, tf1, '0', tf2))

### Generate null set of test statistics by shuffling labels

In [153]:
embedded_adata.obs['shuffled_ko_gene_cov'] = embedded_adata.obs['ko_gene_cov'].values[np.random.permutation(len(embedded_adata.obs))]

In [217]:
null_stats = []
for pair, genes in tf_gene_overlaps.items():
    genes = tf_gene_overlaps[pair]
    tf1, tf2 = pair.split('_')
    for gene in genes:
        null_stats.append(differential_correlation(gene, tf2, '0', tf1, label_col='shuffled_ko_gene_cov'))
        null_stats.append(differential_correlation(gene, tf1, '0', tf2, label_col='shuffled_ko_gene_cov'))
null_stats = np.array(null_stats)

In [224]:
final_pval_cutoff = 0.2

In [225]:
cutoff_stat = np.sort(null_stats)[-1*int(np.round(len(null_stats)*final_pval_cutoff))]

In [226]:
cutoff_stat

0.9925583274065771

### For each TF pair, keep the genes that are regulated by their complex

In [227]:
def emp_pval(val, null_vals):
    return (null_stats > val).mean()

In [228]:
complex_regulators = {}
for pair, gene_stats in all_stats.items():
    for gene, stats in gene_stats.items():
        s1, s2 = stats
        if s1 < 0 or s2 < 0:
            continue
        pval_1 = emp_pval(s1, null_stats)
        pval_2 = emp_pval(s2, null_stats)
        if pval_1 < final_pval_cutoff and pval_2 < final_pval_cutoff:
            if pair not in complex_regulators.keys():
                complex_regulators[pair] = []
            complex_regulators[pair].append(gene)

In [229]:
complex_regulators

{'IRF1_STAT1': ['RPS27', 'FHIT', 'RPL27A'],
 'IRF1_STAT3': ['CD27',
  'CD7',
  'RPL10',
  'MIR142',
  'HLA-A',
  'S100A11',
  'RPS26',
  'RPL32',
  'RPL18A',
  'RPS14',
  'S100A4',
  'RPL27A',
  'RPS29'],
 'IRF1_JUNB': ['EEF2', 'GNB2L1', 'RPL10', 'RPL32', 'RPS14', 'CORO1B', 'BIN1'],
 'STAT1_STAT3': ['HLA-A', 'RPL11'],
 'STAT1_JUNB': ['RPL32', 'STAT3'],
 'STAT1_IRF2': ['TNFSF13B'],
 'STAT3_JUNB': ['RPL10']}