In [18]:
%load_ext autoreload
%autoreload 2

In [1]:
import json
import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.stats.multitest import multipletests
import networkx as nx
from collections import defaultdict
import itertools
from scipy import stats as sps
import json
import os
from src.utils.analysis.epistasis_simulation import simulate_epistasis, build_hierarchical_ontology



# -- Step 1: Generating HIGH-POWER Synthetic Data ---"

## --- Part 1: Generating synthetic genotype and hierarchy ---

In [2]:
# hyperparamters

output_dir = "./epistasis_simulation_samples/"
seed = 42
n_samples = 10000
n_pairs = 50
h2_epistatic = 0.1
ontology_coherence = 0.5
n_additive = 100

### 1. Running simulation with parameters

In [3]:
sim = simulate_epistasis(
    n_samples=n_samples,
    n_snps=500,
    seed=seed,
    n_additive=n_additive,
    h2_additive=0.5,
    n_pairs=n_pairs,
    h2_epistatic=h2_epistatic,
)

### 2. Building hierarchical ontology..

In [4]:
snp_df, gene_df, system_df = build_hierarchical_ontology(
    sim,
    n_genes=250,
    n_systems=50,
    ontology_coherence=ontology_coherence,
    n_causal_systems=20,
    causal_system_enrichment=20.0,
    seed=seed
)


Pre-calculating system pair pools for epistasis placement...
  - Found 50 same-system pairs.
  - Found 649 related-system pairs.
  - Found 576 distant-system pairs.
Assigning 50 epistatic pairs with ontology_coherence = 0.5...
LD structure found. Assigning chromosomes, positions, and blocks based on LD.


### 3. Creating random covariates and phenotypes, and save genotypes

In [5]:
# Create consistent sample IDs
iids = [f'sample_{i}' for i in range(sim['y'].shape[0])]

# Save genotypes with IID as index and SNP IDs as columns
genotypes_df = pd.DataFrame(sim['G'], columns=snp_df.index.values)
genotypes_df.index = iids
genotypes_df.index.name = 'IID'
genotypes_df.to_csv(f"{output_dir}/genotypes.tsv", sep='\t')

# Save phenotypes and covariates with the same IIDs
pheno_df = pd.DataFrame({'FID': iids, 'IID': iids, 'phenotype': sim['y']})
pheno_df.to_csv(f"{output_dir}/simulation.pheno", index=False, sep='\t')

cov_df = pd.DataFrame({
    'FID': iids,
    'IID': iids,
    'SEX': np.random.randint(2, size=sim['y'].shape[0]),
    'AGE': np.random.randint(40, 70, size=sim['y'].shape[0])
})
cov_df.to_csv(f"{output_dir}/simulation.cov", index=False, sep='\t')

### Saving mappings

In [6]:
snp_df.to_csv(f"{output_dir}/snp2gene.tsv", index=True, sep='\t')

# Combine gene->system and system->supersystem for the model's expected format
gene_df = gene_df.rename(columns={'gene_id': 'child', 'system_id': 'parent'})
gene_df['interaction'] = 'gene'
system_df = system_df.rename(columns={'system_id': 'child', 'supersystem_id': 'parent'})
system_df['interaction'] = 'default'

ontology_df = pd.concat([
    gene_df[['parent', 'child', 'interaction']],
    system_df[['parent', 'child', 'interaction']]
])
ontology_df.to_csv(f"{output_dir}/ontology.tsv", index=False, sep='\t', header=False)

# Causal Info for Evaluation
causal_info = {
    'epistatic_pairs': [list(map(int, p)) for p in sim['pair_idx']],
    'additive_snps': list(map(int, sim['additive_idx']))
}
with open(f"{output_dir}/causal_info.json", 'w') as f:
    json.dump(causal_info, f, indent=2)

print("Synthetic data generation complete.")


Synthetic data generation complete.


## --- Part 2: Statistical Sanity Check ---

### 1. Load Ground Truth and Data

In [7]:
def _maf_from_dosage(g):
    g = np.asarray(g, float)
    g = g[~np.isnan(g)]
    if g.size == 0: return np.nan
    p = np.clip(g.mean() / 2.0, 0.0, 1.0)
    return min(p, 1 - p)

def _generate_penetrance_table_from_df(df, snp_cols, phenocol):
    k = len(snp_cols)
    pen = {gt: [] for gt in itertools.product(range(3), repeat=k)}
    gmat = df[snp_cols].astype('float')
    y = df[phenocol].values
    for idx in range(len(df)):
        gt_tuple = tuple(int(gmat.iloc[idx, c]) if pd.notna(gmat.iloc[idx, c]) else -1 for c in range(k)) 
        if -1 in gt_tuple: continue
        pen[gt_tuple].append(y[idx])
    return pen

def _one_way_anova_from_ptable(pen_table):
    non_empty = [phens for phens in pen_table.values() if len(phens) > 0]
    if len(non_empty) < 2: return np.nan
    _, p = sps.f_oneway(*non_empty)
    return float(p)


In [8]:
causal_info_path = f"{output_dir}/causal_info.json"
with open(causal_info_path, 'r') as f:
    causal_info = json.load(f)
true_epistatic_pairs = [tuple(p) for p in causal_info['epistatic_pairs']]
true_additive_snps = causal_info.get('additive_snps', [])
true_epistatic_snps = [snp for pair in true_epistatic_pairs for snp in pair]

genotypes = pd.read_csv(f"{output_dir}/genotypes.tsv", sep='\t', index_col='IID')
genotypes.columns = genotypes.columns.astype(str)
cov_pheno = pd.read_csv(f"{output_dir}/simulation.cov", sep='\t')
pheno = pd.read_csv(f"{output_dir}/simulation.pheno", sep='\t')
cov_pheno = cov_pheno.merge(pheno, on=['FID', 'IID'])
cov_pheno = cov_pheno.set_index('IID')
df_full = pd.concat([genotypes, cov_pheno], axis=1)

### 2. Comprehensive Additive Effect Check for ALL SNPs

In [9]:
all_snps = genotypes.columns.tolist()
results_list = []
for snp in all_snps:
    df_snp = df_full[['phenotype', 'SEX', 'AGE', snp]].dropna()
    X = sm.add_constant(df_snp[['SEX', 'AGE', snp]])
    y = df_snp['phenotype']
    model = sm.OLS(y, X).fit()
    results_list.append({
        'SNP_ID': snp,
        'Coefficient': model.params[snp],
        'P_Value': model.pvalues[snp],
        'MAF': _maf_from_dosage(df_snp[snp]),
        'Is_True_Additive': int(snp) in true_additive_snps,
        'Is_In_Epistatic_Pair': int(snp) in true_epistatic_snps
    })

In [10]:
results_df = pd.DataFrame(results_list).sort_values(by='P_Value').reset_index(drop=True)
pd.set_option('display.max_rows', len(results_df) + 10)
pd.set_option('display.width', 120)

In [11]:
results_df.head()

Unnamed: 0,SNP_ID,Coefficient,P_Value,MAF,Is_True_Additive,Is_In_Epistatic_Pair
0,237,-0.215986,9.857912000000001e-52,0.40345,True,False
1,28,0.300011,1.2113010000000001e-43,0.1199,True,False
2,238,-0.392271,1.348198e-40,0.05975,True,False
3,27,0.361018,9.330822000000001e-39,0.06875,True,False
4,74,0.349903,5.103858999999999e-38,0.0718,True,False


### 3. Interaction check for true epistatic pairs

In [12]:
if not true_epistatic_pairs:
    print("No true epistatic pairs to evaluate.")
else:
    interaction_results = []
    for snp1, snp2 in true_epistatic_pairs:
        s1, s2 = str(snp1), str(snp2)
        
        # ANOVA test
        p_val_anova, mafs = np.nan, (np.nan, np.nan)
        if s1 in df_full.columns and s2 in df_full.columns:
            df_pair = df_full[[s1, s2, 'phenotype']].dropna()
            mafs = (_maf_from_dosage(df_pair[s1].values), _maf_from_dosage(df_pair[s2].values))
            pen_table = _generate_penetrance_table_from_df(df_pair, [s1, s2], 'phenotype')
            p_val_anova = _one_way_anova_from_ptable(pen_table)
        
        # Linear model interaction test
        p_val_linear = np.nan
        if s1 in df_full.columns and s2 in df_full.columns:
            df_linear = df_full[['phenotype', 'SEX', 'AGE', s1, s2]].dropna()
            df_linear['interaction'] = df_linear[s1] * df_linear[s2]
            X = sm.add_constant(df_linear[['SEX', 'AGE', s1, s2, 'interaction']])
            y = df_linear['phenotype']
            model = sm.OLS(y, X).fit()
            p_val_linear = model.pvalues['interaction']

        interaction_results.append({
            'snp1': snp1, 'snp2': snp2, 
            'p_value_anova': p_val_anova,
            'p_value_linear': p_val_linear,
            'maf1': mafs[0], 'maf2': mafs[1],
        })

    interaction_df = pd.DataFrame(interaction_results)
    
    # Calculate FDR for both methods
    pvals_anova = interaction_df['p_value_anova'].dropna().to_numpy(dtype=float)
    if len(pvals_anova) > 0:
        reject_anova, fdr_anova, _, _ = multipletests(pvals_anova, alpha=0.05, method='fdr_bh')
        interaction_df.loc[interaction_df['p_value_anova'].notna(), 'p_value_anova_fdr'] = fdr_anova
    else:
        reject_anova = np.array([])
        interaction_df['p_value_anova_fdr'] = np.nan

    pvals_linear = interaction_df['p_value_linear'].dropna().to_numpy(dtype=float)
    if len(pvals_linear) > 0:
        reject_linear, fdr_linear, _, _ = multipletests(pvals_linear, alpha=0.05, method='fdr_bh')
        interaction_df.loc[interaction_df['p_value_linear'].notna(), 'p_value_linear_fdr'] = fdr_linear
    else:
        reject_linear = np.array([])
        interaction_df['p_value_linear_fdr'] = np.nan

    interaction_df = interaction_df.sort_values(by='p_value_linear').reset_index(drop=True)

    print("\n--- Full Report for True Epistatic Pairs (Interaction Tests) ---")
    #print(interaction_df[['snp1', 'snp2', 'maf1', 'maf2', 'p_value_anova', 'p_value_anova_fdr', 'p_value_linear', 'p_value_linear_fdr']])
    print(f"\nSummary (ANOVA): Found {reject_anova.sum()} / {len(true_epistatic_pairs)} pairs to be significant (FDR < 0.05).")
    print(f"Summary (Linear Model): Found {reject_linear.sum()} / {len(true_epistatic_pairs)} pairs to be significant (FDR < 0.05).")
    print("-----------------------------------------------------------------")



--- Full Report for True Epistatic Pairs (Interaction Tests) ---

Summary (ANOVA): Found 50 / 50 pairs to be significant (FDR < 0.05).
Summary (Linear Model): Found 49 / 50 pairs to be significant (FDR < 0.05).
-----------------------------------------------------------------


In [13]:
interaction_df.head()

Unnamed: 0,snp1,snp2,p_value_anova,p_value_linear,maf1,maf2,p_value_anova_fdr,p_value_linear_fdr
0,453,212,2.26457e-24,2.087147e-14,0.1258,0.14115,3.774283e-23,1.043573e-12
1,208,346,1.405397e-10,8.352445e-12,0.10765,0.4783,5.019276e-10,2.088111e-10
2,86,427,8.580013e-43,1.512099e-11,0.37425,0.4288,4.290007e-41,2.520165e-10
3,263,498,3.200972e-12,5.623156e-10,0.1141,0.2081,1.778318e-11,7.028946e-09
4,255,67,7.145301e-10,7.218037e-10,0.39205,0.1854,2.381767e-09,7.218037e-09


# Step 2: Training model with generated data

In [14]:
!python train_snp2p_model.py \
    --train-tsv "epistasis_simulation_samples/genotypes.tsv" \
    --train-pheno "epistasis_simulation_samples/simulation.pheno" \
    --train-cov "epistasis_simulation_samples/simulation.cov" \
    --onto "epistasis_simulation_samples/ontology.tsv" \
    --snp2gene "epistasis_simulation_samples/snp2gene.tsv" \
    --out "epistasis_simulation_samples/output_model.txt" \
    --epochs 51 \
    --batch-size 64 \
    --lr 1e-4 \
    --qt "phenotype" \
    --jobs 4 \
    --cuda 0 \
    --sys2env --env2sys --sys2gene \
    --sys2pheno --gene2pheno \
    --val-step 50 \
    --use_hierarchical_transformer

Python __main__
Start Process
DDP setup done
[0/1] running on nrnb-gpu-06 GPU 0, rank: 0, local_rank: 0
Finish setup main worker 0
Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
   snp   gene  chr   pos  block
0    0  G0103    7   415      4
1    1  G0243    7  1267      4
2    2  G0061    7  1310      4
3    3  G0020    7  1970      4
4    4  G0145    7  2419      4
The number of SNPs: 500
Loading TSV data from epistasis_simulation_samples/genotypes.tsv
Processing Covariates...
Loading Covariate file at epistasis_simulation_samples/simulation.cov
Processing Phenotypes
Loading Phenotype file at epistasis_simulation_samples/simulation.pheno
Phenotype

# Step 3: predict attention from trained model

In [17]:
!python predict_attention.py \
    --model "epistasis_simulation_samples/output_model.txt.50" \
    --tsv "epistasis_simulation_samples/genotypes.tsv" \
    --pheno "epistasis_simulation_samples/simulation.pheno" \
    --cov "epistasis_simulation_samples/simulation.cov" \
    --onto "epistasis_simulation_samples/ontology.tsv" \
    --snp2gene "epistasis_simulation_samples/snp2gene.tsv" \
    --out "epistasis_simulation_samples/output_model.txt.50" \
    --batch-size 256 \
    --cuda 0

Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
   snp   gene  chr   pos  block
0    0  G0103    7   415      4
1    1  G0243    7  1267      4
2    2  G0061    7  1310      4
3    3  G0020    7  1970      4
4    4  G0145    7  2419      4
The number of SNPs: 500
  g2p_model_dict = torch.load(args.model, map_location='cuda:0')
Namespace(batch_size=64, block_bias=False, bt=[], bt_inds=[], cov_effect='pre', cov_ids=[], cov_mean_dict={'AGE': 54.4924}, cov_std_dict={'AGE': 8.590408706934571}, cuda=0, dense_attention=False, distributed=False, dropout=0.2, env2sys=True, epochs=51, flip=False, focal_loss_alpha=0.25, focal_loss_gamma=2.0, gene2pheno=True, ge

# Step 4: search epistasis and evaluate epistasis retrieval performance

In [19]:
from src.utils.analysis.epistasis_retrieval_evaluation import EvaluationConfig, EpistasisRetrievalEvaluator

In [32]:
config = EvaluationConfig(causal_info='epistasis_simulation_samples/causal_info.json',
                          attention_results='epistasis_simulation_samples/output_model.txt.50.phenotype.head_sum.csv',
                          system_importance='epistasis_simulation_samples/output_model.txt.50.phenotype.head_sum.sys_importance.csv',
                         tsv='epistasis_simulation_samples/genotypes.tsv',
                         pheno='epistasis_simulation_samples/simulation.pheno',
                         cov='epistasis_simulation_samples/simulation.cov',
                         onto='epistasis_simulation_samples/ontology.tsv',
                         snp2gene='epistasis_simulation_samples/snp2gene.tsv',
                         top_n_systems=5, # if None it will retrieve from all system
                          snp_threshold=50, # if None it will retrieve epistasis regardless of system size
                          num_workers=1,
                          executor_type='threads',
                          quantiles=[0.9],
                          output_prefix='epistasis_simulation_samples/simulation_output.50'
                         )

In [33]:
evaluator = EpistasisRetrievalEvaluator(config)

In [34]:
evaluator.evaluate()

--- [1/5] Loading inputs ---
--- Loading causal info from epistasis_simulation_samples/causal_info.json ---
--- Loaded 50 epistatic pairs and 100 additive SNPs ---
--- Loading system importance from epistasis_simulation_samples/output_model.txt.50.phenotype.head_sum.sys_importance.csv ---
--- System importance loaded. Shape: (102, 7) ---
--- Identified top 5 systems ---

--- Running Pre-flight Diagnostic Check ---
Initializing a temporary parser to check data mapping...
Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
Processing Ontology dataframe...
Building system and gene indices dictionary..
Creating masks..
102 Systems are queried
250 Genes are queried
Building descendant dict
Subtree types:  ['default']
   snp   gene  chr   pos  block
0    0  G0103    7   415      4
1    1  G0243    7  1267      4
2    2  G0061    7  1310      4
3    3 