# Test perturbation effects in donor subsets

In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr

import os,sys
import anndata as ad
import pandas as pd
import mudata as md
import glob
from tqdm import tqdm
from itertools import combinations
import itertools
import re

from copy import deepcopy

import time

# Add the parent directory to the path to import from sibling directory
sys.path.append(os.path.abspath('../'))
sys.path.append(os.path.abspath('../../'))
from MultiStatePerturbSeqDataset import *
from merge_DE_results import parse_DE_results_2_adata

In [2]:
def get_reliability(x, sigma):
    sigma2_obs = np.var(x, ddof=1)  # Observed variance across features
    sigma2_sem = np.mean(sigma**2)   # Average squared standard error
    sigma2_true = sigma2_obs - sigma2_sem  # Estimated true signal variance
    reliability = sigma2_true / (sigma2_true + sigma2_sem)
    return(reliability)

def get_max_correlation(x_a, sigma_a, x_b, sigma_b):
    """Calculate maximum possible correlation"""
    rel_a = get_reliability(x_a, sigma_a)
    rel_b = get_reliability(x_b, sigma_b)
    return np.sqrt(rel_a * rel_b)

def get_lfc_correlation(x_a, sigma_a, x_b, sigma_b):
    """Calculate maximum possible correlation"""
    corr = x_a.corr(x_b)
    corr_ceil = get_max_correlation(x_a, sigma_a, x_b, sigma_b)
    return corr, corr_ceil

def _run_DE_test(pbulk_adata, test_state = 'Rest'):
    pbulk_adata.obs['log10_n_cells'] = np.log10(pbulk_adata.obs['n_cells'])

    n_donors = pbulk_adata.obs['donor_id'].nunique()
    if n_donors > 1:
        design_formula = '~ log10_n_cells + donor_id + target'
    else:
        design_formula = '~ log10_n_cells + target'

    # pbulk_adata = pbulk_adata[:, de_test_genes].copy()
    min_counts_per_gene = 10
    
    ms_perturb_data = MultistatePerturbSeqDataset(
        pbulk_adata,
        sample_cols = ['cell_sample_id'],
        perturbation_type = 'CRISPRi',
        target_col = 'perturbed_gene_id',
        sgrna_col = 'guide_id',
        state_col = 'culture_condition',
        control_level = 'NTC'
        )

    results = ms_perturb_data.run_target_DE(
        design_formula = design_formula,
        test_state = [test_state],
        min_counts_per_gene = min_counts_per_gene,
        return_model = False,
        n_cpus=8
        )

    n_cells_target = ms_perturb_data.adata.obs.groupby('target')['n_cells'].sum().reset_index()
    results = pd.merge(results.rename({'contrast':'target'}, axis=1), n_cells_target)
    results['n_donors'] = n_donors
    results['donors'] = '_'.join(pbulk_adata.obs['donor_id'].unique().tolist())
    results['signif'] = results['adj_p_value'] < 0.1
    return(results)

In [3]:
experiment_name = 'CD4i_final'
datadir = '../../../../../3_expts/processed_data'

In [4]:
sgrna_library_metadata = pd.read_csv('../../../metadata/sgRNA_library_curated.csv', index_col=0)
gene_name_to_guide_id = dict(zip(sgrna_library_metadata['sgrna_id'], sgrna_library_metadata['perturbed_gene_name']))
gene_name_to_gene_id = dict(zip(sgrna_library_metadata['perturbed_gene_id'], sgrna_library_metadata['perturbed_gene_name']))
var_df = sc.read_h5ad(f'{datadir}/{experiment_name}/{experiment_name}_merged.DE_pseudobulk_corrected.h5ad', backed=True).var.copy()
de_counts = pd.read_csv(f'{datadir}/{experiment_name}/DE_results_all_confounders/DE_summary_stats_per_target.csv', index_col=0)
pbulk_adata = anndata.read_h5ad(f'{datadir}/{experiment_name}/{experiment_name}_merged.DE_pseudobulk_corrected.h5ad', backed=True)

In [6]:
#adata_de = sc.read_h5ad(datadir + '/CD4i_final/DE_results_all_confounders/CD4i_final.merged_DE_results_corrected.h5ad')
#de_by_guide = pd.read_csv('../results/DE_by_guide.correlation_results.csv', index_col=0)
de_summary_stats = pd.read_csv(datadir + f'/{experiment_name}/DE_results_all_confounders/DE_summary_stats_per_target_corrected.csv', index_col=0)

In [7]:
de_summary_stats

Unnamed: 0,target_contrast,target_name,condition,n_cells_target,n_up_genes,n_down_genes,n_total_de_genes,ontarget_effect_size,ontarget_significant,baseMean,offtarget_flag,n_total_genes_category,ontarget_effect_category,target_contrast_corrected,obs_names,target_name_corrected
0,ENSG00000012963,UBR7,Stim8hr,491.0,0,2,2,-12.952742,True,43.169196,True,2-10 DE genes,on-target KD,ENSG00000012963,ENSG00000012963_Stim8hr,UBR7
1,ENSG00000017260,ATP2C1,Stim8hr,469.0,0,1,1,-16.307246,True,102.399025,False,1 DE gene,on-target KD,ENSG00000017260,ENSG00000017260_Stim8hr,ATP2C1
2,ENSG00000067606,PRKCZ,Stim8hr,427.0,1,1,2,-1.658755,False,0.965897,False,2-10 DE genes,no on-target KD,ENSG00000067606,ENSG00000067606_Stim8hr,PRKCZ
3,ENSG00000092929,UNC13D,Stim8hr,830.0,0,2,2,-19.259466,True,60.904483,False,2-10 DE genes,on-target KD,ENSG00000092929,ENSG00000092929_Stim8hr,UNC13D
4,ENSG00000100504,PYGL,Stim8hr,414.0,1,0,1,0.000000,False,,False,1 DE gene,no on-target KD,ENSG00000100504,ENSG00000100504_Stim8hr,PYGL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
33981,ENSG00000198342,ZNF442,Rest,324.0,0,0,0,0.000000,False,,False,no effect,no on-target KD,ENSG00000198342,ENSG00000198342_Rest,ZNF442
33982,ENSG00000205572,SERF1B,Rest,44.0,7,24,31,0.000000,False,,False,>10 DE genes,no on-target KD,ENSG00000205572,ENSG00000205572_Rest,SERF1B
33983,ENSG00000221890,NPTXR,Rest,293.0,0,3,3,-1.027171,False,0.798092,False,2-10 DE genes,no on-target KD,ENSG00000221890,ENSG00000221890_Rest,NPTXR
33984,ENSG00000236320,SLFN14,Rest,848.0,0,0,0,0.000000,False,,False,no effect,no on-target KD,ENSG00000236320,ENSG00000236320_Rest,SLFN14


### Run Stim8hr

In [7]:
cond = 'Stim8hr'
mask = (de_summary_stats['condition']==cond)&(de_summary_stats['n_total_de_genes']>30)&(de_summary_stats['n_cells_target']>75)
selected_perturbed_genes = de_summary_stats[mask].target_name.tolist()
keep = (pbulk_adata.obs['perturbed_gene_name'].isin(selected_perturbed_genes + ['NTC'])) & (pbulk_adata.obs['culture_condition'] == cond)
pbulk_adata_test = pbulk_adata[keep].to_memory()
all_donors = pbulk_adata_test.obs['donor_id'].unique().tolist()

In [8]:
chunk_size = 50
all_targets = pbulk_adata_test.obs['perturbed_gene_id'].unique().tolist()
all_targets.remove('NTC')

# Randomize targets before splitting (without replacement)
np.random.seed(42)
np.random.shuffle(all_targets)

# Split all_targets into groups based on chunk_size
target_chunks = [all_targets[i:i+chunk_size] for i in range(0, len(all_targets), chunk_size)]

# Initialize a binary matrix with zeros
target_chunk_matrix = pd.DataFrame(0, 
                                index=all_targets, 
                                columns=[f'chunk_{i}' for i in range(len(target_chunks))])

# Fill the matrix with 1s for each target in its respective chunk
for chunk_idx, chunk in enumerate(target_chunks):
    target_chunk_matrix.loc[chunk, f'chunk_{chunk_idx}'] = 1

In [9]:
# Generate all possible combinations of 2 donors
donor_combinations = []
for train_donors in combinations(all_donors, 2):
    donor_combinations.append(
        list(train_donors)
    )

In [None]:
for chunk_ix in range(len(target_chunks)):
    test_targets = target_chunk_matrix.index[target_chunk_matrix[f'chunk_{chunk_ix}'] == 1].tolist()
    pbulk_adata_test_for_de = pbulk_adata_test[pbulk_adata_test.obs.perturbed_gene_id.isin(test_targets+['NTC'])].copy()
    donor_robustness_results = pd.DataFrame()
    all_results_df = pd.DataFrame()
    for train_ds in donor_combinations:
        # Run DE tests
        results_df = _run_DE_test(pbulk_adata_test_for_de[pbulk_adata_test_for_de.obs['donor_id'].isin(train_ds)], test_state=cond)
        all_results_df = pd.concat([all_results_df, results_df])
    all_results_df.to_csv(f"{datadir}/{experiment_name}/donor_robustness_tmp_results/donor_robustness_results.{cond}.chunk_{chunk_ix}.csv.gz", compression='gzip')

### Run Rest

In [8]:
cond = 'Rest'
mask = (de_summary_stats['condition']==cond)&(de_summary_stats['n_total_de_genes']>30)&(de_summary_stats['n_cells_target']>75)
selected_perturbed_genes = de_summary_stats[mask].target_name.tolist()
keep = (pbulk_adata.obs['perturbed_gene_name'].isin(selected_perturbed_genes + ['NTC'])) & (pbulk_adata.obs['culture_condition'] == cond)
pbulk_adata_test = pbulk_adata[keep].to_memory()
all_donors = pbulk_adata_test.obs['donor_id'].unique().tolist()

In [9]:
chunk_size = 50
all_targets = pbulk_adata_test.obs['perturbed_gene_id'].unique().tolist()
all_targets.remove('NTC')

# Randomize targets before splitting (without replacement)
np.random.seed(42)
np.random.shuffle(all_targets)

# Split all_targets into groups based on chunk_size
target_chunks = [all_targets[i:i+chunk_size] for i in range(0, len(all_targets), chunk_size)]

# Initialize a binary matrix with zeros
target_chunk_matrix = pd.DataFrame(0, 
                                index=all_targets, 
                                columns=[f'chunk_{i}' for i in range(len(target_chunks))])

# Fill the matrix with 1s for each target in its respective chunk
for chunk_idx, chunk in enumerate(target_chunks):
    target_chunk_matrix.loc[chunk, f'chunk_{chunk_idx}'] = 1

In [10]:
# Generate all possible combinations of 2 donors
donor_combinations = []
for train_donors in combinations(all_donors, 2):
    donor_combinations.append(
        list(train_donors)
    )

In [None]:
for chunk_ix in range(len(target_chunks)):
    test_targets = target_chunk_matrix.index[target_chunk_matrix[f'chunk_{chunk_ix}'] == 1].tolist()
    pbulk_adata_test_for_de = pbulk_adata_test[pbulk_adata_test.obs.perturbed_gene_id.isin(test_targets+['NTC'])].copy()
    donor_robustness_results = pd.DataFrame()
    all_results_df = pd.DataFrame()
    for train_ds in donor_combinations:
        # Run DE tests
        results_df = _run_DE_test(pbulk_adata_test_for_de[pbulk_adata_test_for_de.obs['donor_id'].isin(train_ds)], test_state=cond)
        all_results_df = pd.concat([all_results_df, results_df])
    all_results_df.to_csv(f"{datadir}/{experiment_name}/donor_robustness_tmp_results/donor_robustness_results.{cond}.chunk_{chunk_ix}.csv.gz", compression='gzip')

### Run Stim48hr

In [None]:
cond = 'Stim48hr'
mask = (de_summary_stats['condition']==cond)&(de_summary_stats['n_total_de_genes']>30)&(de_summary_stats['n_cells_target']>75)
selected_perturbed_genes = de_summary_stats[mask].target_name.tolist()
keep = (pbulk_adata.obs['perturbed_gene_name'].isin(selected_perturbed_genes + ['NTC'])) & (pbulk_adata.obs['culture_condition'] == cond)
pbulk_adata_test = pbulk_adata[keep].to_memory()
all_donors = pbulk_adata_test.obs['donor_id'].unique().tolist()

chunk_size = 50
all_targets = pbulk_adata_test.obs['perturbed_gene_id'].unique().tolist()
all_targets.remove('NTC')

# Randomize targets before splitting (without replacement)
np.random.seed(42)
np.random.shuffle(all_targets)

# Split all_targets into groups based on chunk_size
target_chunks = [all_targets[i:i+chunk_size] for i in range(0, len(all_targets), chunk_size)]

# Initialize a binary matrix with zeros
target_chunk_matrix = pd.DataFrame(0, 
                                index=all_targets, 
                                columns=[f'chunk_{i}' for i in range(len(target_chunks))])

# Fill the matrix with 1s for each target in its respective chunk
for chunk_idx, chunk in enumerate(target_chunks):
    target_chunk_matrix.loc[chunk, f'chunk_{chunk_idx}'] = 1

# Generate all possible combinations of 2 donors
donor_combinations = []
for train_donors in combinations(all_donors, 2):
    donor_combinations.append(
        list(train_donors)
    )

In [None]:
for chunk_ix in range(len(target_chunks)):
    test_targets = target_chunk_matrix.index[target_chunk_matrix[f'chunk_{chunk_ix}'] == 1].tolist()
    pbulk_adata_test_for_de = pbulk_adata_test[pbulk_adata_test.obs.perturbed_gene_id.isin(test_targets+['NTC'])].copy()
    donor_robustness_results = pd.DataFrame()
    all_results_df = pd.DataFrame()
    for train_ds in donor_combinations:
        # Run DE tests
        results_df = _run_DE_test(pbulk_adata_test_for_de[pbulk_adata_test_for_de.obs['donor_id'].isin(train_ds)], test_state=cond)
        all_results_df = pd.concat([all_results_df, results_df])
    all_results_df.to_csv(f"{datadir}/{experiment_name}/donor_robustness_tmp_results/donor_robustness_results.{cond}.chunk_{chunk_ix}.csv.gz", compression='gzip')

### Parse results

In [None]:
de_results_adatas = []
de_results_files = glob.glob(f'{datadir}/{experiment_name}/donor_robustness_tmp_results/donor_robustness_results.*.csv.gz')
for file in tqdm(de_results_files, desc="Processing DE result files"):
    # try:
    df = pd.read_csv(file, compression='gzip', index_col=0)
    df = df.rename({'target': 'target_contrast'}, axis=1)
    df['target_contrast_gene_name'] = df['target_contrast'].map(lambda x: gene_name_to_gene_id.get(x, x))
    donor_pairs = df['donors'].unique()
    for donor_pair in donor_pairs:
        temp_adata = parse_DE_results_2_adata(df[df['donors']==donor_pair])
        temp_adata.obs['donor_pair'] = donor_pair
        temp_adata.obs['chunk'] = int(re.search(r'chunk_(\d+)', file).group(1))
        de_results_adatas.append(temp_adata)
    # except EOFError:
    #     continue
combined_de_adata = anndata.concat(de_results_adatas)
combined_de_adata.obs_names = [item1+'_'+item2 for item1, item2 in zip(combined_de_adata.obs_names, combined_de_adata.obs['donor_pair'])]
combined_de_adata.var['gene_ids'] = combined_de_adata.var_names
combined_de_adata.var['gene_name'] = var_df.loc[combined_de_adata.var_names, 'gene_name']
assert combined_de_adata.obs_names.is_unique

In [None]:
# Resolve duplicated donor pairs
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0008162_CE0006864', 'donor_pair'] = 'CE0006864_CE0008162'
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0010866_CE0006864', 'donor_pair'] = 'CE0006864_CE0010866'
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0008678_CE0006864', 'donor_pair'] = 'CE0006864_CE0008678'
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0010866_CE0008162', 'donor_pair'] = 'CE0008162_CE0010866'
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0008678_CE0008162', 'donor_pair'] = 'CE0008162_CE0008678'
combined_de_adata.obs.loc[combined_de_adata.obs['donor_pair']=='CE0010866_CE0008678', 'donor_pair'] = 'CE0008678_CE0010866'

In [None]:
combined_de_adata.write_h5ad(f'{datadir}/{experiment_name}/{experiment_name}.DE_donor_robustness.h5ad')

In [None]:
conditions = ['Rest', 'Stim8hr', 'Stim48hr']
all_correlations = pd.DataFrame()
for cond in conditions:
    de_results_files = glob.glob(f'{datadir}/{experiment_name}/donor_robustness_tmp_results/donor_robustness_results.{cond}.*.csv.gz')
    for file in de_results_files:
        
        all_results_df = pd.read_csv(file, index_col=0)
        donor_pairs = all_results_df['donors'].unique()
        comparison_pairs = []
        for pair1, pair2 in itertools.combinations(donor_pairs, 2):
            donors1 = set(pair1.split('_'))
            donors2 = set(pair2.split('_'))
            if donors1.isdisjoint(donors2):
                comparison_pairs.append((pair1, pair2))
        
        pivoted = all_results_df.pivot(columns='donors', index=['target','variable'], values='log_fc')
        pivoted_se = all_results_df.pivot(columns='donors', index=['target','variable'], values='lfcSE')
        pivoted_pval = all_results_df.pivot(columns='donors', index=['target','variable'], values='adj_p_value')
            
        correlations = pd.DataFrame()
        for target in tqdm(pivoted.index.get_level_values('target').unique()):
            target_data = pivoted.loc[target]
            target_se = pivoted_se.loc[target]
            target_p = pivoted_pval.loc[target]
                
            # Calculate correlations and F1 scores for specified pairs
            pair_correlations = []
            for pair in comparison_pairs:
                if pair[0] in target_data.columns and pair[1] in target_data.columns:
                    # Get mask for genes with p < 0.1 in either test
                    sig_mask = (target_p[pair[0]] < 0.2) | (target_p[pair[1]] < 0.2)
                    # sig_mask = (np.abs(target_data[pair[0]]) > 0.1) | (np.abs(target_data[pair[1]]) > 0.1)
                    
                    # Filter data using mask
                    data1 = target_data[pair[0]][sig_mask]
                    data2 = target_data[pair[1]][sig_mask]
                    se1 = target_se[pair[0]][sig_mask]
                    se2 = target_se[pair[1]][sig_mask]
                            
                    corr, corr_ceil = get_lfc_correlation(data1, se1, data2, se2)
            
                    pair_correlations.append({
                        'target': target,
                        'donors': pair[0],
                        'variable': pair[1],
                        'correlation': corr,
                        'correlation_ceiling': corr_ceil,
                        'n_signif': data1.shape[0],
                        'culture_condition': cond
                    })
            correlations = pd.concat([correlations, pd.DataFrame(pair_correlations)])
        
        correlations = pd.merge(correlations.rename({'target':'target_contrast'}, axis=1), de_counts[de_counts['condition']==cond], on='target_contrast', how='left')
        all_correlations = pd.concat([all_correlations, correlations])

In [None]:
all_correlations

In [None]:
all_correlations.to_csv('../results/DE_donor_robustness_correlation.csv')

In [None]:
all_correlations = pd.read_csv('../results/DE_donor_robustness_correlation.csv', index_col=0)

In [None]:
all_correlations

In [None]:
de_counts

In [None]:
all_correlations = pd.read_csv('../results/DE_donor_robustness_correlation.csv')

donor_correlation_summary_rest = pd.DataFrame({'donor_correlation_mean': all_correlations[all_correlations.culture_condition=='Rest'].groupby('target_name')['correlation'].mean(),
                                               'donor_correlation_min': all_correlations[all_correlations.culture_condition=='Rest'].groupby('target_name')['correlation'].min()})
donor_correlation_summary_rest = pd.merge(donor_correlation_summary_rest, de_counts[de_counts.condition=='Rest'], on='target_name', how='left')

donor_correlation_summary_stim8hr = pd.DataFrame({'donor_correlation_mean': all_correlations[all_correlations.culture_condition=='Stim8hr'].groupby('target_name')['correlation'].mean(),
                                               'donor_correlation_min': all_correlations[all_correlations.culture_condition=='Stim8hr'].groupby('target_name')['correlation'].min()})
donor_correlation_summary_stim8hr = pd.merge(donor_correlation_summary_stim8hr, de_counts[de_counts.condition=='Stim8hr'], on='target_name', how='left')

donor_correlation_summary_stim48hr = pd.DataFrame({'donor_correlation_mean': all_correlations[all_correlations.culture_condition=='Stim48hr'].groupby('target_name')['correlation'].mean(),
                                               'donor_correlation_min': all_correlations[all_correlations.culture_condition=='Stim48hr'].groupby('target_name')['correlation'].min()})
donor_correlation_summary_stim48hr = pd.merge(donor_correlation_summary_stim48hr, de_counts[de_counts.condition=='Stim48hr'], on='target_name', how='left')

donor_correlation_summary = pd.concat([donor_correlation_summary_rest, donor_correlation_summary_stim8hr, donor_correlation_summary_stim48hr])

donor_correlation_summary.to_csv('../results/DE_donor_robustness_correlation_summary.csv')