In [1]:
import pandas as pd
from pathlib import Path

pd.set_option("display.max_columns", 10)

base = Path("data/clean_classifier_data")

# ----- FILE REGISTRY -----
parquet_files = {
    "atac_pseudobulk_df": base / "clean_data_files/atac_pseudobulk_df.parquet",
    "rna_pseudobulk_df":  base / "clean_data_files/rna_pseudobulk_df.parquet",
    "peak_to_gene_dist_df": base / "clean_feature_files/peak_to_gene_dist_df.parquet",
    "sliding_window_df":    base / "clean_feature_files/sliding_window_df.parquet",
}

csv_files = {
    "bear_no_beeline_df":              base / "clean_ground_truth_files/bear_no_beeline_df.csv",
    "bear_no_beeline_or_chip_df":      base / "clean_ground_truth_files/bear_no_beeline_or_chip_df.csv",
    "bear_no_chip_df":                 base / "clean_ground_truth_files/bear_no_chip_df.csv",
    "beeline_no_bear_df":              base / "clean_ground_truth_files/beeline_no_bear_df.csv",
    "beeline_no_chip_df":              base / "clean_ground_truth_files/beeline_no_chip_df.csv",
    "beeline_no_chip_or_bear_df":      base / "clean_ground_truth_files/beeline_no_chip_or_bear_df.csv",
    "chip_no_beeline_df":              base / "clean_ground_truth_files/chip_no_beeline_df.csv",
    "chip_no_beeline_or_bear_df":      base / "clean_ground_truth_files/chip_no_beeline_or_bear_df.csv",
    "chip_no_bear_df":                 base / "clean_ground_truth_files/chip_no_bear_df.csv",
}

# ----- LOAD DATA -----
dfs = {}
for name, path in parquet_files.items():
    dfs[name] = pd.read_parquet(path)
for name, path in csv_files.items():
    dfs[name] = pd.read_csv(path)
    


# ----- LOG HEADS & SHAPES -----
for name, df in dfs.items():
    print(f"\n{name}")
    print(df.head(3))
    print(df.shape)



atac_pseudobulk_df
                      E7.5_REP1.AAACAGCCAAACCCTA  E7.5_REP1.AAACAGCCAGGAACTG  \
peak_id                                                                        
chr1:3142536-3143136                    0.093238                    0.000000   
chr1:3553099-3553699                    0.035439                    0.199449   
chr1:3584996-3585596                    0.000000                    0.000000   

                      E7.5_REP1.AAACAGCCATCCTGAA  E7.5_REP1.AAACAGCCATGCTATG  \
peak_id                                                                        
chr1:3142536-3143136                         0.0                    0.040567   
chr1:3553099-3553699                         0.0                    0.360751   
chr1:3584996-3585596                         0.0                    0.000000   

                      E7.5_REP1.AAACATGCAATGAATG  ...  \
peak_id                                           ...   
chr1:3142536-3143136                    0.000000  ...   
chr1:35

In [None]:
import numpy as np
import pandas as pd
from numba import jit
from scipy.stats import pearsonr, spearmanr
from scipy.special import rel_entr
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist
import warnings
from tqdm.auto import tqdm
warnings.filterwarnings('ignore')

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def safe_divide(a, b, fill_value=0.0):
    """Safe division with fill value for division by zero."""
    return np.divide(a, b, out=np.full_like(a, fill_value, dtype=float), where=b!=0)

def entropy(x):
    """Calculate Shannon entropy of array x."""
    x = x[x > 0]  # Remove zeros
    if len(x) == 0:
        return 0.0
    p = x / np.sum(x)
    return -np.sum(p * np.log2(p))

def partial_correlation(x, y, z):
    """Calculate partial correlation between x and y controlling for z."""
    if len(x) < 3:
        return np.nan
    try:
        # Residuals after regressing out z
        x_resid = x - np.mean(x)
        y_resid = y - np.mean(y)
        z_norm = (z - np.mean(z)) / (np.std(z) + 1e-10)
        
        x_resid = x_resid - np.dot(x_resid, z_norm) * z_norm / (np.dot(z_norm, z_norm) + 1e-10)
        y_resid = y_resid - np.dot(y_resid, z_norm) * z_norm / (np.dot(z_norm, z_norm) + 1e-10)
        
        # Correlation of residuals
        corr = np.corrcoef(x_resid, y_resid)[0, 1]
        return corr if not np.isnan(corr) else 0.0
    except:
        return np.nan

@jit(nopython=True)
def fast_gini(x):
    """Numba-compiled Gini coefficient."""
    if len(x) == 0 or np.sum(x) == 0:
        return 0.0
    sorted_x = np.sort(x)
    n = len(x)
    cumsum = np.cumsum(sorted_x)
    return (2 * np.sum((n - np.arange(n)) * sorted_x)) / (n * cumsum[-1]) - (n + 1) / n

@jit(nopython=True)
def fast_entropy(x):
    """Numba-compiled Shannon entropy."""
    x_pos = x[x > 0]
    if len(x_pos) == 0:
        return 0.0
    p = x_pos / np.sum(x_pos)
    return -np.sum(p * np.log2(p))

@jit(nopython=True)
def compute_distance_features(scores, tss_scores, tss_distances):
    """Numba-compiled distance feature computation."""
    n = len(scores)
    
    # Initialize outputs
    min_tss_dist = np.inf
    mean_tss_dist = np.inf
    closest_peak_binding = 0.0
    peaks_1kb = 0
    peaks_5kb = 0
    peaks_10kb = 0
    peaks_50kb = 0
    
    abs_distances = np.abs(tss_distances)
    binding_mask = scores > 0
    
    if np.sum(binding_mask) > 0:
        bound_distances = abs_distances[binding_mask]
        bound_scores = scores[binding_mask]
        
        min_tss_dist = np.min(bound_distances)
        closest_idx = np.argmin(bound_distances)
        closest_peak_binding = bound_scores[closest_idx]
        
        # Weighted mean distance
        mean_tss_dist = np.sum(bound_distances * bound_scores) / np.sum(bound_scores)
        
        # Binned counts
        peaks_1kb = np.sum((abs_distances < 1000) & binding_mask)
        peaks_5kb = np.sum((abs_distances < 5000) & binding_mask)
        peaks_10kb = np.sum((abs_distances < 10000) & binding_mask)
        peaks_50kb = np.sum((abs_distances < 50000) & binding_mask)
    
    return min_tss_dist, mean_tss_dist, closest_peak_binding, peaks_1kb, peaks_5kb, peaks_10kb, peaks_50kb

# ============================================================================
# PRECOMPUTE CORRELATION MATRICES (ONE-TIME COMPUTATION)
# ============================================================================

def precompute_correlation_matrices(rna_df, atac_df, peak_to_gene_df):
    """
    Pre-compute correlation matrices to avoid redundant calculations.
    Fully vectorized for maximum speed.
    """
    print("Pre-computing correlation matrices...")
    
    # Get unique genes and TFs
    genes = peak_to_gene_df['TG'].unique()
    tfs = rna_df.index.tolist()
    
    # ========================================================================
    # 1. GENE ACCESSIBILITY (sum of peaks per gene)
    # ========================================================================
    print("\n[1/4] Computing gene accessibility...")
    gene_to_peaks = peak_to_gene_df.groupby('TG')['peak_id'].apply(list).to_dict()
    
    gene_accessibility_dict = {}
    for gene in tqdm(genes, desc="  Gene accessibility"):
        if gene in gene_to_peaks:
            peaks = gene_to_peaks[gene]
            valid_peaks = [p for p in peaks if p in atac_df.index]
            if valid_peaks:
                gene_accessibility_dict[gene] = atac_df.loc[valid_peaks].sum(axis=0).values
            else:
                gene_accessibility_dict[gene] = np.zeros(atac_df.shape[1])
        else:
            gene_accessibility_dict[gene] = np.zeros(atac_df.shape[1])
    
    # ========================================================================
    # 2. TF-TG EXPRESSION CORRELATIONS (Vectorized)
    # ========================================================================
    print("\n[2/4] Computing TF-TG expression correlations...")
    
    # Get subset of RNA matrix for genes of interest
    genes_in_rna = [g for g in genes if g in rna_df.index]
    tfs_in_rna = [tf for tf in tfs if tf in rna_df.index]
    
    # Create index mappings for fast lookup
    rna_gene_to_idx = {gene: idx for idx, gene in enumerate(rna_df.index)}
    
    # Extract TF and gene expression matrices
    tf_indices = [rna_gene_to_idx[tf] for tf in tfs_in_rna]
    gene_indices = [rna_gene_to_idx[g] for g in genes_in_rna]
    
    tf_matrix = rna_df.values[tf_indices, :]  # (n_tfs, n_cells)
    gene_matrix = rna_df.values[gene_indices, :]  # (n_genes, n_cells)
    
    # Compute Pearson correlation matrix all at once
    print("  Computing Pearson correlations (vectorized)...")
    # Pearson: use numpy corrcoef on stacked arrays, then extract submatrix
    # More memory efficient: compute in chunks if needed
    
    tf_tg_corr_pearson = {}
    tf_tg_corr_spearman = {}
    
    # Use pandas for efficient correlation computation
    tf_df = pd.DataFrame(tf_matrix, index=tfs_in_rna)
    gene_df = pd.DataFrame(gene_matrix, index=genes_in_rna)
    
    # Compute full correlation matrix (TFs x Genes)
    # This is much faster than nested loops
    pearson_corr_matrix = np.corrcoef(tf_matrix, gene_matrix)[:len(tfs_in_rna), len(tfs_in_rna):]
    
    # Store in dictionary
    for i, tf in enumerate(tqdm(tfs_in_rna, desc="  Storing Pearson")):
        for j, gene in enumerate(genes_in_rna):
            corr_val = pearson_corr_matrix[i, j]
            tf_tg_corr_pearson[(tf, gene)] = corr_val if not np.isnan(corr_val) else 0.0
    
    # Compute Spearman (on ranks) - this is slower, so we can do in batches
    print("  Computing Spearman correlations (vectorized)...")
    
    # Rank transform the matrices
    from scipy.stats import rankdata
    tf_ranks = np.apply_along_axis(rankdata, 1, tf_matrix)
    gene_ranks = np.apply_along_axis(rankdata, 1, gene_matrix)
    
    # Compute correlation on ranks (Spearman)
    spearman_corr_matrix = np.corrcoef(tf_ranks, gene_ranks)[:len(tfs_in_rna), len(tfs_in_rna):]
    
    for i, tf in enumerate(tqdm(tfs_in_rna, desc="  Storing Spearman")):
        for j, gene in enumerate(genes_in_rna):
            corr_val = spearman_corr_matrix[i, j]
            tf_tg_corr_spearman[(tf, gene)] = corr_val if not np.isnan(corr_val) else 0.0
    
    # ========================================================================
    # 3. TF-ACCESSIBILITY CORRELATIONS (Vectorized)
    # ========================================================================
    print("\n[3/4] Computing TF-accessibility correlations...")
    
    # Create gene accessibility matrix
    gene_acc_matrix = np.array([gene_accessibility_dict[g] for g in genes])  # (n_genes, n_cells)
    
    # Compute correlation between TFs and gene accessibility
    tf_atac_corr_matrix = np.corrcoef(tf_matrix, gene_acc_matrix)[:len(tfs_in_rna), len(tfs_in_rna):]
    
    tf_atac_corr = {}
    for i, tf in enumerate(tqdm(tfs_in_rna, desc="  TF-accessibility")):
        for j, gene in enumerate(genes):
            corr_val = tf_atac_corr_matrix[i, j]
            tf_atac_corr[(tf, gene)] = corr_val if not np.isnan(corr_val) else 0.0
    
    # ========================================================================
    # 4. TG EXPRESSION-ACCESSIBILITY CORRELATIONS (Vectorized)
    # ========================================================================
    print("\n[4/4] Computing TG expression-accessibility correlations...")
    
    atac_rna_corr = {}
    
    # For each gene, compute correlation between its expression and accessibility
    for i, gene in enumerate(tqdm(genes, desc="  Gene expr-accessibility")):
        if gene in genes_in_rna:
            gene_idx = gene_indices[genes_in_rna.index(gene)]
            tg_expr = rna_df.values[gene_idx, :]
            gene_acc = gene_accessibility_dict[gene]
            
            corr = np.corrcoef(tg_expr, gene_acc)[0, 1]
            atac_rna_corr[gene] = corr if not np.isnan(corr) else 0.0
        else:
            atac_rna_corr[gene] = 0.0
    
    print("\n✓ Correlation matrices computed!\n")
    
    return {
        'tf_tg_corr_pearson': tf_tg_corr_pearson,
        'tf_tg_corr_spearman': tf_tg_corr_spearman,
        'tf_atac_corr': tf_atac_corr,
        'atac_rna_corr': atac_rna_corr,
        'gene_accessibility_dict': gene_accessibility_dict
    }

# ============================================================================
# MAIN FEATURE COMPUTATION
# ============================================================================

def compute_binding_features_fast(tf_tg_binding):
    """
    Fast vectorized computation of binding features.
    Uses groupby with built-in aggregations + custom numba functions.
    """
    print("Computing binding features (vectorized)...")
    
    # Pre-compute weighted scores
    tf_tg_binding['weighted_score'] = (
        tf_tg_binding['sliding_window_score'] * tf_tg_binding['tss_distance_score']
    )
    tf_tg_binding['abs_tss_distance'] = np.abs(tf_tg_binding['tss_distance'])
    tf_tg_binding['has_binding'] = (tf_tg_binding['sliding_window_score'] > 0).astype(int)
    
    # Proximity score (exponential decay)
    tf_tg_binding['proximity_score'] = (
        tf_tg_binding['sliding_window_score'] * 
        np.exp(-tf_tg_binding['abs_tss_distance'] / 1000)
    )
    
    # Binned distance indicators
    tf_tg_binding['within_1kb'] = (
        (tf_tg_binding['abs_tss_distance'] < 1000) & 
        (tf_tg_binding['sliding_window_score'] > 0)
    ).astype(int)
    tf_tg_binding['within_5kb'] = (
        (tf_tg_binding['abs_tss_distance'] < 5000) & 
        (tf_tg_binding['sliding_window_score'] > 0)
    ).astype(int)
    tf_tg_binding['within_10kb'] = (
        (tf_tg_binding['abs_tss_distance'] < 10000) & 
        (tf_tg_binding['sliding_window_score'] > 0)
    ).astype(int)
    tf_tg_binding['within_50kb'] = (
        (tf_tg_binding['abs_tss_distance'] < 50000) & 
        (tf_tg_binding['sliding_window_score'] > 0)
    ).astype(int)
    
    # ========================================================================
    # FAST AGGREGATION (built-in functions)
    # ========================================================================
    print("  Phase 1: Basic aggregations...")
    
    agg_dict = {
        'sliding_window_score': [
            ('max_binding_score', 'max'),
            ('mean_binding_score', 'mean'),
            ('median_binding_score', 'median'),
            ('std_binding_score', 'std'),
            ('sum_binding_score', 'sum'),
        ],
        'weighted_score': [
            ('distance_weighted_binding', 'sum'),
            ('max_weighted_binding', 'max'),
        ],
        'has_binding': [
            ('n_peaks_with_binding', 'sum'),
        ],
        'proximity_score': [
            ('proximity_binding_score', 'sum'),
        ],
        'within_1kb': [('peaks_within_1kb', 'sum')],
        'within_5kb': [('peaks_within_5kb', 'sum')],
        'within_10kb': [('peaks_within_10kb', 'sum')],
        'within_50kb': [('peaks_within_50kb', 'sum')],
    }
    
    # This is MUCH faster than apply()
    features = tf_tg_binding.groupby(['TF', 'TG'], observed=True).agg(**{
        name: (col, func) for col, funcs in agg_dict.items() for name, func in funcs
    })
    
    # Add fraction peaks bound
    peak_counts = tf_tg_binding.groupby(['TF', 'TG'], observed=True).size()
    features['fraction_peaks_bound'] = features['n_peaks_with_binding'] / peak_counts
    
    # Compute 75th percentile threshold and strong peaks count
    print("  Phase 2: Strong peaks (75th percentile)...")
    percentile_75 = tf_tg_binding.groupby(['TF', 'TG'], observed=True)['sliding_window_score'].quantile(0.75)
    
    # Count peaks above 75th percentile
    tf_tg_binding_with_p75 = tf_tg_binding.merge(
        percentile_75.rename('p75'),
        left_on=['TF', 'TG'],
        right_index=True
    )
    tf_tg_binding_with_p75['is_strong'] = (
        tf_tg_binding_with_p75['sliding_window_score'] > tf_tg_binding_with_p75['p75']
    ).astype(int)
    
    n_strong_peaks = tf_tg_binding_with_p75.groupby(['TF', 'TG'], observed=True)['is_strong'].sum()
    features['n_strong_peaks'] = n_strong_peaks
    
    # ========================================================================
    # DISTANCE FEATURES (using numba for speed)
    # ========================================================================
    print("  Phase 3: Distance features (numba-accelerated)...")
    
    # Group data for numba processing
    grouped = tf_tg_binding.groupby(['TF', 'TG'], observed=True)
    
    distance_results = []
    for (tf, tg), group in tqdm(grouped, desc="    Distance features"):
        scores = group['sliding_window_score'].values
        tss_scores = group['tss_distance_score'].values
        tss_distances = group['tss_distance'].values
        
        min_dist, mean_dist, closest_binding, p1kb, p5kb, p10kb, p50kb = compute_distance_features(
            scores, tss_scores, tss_distances
        )
        
        distance_results.append({
            'TF': tf,
            'TG': tg,
            'min_tss_distance': min_dist,
            'mean_tss_distance': mean_dist,
            'closest_peak_binding': closest_binding,
            'peaks_within_1kb_check': p1kb,
            'peaks_within_5kb_check': p5kb,
            'peaks_within_10kb_check': p10kb,
            'peaks_within_50kb_check': p50kb,
        })
    
    distance_df = pd.DataFrame(distance_results).set_index(['TF', 'TG'])
    
    # Merge distance features
    features = features.join(distance_df)
    
    # Use the numba-computed binned counts (they should match the vectorized ones, but numba is more accurate)
    features['peaks_within_1kb'] = features['peaks_within_1kb_check']
    features['peaks_within_5kb'] = features['peaks_within_5kb_check']
    features['peaks_within_10kb'] = features['peaks_within_10kb_check']
    features['peaks_within_50kb'] = features['peaks_within_50kb_check']
    features.drop(columns=['peaks_within_1kb_check', 'peaks_within_5kb_check', 
                           'peaks_within_10kb_check', 'peaks_within_50kb_check'], inplace=True)
    
    # ========================================================================
    # ENTROPY AND GINI (using numba)
    # ========================================================================
    print("  Phase 4: Entropy and Gini (numba-accelerated)...")
    
    entropy_results = []
    for (tf, tg), group in tqdm(grouped, desc="    Entropy/Gini"):
        scores = group['sliding_window_score'].values
        
        entropy_val = fast_entropy(scores)
        gini_val = fast_gini(scores)
        
        entropy_results.append({
            'TF': tf,
            'TG': tg,
            'binding_entropy': entropy_val,
            'binding_gini': gini_val,
        })
    
    entropy_df = pd.DataFrame(entropy_results).set_index(['TF', 'TG'])
    features = features.join(entropy_df)
    
    # ========================================================================
    # NEGATIVE EVIDENCE
    # ========================================================================
    features['no_binding_near_tss'] = (features['peaks_within_10kb'] == 0).astype(int)
    
    # Reset index
    features = features.reset_index()
    
    print(f"✓ Computed binding features for {len(features)} TF-TG pairs\n")
    
    return features

def add_expression_features(features_df, rna_df, precomputed):
    """Add expression-based features."""
    print("Adding expression features...")
    
    rna_genes = rna_df.index.tolist()
    rna_matrix = rna_df.values
    
    # Vectorized expression features
    tf_list = features_df['TF'].values
    tg_list = features_df['TG'].values
    
    # Initialize arrays
    n_pairs = len(features_df)
    tf_mean_expr = np.zeros(n_pairs)
    tf_std_expr = np.zeros(n_pairs)
    tf_cv = np.zeros(n_pairs)
    tf_detection_rate = np.zeros(n_pairs)
    
    tg_mean_expr = np.zeros(n_pairs)
    tg_std_expr = np.zeros(n_pairs)
    tg_cv = np.zeros(n_pairs)
    tg_detection_rate = np.zeros(n_pairs)
    
    tf_tg_expr_ratio = np.zeros(n_pairs)
    codetection_rate = np.zeros(n_pairs)
    
    tf_tg_pearson_corr = np.zeros(n_pairs)
    tf_tg_spearman_corr = np.zeros(n_pairs)
    
    # Context-aware features
    tf_active_when_tg_active = np.zeros(n_pairs)
    tg_expr_in_top_tf_quartile = np.zeros(n_pairs)
    tg_expr_in_bottom_tf_quartile = np.zeros(n_pairs)
    differential_tg_expr = np.zeros(n_pairs)
    
    tf_not_expressed = np.zeros(n_pairs)
    anticorrelation_flag = np.zeros(n_pairs)
    
    for i, (tf, tg) in enumerate(tqdm(zip(tf_list, tg_list), desc="  Computing expression features")):
        # Get expression data
        if tf in rna_genes:
            tf_idx = rna_genes.index(tf)
            tf_expr = rna_matrix[tf_idx]
            
            tf_mean_expr[i] = np.mean(tf_expr)
            tf_std_expr[i] = np.std(tf_expr)
            tf_cv[i] = tf_std_expr[i] / (tf_mean_expr[i] + 1e-10)
            tf_detection_rate[i] = np.mean(tf_expr > 0)
            
            # Thresholds for context features
            tf_threshold = np.percentile(tf_expr, 25) if tf_mean_expr[i] > 0 else 0
            tf_75 = np.percentile(tf_expr, 75)
            tf_25 = np.percentile(tf_expr, 25)
            
            tf_not_expressed[i] = 1 if tf_mean_expr[i] < tf_threshold else 0
        else:
            tf_expr = None
            tf_not_expressed[i] = 1
        
        if tg in rna_genes:
            tg_idx = rna_genes.index(tg)
            tg_expr_arr = rna_matrix[tg_idx]
            
            tg_mean_expr[i] = np.mean(tg_expr_arr)
            tg_std_expr[i] = np.std(tg_expr_arr)
            tg_cv[i] = tg_std_expr[i] / (tg_mean_expr[i] + 1e-10)
            tg_detection_rate[i] = np.mean(tg_expr_arr > 0)
            
            # TG threshold for context features
            tg_threshold = np.percentile(tg_expr_arr, 25) if tg_mean_expr[i] > 0 else 0
        else:
            tg_expr_arr = None
        
        # TF-TG relationship features
        if tf_expr is not None and tg_expr_arr is not None:
            tf_tg_expr_ratio[i] = tf_mean_expr[i] / (tg_mean_expr[i] + 1e-10)
            codetection_rate[i] = np.mean((tf_expr > 0) & (tg_expr_arr > 0))
            
            # Get precomputed correlations
            tf_tg_pearson_corr[i] = precomputed['tf_tg_corr_pearson'].get((tf, tg), 0.0)
            tf_tg_spearman_corr[i] = precomputed['tf_tg_corr_spearman'].get((tf, tg), 0.0)
            
            # Context-aware features
            tg_active_mask = tg_expr_arr > tg_threshold
            if np.any(tg_active_mask):
                tf_active_when_tg_active[i] = np.mean(tf_expr[tg_active_mask])
            
            tf_top_mask = tf_expr > tf_75
            tf_bottom_mask = tf_expr < tf_25
            
            if np.any(tf_top_mask):
                tg_expr_in_top_tf_quartile[i] = np.mean(tg_expr_arr[tf_top_mask])
            if np.any(tf_bottom_mask):
                tg_expr_in_bottom_tf_quartile[i] = np.mean(tg_expr_arr[tf_bottom_mask])
            
            differential_tg_expr[i] = tg_expr_in_top_tf_quartile[i] - tg_expr_in_bottom_tf_quartile[i]
            
            # Negative evidence
            anticorrelation_flag[i] = 1 if tf_tg_pearson_corr[i] < -0.3 else 0
    
    # Add to dataframe
    features_df['tf_mean_expr'] = tf_mean_expr
    features_df['tf_std_expr'] = tf_std_expr
    features_df['tf_cv'] = tf_cv
    features_df['tf_detection_rate'] = tf_detection_rate
    
    features_df['tg_mean_expr'] = tg_mean_expr
    features_df['tg_std_expr'] = tg_std_expr
    features_df['tg_cv'] = tg_cv
    features_df['tg_detection_rate'] = tg_detection_rate
    
    features_df['tf_tg_expr_ratio'] = tf_tg_expr_ratio
    features_df['tf_tg_pearson_corr'] = tf_tg_pearson_corr
    features_df['tf_tg_spearman_corr'] = tf_tg_spearman_corr
    features_df['codetection_rate'] = codetection_rate
    
    features_df['tf_active_when_tg_active'] = tf_active_when_tg_active
    features_df['tg_expr_in_top_tf_quartile'] = tg_expr_in_top_tf_quartile
    features_df['tg_expr_in_bottom_tf_quartile'] = tg_expr_in_bottom_tf_quartile
    features_df['differential_tg_expr'] = differential_tg_expr
    
    features_df['tf_not_expressed'] = tf_not_expressed
    features_df['anticorrelation_flag'] = anticorrelation_flag
    
    return features_df

def add_accessibility_features(features_df, rna_df, atac_df, precomputed):
    """Add accessibility-based features."""
    print("Adding accessibility features...")
    
    rna_genes = rna_df.index.tolist()
    rna_matrix = rna_df.values
    
    tf_list = features_df['TF'].values
    tg_list = features_df['TG'].values
    
    n_pairs = len(features_df)
    
    # Initialize arrays
    gene_accessibility_mean = np.zeros(n_pairs)
    gene_accessibility_std = np.zeros(n_pairs)
    gene_accessibility_cv = np.zeros(n_pairs)
    atac_variability_ratio = np.zeros(n_pairs)
    
    atac_rna_corr = np.zeros(n_pairs)
    tf_atac_corr = np.zeros(n_pairs)
    tf_expr_tg_atac_corr = np.zeros(n_pairs)
    
    atac_when_tf_expressed = np.zeros(n_pairs)
    partial_corr_tg_atac_ctrl_tf = np.zeros(n_pairs)
    
    for i, (tf, tg) in enumerate(tqdm(zip(tf_list, tg_list), desc="  Computing accessibility features")):
        # Get gene accessibility
        if tg in precomputed['gene_accessibility_dict']:
            gene_acc = precomputed['gene_accessibility_dict'][tg]
            
            gene_accessibility_mean[i] = np.mean(gene_acc)
            gene_accessibility_std[i] = np.std(gene_acc)
            gene_accessibility_cv[i] = gene_accessibility_std[i] / (gene_accessibility_mean[i] + 1e-10)
            atac_variability_ratio[i] = gene_accessibility_std[i] / (gene_accessibility_mean[i] + 1e-10)
            
            # Get precomputed correlations
            atac_rna_corr[i] = precomputed['atac_rna_corr'].get(tg, 0.0)
            tf_atac_corr[i] = precomputed['tf_atac_corr'].get((tf, tg), 0.0)
            tf_expr_tg_atac_corr[i] = tf_atac_corr[i]  # Same thing
            
            # Context-aware: accessibility when TF is expressed
            if tf in rna_genes:
                tf_idx = rna_genes.index(tf)
                tf_expr = rna_matrix[tf_idx]
                
                tf_threshold = np.percentile(tf_expr, 25) if np.mean(tf_expr) > 0 else 0
                tf_expressed_mask = tf_expr > tf_threshold
                
                if np.any(tf_expressed_mask):
                    atac_when_tf_expressed[i] = np.mean(gene_acc[tf_expressed_mask])
                
                # Partial correlation
                if tg in rna_genes:
                    tg_idx = rna_genes.index(tg)
                    tg_expr = rna_matrix[tg_idx]
                    
                    partial_corr_tg_atac_ctrl_tf[i] = partial_correlation(tg_expr, gene_acc, tf_expr)
    
    # Add to dataframe
    features_df['gene_accessibility_mean'] = gene_accessibility_mean
    features_df['gene_accessibility_std'] = gene_accessibility_std
    features_df['gene_accessibility_cv'] = gene_accessibility_cv
    features_df['atac_variability_ratio'] = atac_variability_ratio
    
    features_df['atac_rna_corr'] = atac_rna_corr
    features_df['tf_atac_corr'] = tf_atac_corr
    features_df['tf_expr_tg_atac_corr'] = tf_expr_tg_atac_corr
    
    features_df['atac_when_tf_expressed'] = atac_when_tf_expressed
    features_df['partial_corr_tg_atac_ctrl_tf'] = partial_corr_tg_atac_ctrl_tf
    
    return features_df

def add_interaction_features(features_df):
    """Add interaction and composite features."""
    print("Adding interaction features...")
    
    # Interaction terms
    features_df['binding_x_tf_expr'] = features_df['max_binding_score'] * features_df['tf_mean_expr']
    features_df['binding_x_corr'] = features_df['max_binding_score'] * features_df['tf_tg_pearson_corr']
    features_df['binding_x_distance'] = features_df['max_binding_score'] * safe_divide(
        1.0, features_df['min_tss_distance'].replace([np.inf, -np.inf], 1e6), fill_value=0.0
    )
    
    # Regulatory potential score (composite)
    # Use distance_weighted_binding as a proxy for max_weighted_binding * tss_distance_score
    features_df['regulatory_potential'] = (
        features_df['max_weighted_binding'] * 
        features_df['tf_mean_expr'] * 
        np.abs(features_df['tf_tg_pearson_corr'])
    )
    
    return features_df

# ============================================================================
# MAIN PIPELINE
# ============================================================================

def compute_all_features(sliding_window_df, peak_to_gene_dist_df, rna_df, atac_df, 
                          tf_tg_pairs=None, batch_size=None):
    """
    Main function to compute all features for TF-TG pairs.
    
    Parameters:
    -----------
    sliding_window_df : DataFrame with columns ['TF', 'peak_id', 'sliding_window_score']
    peak_to_gene_dist_df : DataFrame with peak-to-gene mappings
    rna_df : DataFrame with RNA expression (genes x cells)
    atac_df : DataFrame with ATAC accessibility (peaks x cells)
    tf_tg_pairs : Optional list of (TF, TG) tuples to compute features for.
                  If None, compute for all possible pairs.
    batch_size : If provided, process in batches (useful for very large datasets)
    
    Returns:
    --------
    features_df : DataFrame with all computed features
    """
    
    # Step 1: Pre-compute correlations (one-time cost)
    precomputed = precompute_correlation_matrices(rna_df, atac_df, peak_to_gene_dist_df)
    
    # Step 2: Merge sliding window with peak-to-gene links
    print("Merging sliding window with peak-to-gene distances...")
    tf_tg_binding = sliding_window_df.merge(
        peak_to_gene_dist_df[['peak_id', 'TG', 'tss_distance', 'tss_distance_score']], 
        on='peak_id',
        how='inner'
    )
    print(f"  Merged shape: {tf_tg_binding.shape}")
    
    # Step 3: Compute binding features via groupby
    print("Computing binding features...")
    binding_features = compute_binding_features_fast(tf_tg_binding)
    
    print(f"  Computed binding features for {len(binding_features)} TF-TG pairs\n")
    
    # Step 4: Add expression features
    features_df = add_expression_features(binding_features, rna_df, precomputed)
    
    # Step 5: Add accessibility features
    features_df = add_accessibility_features(features_df, rna_df, atac_df, precomputed)
    
    # Step 6: Add interaction features
    features_df = add_interaction_features(features_df)
    
    # Handle any remaining NaN or inf values
    print("Cleaning up features...")
    features_df = features_df.replace([np.inf, -np.inf], np.nan)
    
    # Fill NaN with 0 for most features (or could use median/mean)
    numeric_cols = features_df.select_dtypes(include=[np.number]).columns
    features_df[numeric_cols] = features_df[numeric_cols].fillna(0)
    
    print(f"\n✓ Feature computation complete!")
    print(f"  Final shape: {features_df.shape}")
    print(f"  Features: {features_df.shape[1] - 2} (excluding TF and TG columns)")
    
    return features_df

# ============================================================================
# USAGE EXAMPLE
# ============================================================================

sliding_window_df = dfs["sliding_window_df"]
peak_to_gene_dist_df = dfs["peak_to_gene_dist_df"]
rna_pseudobulk_df = dfs["rna_pseudobulk_df"]
atac_pseudobulk_df = dfs["atac_pseudobulk_df"]

features_df = compute_all_features(
    sliding_window_df=sliding_window_df,
    peak_to_gene_dist_df=peak_to_gene_dist_df,
    rna_df=rna_pseudobulk_df,
    atac_df=atac_pseudobulk_df
)

Pre-computing correlation matrices...

[1/4] Computing gene accessibility...


  Gene accessibility:   0%|          | 0/15161 [00:00<?, ?it/s]


[2/4] Computing TF-TG expression correlations...
  Computing Pearson correlations (vectorized)...


  Storing Pearson:   0%|          | 0/2925 [00:00<?, ?it/s]

  Computing Spearman correlations (vectorized)...


  Storing Spearman:   0%|          | 0/2925 [00:00<?, ?it/s]


[3/4] Computing TF-accessibility correlations...


  TF-accessibility:   0%|          | 0/2925 [00:00<?, ?it/s]


[4/4] Computing TG expression-accessibility correlations...


  Gene expr-accessibility:   0%|          | 0/15161 [00:00<?, ?it/s]


✓ Correlation matrices computed!

Merging sliding window with peak-to-gene distances...
  Merged shape: (58809252, 6)
Computing binding features...
Computing binding features (vectorized)...
  Phase 1: Basic aggregations...
  Phase 2: Strong peaks (75th percentile)...
  Phase 3: Distance features (numba-accelerated)...


    Distance features:   0%|          | 0/3289937 [00:00<?, ?it/s]

  Phase 4: Entropy and Gini (numba-accelerated)...


    Entropy/Gini:   0%|          | 0/3289937 [00:00<?, ?it/s]

✓ Computed binding features for 3289937 TF-TG pairs

  Computed binding features for 3289937 TF-TG pairs

Adding expression features...


  Computing expression features: 0it [00:00, ?it/s]