# For a given (core) gene, obtain regulators of the gene and their burden effects

In [1]:
import scanpy as sc
import numpy as np
import pandas as pd

pd.set_option('display.float_format', lambda x: f"{x:.2e}" if abs(x) < 0.001 else f"{x:.3f}")
pd.set_option('display.max_rows', 200)

trait = "LymphocyteCount"

In [2]:
datadir = '/mnt/oak/users/emma/data/GWT/CD4i_final/'
experiment_name = 'CD4i_final'
adata_de = sc.read_h5ad(datadir + f'/DE_results_all_confounders/{experiment_name}.merged_DE_results.h5ad')

adata_de.layers['zscore'] = adata_de.layers['log_fc'] / adata_de.layers['lfcSE']
adata_de.layers['zscore'][np.where(adata_de.layers['zscore'] > 50)] = 50
for cond in adata_de.obs.culture_condition.unique():
    affected_gs_cond = np.sum(adata_de[adata_de.obs['culture_condition'] == cond].layers['adj_p_value'] < 0.1, axis=0)
    adata_de.var[f'n_signif_effects_{cond}'] = affected_gs_cond
    # affected_gs_cond = np.sum(adata_de[adata_de.obs['culture_condition'] == cond].layers['MASH_lfsr'] < 0.05, axis=0)
    # adata_de.var[f'n_mash_signif_effects_{cond}'] = affected_gs_cond

adata_de.var_names = adata_de.var['gene_name'].values
adata_de

AnnData object with n_obs × n_vars = 33986 × 13959
    obs: 'target_contrast_gene_name', 'culture_condition', 'target_contrast', 'chunk', 'n_cells_target'
    var: 'gene_ids', 'gene_name', 'n_signif_effects_Stim8hr', 'n_signif_effects_Stim48hr', 'n_signif_effects_Rest'
    layers: 'adj_p_value', 'baseMean', 'lfcSE', 'log_fc', 'p_value', 'zscore'

In [3]:
burden_file = f"/mnt/oak/users/mineto/workstation/250717_TcellPerturb/input/burden/Backman_{trait}_fullFeatures.per_gene_estimates.tsv"
burden_df = pd.read_csv(burden_file, sep="\t", header=0)

corresp_file = "/mnt/oak/users/mineto/workstation/GRN/gencode_v41_gname_gid_ALL_sorted_onlyID"
corresp_df = pd.read_csv(corresp_file, sep="\t", header=None, names=['ensg', 'gene'])
corresp_df = corresp_df.drop_duplicates(subset=['gene']).set_index('gene')

corresp_df = corresp_df.reset_index()  # bring 'gene' out of index if needed

# if multiple symbols per ENSG, keep the first (change to .agg('|'.join) if you want all)
ensg_to_gene = (corresp_df[["ensg", "gene"]]
                .dropna()
                .drop_duplicates()
                .groupby("ensg", as_index=True)["gene"]
                .first())

burden_df["gene"] = burden_df["ensg"].map(ensg_to_gene)
burden_df.head()

Unnamed: 0,ensg,param0,param1,param2,prior_mean,post_mean,lower_95,upper_95,gene
0,ENSG00000131737,-1.35,1.351,40.602,-0.019,-0.019,-0.071,0.036,KRT34
1,ENSG00000107779,1.282,1.103,48.588,0.013,0.011,-0.037,0.056,BMPR1A
2,ENSG00000168016,0.679,0.839,82.535,0.003,0.003,-0.015,0.023,TRANK1
3,ENSG00000187566,1.119,1.049,51.739,0.01,0.021,-0.015,0.074,NHLRC1
4,ENSG00000109436,2.119,1.82,48.637,0.029,0.022,-0.034,0.067,TBC1D9


In [4]:
def get_significant_regulators(adata_de, burden_df, target_gene, condition, adj_p_threshold=0.05):
    """
    Get all significant regulators of a target gene under a given condition.
    
    Parameters:
    -----------
    adata_de : AnnData
        Annotated data matrix with perturbation data
    burden_df : pandas.DataFrame
        DataFrame with burden effect data
    target_gene : str
        Name of the target gene to find regulators for
    condition : str
        Culture condition to filter by (e.g., 'Stim8hr', 'Stim48hr', 'Rest')
    adj_p_threshold : float, default=0.05
        Adjusted p-value threshold for significance
        
    Returns:
    --------
    pandas.DataFrame
        DataFrame with significant regulators and their statistics
    """
    # Filter data by condition
    adata_cond = adata_de[adata_de.obs.culture_condition == condition]
    
    # Get adjusted p-values for the target gene
    adj_p_values = sc.get.obs_df(adata_cond, 
                                 target_gene, 
                                 layer='adj_p_value')
    adj_p_values.set_index(adata_cond.obs.target_contrast_gene_name, inplace=True)
    
    # Get log fold changes for the target gene
    log_fc = sc.get.obs_df(adata_cond, 
                                 target_gene, 
                                 layer='log_fc')
    log_fc.set_index(adata_cond.obs.target_contrast_gene_name, inplace=True)
    
    # Get z-scores for the target gene
    zscores = sc.get.obs_df(adata_cond, 
                            target_gene, 
                            layer='zscore')
    zscores.set_index(adata_cond.obs.target_contrast_gene_name, inplace=True)
    
    # Create results DataFrame
    results = pd.DataFrame({
        'gene': adata_cond.obs.target_contrast_gene_name,
        'adj_p_value': adj_p_values[target_gene].values,
        'log_fold_change': log_fc[target_gene].values,
        # 'zscore': zscores[target_gene].values
    })
    
    # Filter for significant regulators
    significant = results[results['adj_p_value'] < adj_p_threshold].copy()
    
    # merge with burden data
    significant = pd.merge(significant, burden_df[['gene', 'post_mean', 'lower_95', 'upper_95']], on='gene', how='inner')
    
    significant.rename(columns={'adj_p_value': 'perturb_p', 'log_fold_change': 'perturb_lfc', 'post_mean': 'burden', 'lower_95': 'burden_L95', 'upper_95': 'burden_U95'}, inplace=True)
    
    # Sort by burden effect
    significant = significant.sort_values('burden')

    return significant


def compare_gene_regulators(adata_de, burden_df, target_gene1, target_gene2, condition1, condition2, adj_p_threshold=0.05):
    """
    Compare regulators between two target genes under different conditions and find overlapping vs. specific regulators.
    
    Parameters:
    -----------
    adata_de : AnnData
        Annotated data matrix with perturbation data
    burden_df : pandas.DataFrame
        DataFrame with burden effect data
    target_gene1 : str
        Name of the first target gene
    target_gene2 : str
        Name of the second target gene
    condition1 : str
        Culture condition for the first target gene (e.g., 'Stim8hr', 'Stim48hr', 'Rest')
    condition2 : str
        Culture condition for the second target gene (e.g., 'Stim8hr', 'Stim48hr', 'Rest')
    adj_p_threshold : float, default=0.05
        Adjusted p-value threshold for significance
        
    Returns:
    --------
    dict
        Dictionary containing overlapping and specific regulators
    """
    # Get significant regulators for both genes under their respective conditions
    regulators_gene1 = get_significant_regulators(adata_de, burden_df, target_gene1, condition1, adj_p_threshold)
    regulators_gene2 = get_significant_regulators(adata_de, burden_df, target_gene2, condition2, adj_p_threshold)
    
    if regulators_gene1 is None or regulators_gene2 is None:
        print("Could not find regulators for one or both genes")
        return None
    
    # Get sets of regulator gene names
    regulators_1_set = set(regulators_gene1['gene'])
    regulators_2_set = set(regulators_gene2['gene'])
    
    # Find overlapping and specific regulators
    overlapping_regulators = regulators_1_set.intersection(regulators_2_set)
    specific_to_gene1 = regulators_1_set - regulators_2_set
    specific_to_gene2 = regulators_2_set - regulators_1_set
    
    # Create detailed results for overlapping regulators
    overlapping_results = []
    for reg in overlapping_regulators:
        # Get data for both target genes
        reg_data_1 = regulators_gene1[regulators_gene1['gene'] == reg].iloc[0]
        reg_data_2 = regulators_gene2[regulators_gene2['gene'] == reg].iloc[0]
        
        overlapping_results.append({
            'gene': reg,
            f'{target_gene1}_{condition1}_perturb_p': reg_data_1['perturb_p'],
            f'{target_gene1}_{condition1}_perturb_lfc': reg_data_1['perturb_lfc'],
            f'{target_gene2}_{condition2}_perturb_p': reg_data_2['perturb_p'],
            f'{target_gene2}_{condition2}_perturb_lfc': reg_data_2['perturb_lfc'],
            f'burden': reg_data_2['burden'],
            f'burden_L95': reg_data_2['burden_L95'],
            f'burden_U95': reg_data_2['burden_U95']
        })
    
    overlapping_df = pd.DataFrame(overlapping_results)
    overlapping_df = overlapping_df.sort_values('burden')

    # Get specific regulators for each gene
    specific_gene1_df = regulators_gene1[regulators_gene1['gene'].isin(specific_to_gene1)]
    specific_gene1_df = specific_gene1_df.sort_values('burden')
    specific_gene2_df = regulators_gene2[regulators_gene2['gene'].isin(specific_to_gene2)]
    specific_gene2_df = specific_gene2_df.sort_values('burden')
    
    # Sort by perturbation p-value
    if len(overlapping_df) > 0:
        overlapping_df = overlapping_df.sort_values('burden')
    specific_gene1_df = specific_gene1_df.sort_values('burden')
    specific_gene2_df = specific_gene2_df.sort_values('burden')
    
    return {
        'overlapping_regulators': overlapping_df,
        'specific_to_gene1': specific_gene1_df,
        'specific_to_gene2': specific_gene2_df,
        'summary': {
            'total_regulators_gene1': len(regulators_gene1),
            'total_regulators_gene2': len(regulators_gene2),
            'overlapping_count': len(overlapping_regulators),
            'specific_to_gene1_count': len(specific_to_gene1),
            'specific_to_gene2_count': len(specific_to_gene2)
        }
    }

In [7]:
cond = 'Stim8hr'
target_gene = "TBC1D1"

# First print burden effect of target gene of interest
print(f"Burden effect of target gene ({target_gene}): {burden_df[burden_df['gene'] == target_gene]['post_mean'].values[0]:.3f}, 95% CI: [{burden_df[burden_df['gene'] == target_gene]['lower_95'].values[0]:.3f}, {burden_df[burden_df['gene'] == target_gene]['upper_95'].values[0]:.3f}]")

# Get significant regulators of gene "GENE1" under "Stim8hr" condition
significant_regulators = get_significant_regulators(
    adata_de, 
    burden_df,
    target_gene=target_gene, 
    condition=cond, 
    adj_p_threshold=0.05
)

if significant_regulators is not None:
    print(f"Found {len(significant_regulators)} significant regulators:")
    print(significant_regulators)

Burden effect of target gene (TBC1D1): -0.007, 95% CI: [-0.032, 0.015]
Found 88 significant regulators:
        gene  perturb_p  perturb_lfc    burden  burden_L95  burden_U95
67     PLEK2   2.98e-06        0.759    -0.076      -0.148      -0.022
35    YEATS2      0.049       -0.881    -0.065      -0.140      -0.013
78     XRRA1      0.002        0.598    -0.064      -0.112      -0.026
58     APPL2   5.70e-05        0.811    -0.050      -0.091      -0.018
27    KIF16B   1.64e-04        0.700    -0.042      -0.090      -0.008
23    GTF3C6      0.048        0.837    -0.040      -0.102       0.022
29     ARID2      0.049       -0.561    -0.039      -0.108       0.013
81     MALT1      0.043        0.468    -0.033      -0.095       0.018
22      LEO1   6.34e-04        0.692    -0.033      -0.147       0.140
84       LAT   8.04e-28        1.652    -0.032      -0.100       0.029
72  ARHGAP27   7.04e-06        0.885    -0.023      -0.067       0.012
9     DOLPP1   6.02e-05        0.540    -0.0

In [6]:
gene1 = "IL18R1"
gene2 = "TBC1D1"
condition1 = "Stim48hr"
condition2 = "Stim8hr"

comparison_results = compare_gene_regulators(
    adata_de, 
    burden_df,
    target_gene1=gene1, 
    target_gene2=gene2, 
    condition1=condition1, 
    condition2=condition2, 
    adj_p_threshold=0.05
)

print(f"=== REGULATOR COMPARISON SUMMARY FOR {gene1} UNDER {condition1} AND {gene2} UNDER {condition2} ===")
print(f"Overlapping regulators: \n{comparison_results['overlapping_regulators']}\n")
print(f"Specific to {gene1}, {condition1}: \n{comparison_results['specific_to_gene1']}\n")
print(f"Specific to {gene2}, {condition2}: \n{comparison_results['specific_to_gene2']}\n")

=== REGULATOR COMPARISON SUMMARY FOR IL18R1 UNDER Stim48hr AND TBC1D1 UNDER Stim8hr ===
Overlapping regulators: 
        gene  IL18R1_Stim48hr_perturb_p  IL18R1_Stim48hr_perturb_lfc  \
17     MALT1                      0.048                        0.517   
20      LEO1                      0.048                        0.511   
13       LAT                   1.03e-20                        1.833   
30  ARHGAP27                      0.017                        0.744   
22      TMX1                      0.025                        0.474   
18      CD3G                   7.21e-07                        1.392   
0      STAT3                   9.05e-08                        0.997   
12      CD3D                   9.11e-09                        1.883   
25       GPI                      0.010                        0.567   
16    NKAPD1                      0.015                        0.684   
26      CD3E                   3.74e-06                        1.592   
3      BCAT2           