In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gzip
from functools import lru_cache
import os

# Set display options and plotting style
pd.set_option('display.max_columns', 100)
plt.style.use('seaborn-whitegrid')

# Create output directory
OUTPUT_DIR = './diagonal_constraint_analysis'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# File paths - update these to your locations
TFBS_RESULTS_PATH = ''
OLD_TFBS_PATH = ''  # Contains TF information
ORTHODB_RATES_PATH = ''
GTEX_FILE = ""
BIOMART_PATH = ""
TRAINING_GENES_PATH = ''

# Define model columns based on what's actually available in the data
model_columns = [
    # HyenaDNA models
    'mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen',
    'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen', 
    'mean_cross_entropy_diff_hyenadna-medium-160k-seqlen',
    'mean_cross_entropy_diff_hyenadna-large-1m-seqlen',
    'mean_cross_entropy_diff_hyenadna-small-32k-seqlen', 
    
    # Other transformer models
    'mean_cross_entropy_diff_DNABERT-2-117M',
    'mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g',
    'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref',
    'mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species',
    
    # Existing named columns in old data
    'GPN',
    'Phylop', 
    'LOL-EVE',  # This exists in old data
    'Enformer', # This exists in old data
    
    # Other models
    'mean_diff_evo_1_131k_base',
    'mean_cross_entropy_diff_johahi/specieslm-metazoa-upstream-k6',
    'mean_cross_entropy_diff_evo2-7b',
    'mean_cross_entropy_diff_songlab/gpn-animal-promoter'
    
    # Note: Excluded LOL-EVE ablation columns to focus on main models only
]

# Color and display name mappings
model_colors = {
    'LOL-EVE': '#00aa55',
    'Enformer': '#9467bd',
    'GPN': '#ffbb78',
    'Phylop': '#ff7f0e',
    
    # HyenaDNA family - shades of blue/teal
    'mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen': '#17becf',
    'mean_cross_entropy_diff_hyenadna-small-32k-seqlen': '#14a3c7',
    'mean_cross_entropy_diff_hyenadna-medium-160k-seqlen': '#1190b8',
    'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen': '#0e7ca8',
    'mean_cross_entropy_diff_hyenadna-large-1m-seqlen': '#0b6999',
    
    # Caduceus family - shades of pink/magenta
    'mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16': '#e377c2',
    'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16': '#d85fb8',
    
    # Nucleotide transformer family - shades of brown/tan
    'mean_cross_entropy_diff_DNABERT-2-117M': '#8c564b',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species': '#a0784f',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g': '#b49a53',
    'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref': '#c8bc57',
    'mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species': '#dcde5b',
    
    # Other models
    'mean_cross_entropy_diff_evo2-7b': '#ff9896',
    'mean_cross_entropy_diff_johahi/specieslm-metazoa-upstream-k6': '#9edae5',
    'mean_cross_entropy_diff_songlab/gpn-animal-promoter': '#2ca02c',
    'mean_diff_evo_1_131k_base': '#ffbb78'
}

model_display_names = {
    'LOL-EVE': 'LOL-EVE',
    'Enformer': 'Enformer', 
    'GPN': 'GPN',
    'Phylop': 'PhyloP',
    
    # HyenaDNA family
    'mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen': 'HyenaDNA-Tiny',
    'mean_cross_entropy_diff_hyenadna-small-32k-seqlen': 'HyenaDNA-Small',
    'mean_cross_entropy_diff_hyenadna-medium-160k-seqlen': 'HyenaDNA-Medium-160k',
    'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen': 'HyenaDNA-Medium-450k',
    'mean_cross_entropy_diff_hyenadna-large-1m-seqlen': 'HyenaDNA-Large',
    
    # Other transformers
    'mean_cross_entropy_diff_DNABERT-2-117M': 'DNABERT-2',
    'mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16': 'Caduceus-PH',
    'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16': 'Caduceus-PS',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species': 'NT-2.5B-Multi',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g': 'NT-2.5B-1000G',
    'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref': 'NT-500M-Ref',
    'mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species': 'NT-V2-500M-Multi',
    
    # Other models
    'mean_cross_entropy_diff_evo2-7b': 'Evo-2',
    'mean_cross_entropy_diff_johahi/specieslm-metazoa-upstream-k6': 'Species-LM',
    'mean_cross_entropy_diff_songlab/gpn-animal-promoter': 'GPN-Promoter',
    'mean_diff_evo_1_131k_base': 'Evo-1'
}

@lru_cache(maxsize=1)
def load_gtex_expression_data(file_path):
    """Load GTEx data with caching"""
    print(f"Loading GTEx data from {file_path}...")
    with gzip.open(file_path, 'rt') as f:
        next(f)  # Skip headers
        next(f)
        df = pd.read_csv(f, sep='\t', index_col=0)
    return df.drop('Description', axis=1) if 'Description' in df.columns else df

def load_data():
    """Load and merge all necessary datasets"""
    print("Loading TFBS data...")
    tfbs_df = pd.read_csv(TFBS_RESULTS_PATH)
    tfbs_df.rename({'chrom':'CHROM', 'pos':'POS', 'ref':'REF', 'alt':'ALT', 'gene':'GENE', 'species':'SPECIES'}, axis=1, inplace=True)
    
    print("Loading old TFBS data with TF information...")
    old_tfbs = pd.read_csv(OLD_TFBS_PATH)
    old_tfbs.rename({'chrom':'CHROM', 'pos':'POS', 'ref':'REF', 'alt':'ALT', 'gene':'GENE', 'species':'SPECIES'}, axis=1, inplace=True)
    
    print("Merging TFBS datasets...")
    merge_columns = ['CHROM', 'POS', 'ALT', 'REF', 'GENE', 'SPECIES']
    merge_columns = [col for col in merge_columns if col in tfbs_df.columns and col in old_tfbs.columns]
    tfbs_df = tfbs_df.merge(old_tfbs, on=merge_columns, how='inner')
    
    print("Loading evolutionary rates data...")
    evo_rates = pd.read_csv(ORTHODB_RATES_PATH)
    evo_rates['GENE'] = evo_rates['GENE'].str.lower()
    evo_rates.dropna(inplace=True)
    
    print("Merging TFBS and evolutionary rates...")
    merged_df = tfbs_df.merge(evo_rates, on='GENE')
    
    # Rename columns for clarity (only rename what needs renaming)
    column_mapping = {}
    
    # Only rename ar_forward_llr_no_ablation to LOL-EVE if LOL-EVE doesn't already exist
    if 'ar_forward_llr_no_ablation' in merged_df.columns and 'LOL-EVE' not in merged_df.columns:
        column_mapping['ar_forward_llr_no_ablation'] = 'LOL-EVE'
    
    # Rename evolutionary rate column if it exists
    if 'Mammalia_Evo_Rate' in merged_df.columns:
        column_mapping['Mammalia_Evo_Rate'] = 'Mammalian_Constraint'
    
    if column_mapping:
        merged_df.rename(columns=column_mapping, inplace=True)
        print(f"Renamed columns: {column_mapping}")
    
    # Use the existing LOL-EVE column if ar_forward_llr_no_ablation doesn't exist
    if 'LOL-EVE' not in merged_df.columns and 'ar_forward_llr_no_ablation' not in merged_df.columns:
        print("Warning: No LOL-EVE column found!")
    
    print("Loading training genes...")
    training_genes_df = pd.read_table(TRAINING_GENES_PATH, header=None)
    training_genes_df[3] = training_genes_df[3].apply(lambda x: x.split('promoter_')[1])
    training_genes = training_genes_df[3].unique()
    
    print("Loading expression data...")
    expression_data = load_gtex_expression_data(GTEX_FILE)
    
    # Load gene mapping
    biomart_df = pd.read_csv(BIOMART_PATH, sep='\t', usecols=['Gene stable ID', 'Gene name'])
    gene_map = dict(zip(biomart_df['Gene stable ID'], biomart_df['Gene name']))
    
    # Map gene names and calculate expression variability
    def get_gene_name(ensembl_id):
        base_id = ensembl_id.split('.')[0]
        return gene_map.get(base_id, ensembl_id)
    
    expression_data.index = expression_data.index.map(get_gene_name)
    
    # Calculate coefficient of variation
    mean = expression_data.mean(axis=1)
    std = expression_data.std(axis=1)
    cv = std / mean
    
    # Create expression dataframe
    cv_df = cv.reset_index()
    cv_df.columns = ['Gene', 'Expression']
    cv_df = cv_df[cv_df['Gene'].str.lower().isin([g.lower() for g in training_genes])]
    
    return merged_df, cv_df

def create_gene_categories(merged_df, cv_df, percentile=25):
    """Create gene categories based on constraint and expression variability"""
    # Prepare analysis dataframe
    analysis_df = merged_df[['GENE', 'Mammalian_Constraint']].drop_duplicates()
    analysis_df['GENE'] = analysis_df['GENE'].str.lower()
    cv_df['Gene'] = cv_df['Gene'].str.lower()
    analysis_df = analysis_df.merge(cv_df, left_on='GENE', right_on='Gene', how='inner')
    
    # Calculate thresholds
    constraint_high = np.percentile(analysis_df['Mammalian_Constraint'], percentile)
    constraint_low = np.percentile(analysis_df['Mammalian_Constraint'], 100-percentile)
    expr_var_high = np.percentile(analysis_df['Expression'], 100-percentile)
    expr_var_low = np.percentile(analysis_df['Expression'], percentile)
    
    # Create masks
    high_constraint = analysis_df['Mammalian_Constraint'] <= constraint_high
    low_constraint = analysis_df['Mammalian_Constraint'] >= constraint_low
    high_variability = analysis_df['Expression'] >= expr_var_high
    low_variability = analysis_df['Expression'] <= expr_var_low
    
    return {
        'high_constraint_low_variability': set(analysis_df[high_constraint & low_variability]['GENE']),
        'low_constraint_high_variability': set(analysis_df[low_constraint & high_variability]['GENE'])
    }

def compute_delta_accuracy_across_percentiles(merged_df, cv_df, score_columns, percentiles=range(15, 45)):
    """Compute delta accuracy across different percentile thresholds"""
    results = []
    
    for pct in percentiles:
        print(f"Processing percentile {pct}...")
        categories = create_gene_categories(merged_df, cv_df, percentile=pct)
        group1_genes = categories['high_constraint_low_variability']
        group2_genes = categories['low_constraint_high_variability']
        
        if len(group1_genes) < 10 or len(group2_genes) < 10:
            continue
            
        # Get common TFs
        group1_tfs = set(merged_df[merged_df.GENE.isin(group1_genes)]['TF'])
        group2_tfs = set(merged_df[merged_df.GENE.isin(group2_genes)]['TF'])
        common_tfs = group1_tfs.intersection(group2_tfs)
        
        for tf in common_tfs:
            tf_df = merged_df[merged_df['TF'] == tf]
            available_g1 = set(tf_df[tf_df.GENE.isin(group1_genes)]['GENE'])
            available_g2 = set(tf_df[tf_df.GENE.isin(group2_genes)]['GENE'])
            
            n_samples = min(len(available_g1), len(available_g2))
            if n_samples < 5:
                continue
                
            # Sample equal numbers from each group
            sampled_g1 = np.random.choice(list(available_g1), n_samples, replace=False)
            sampled_g2 = np.random.choice(list(available_g2), n_samples, replace=False)
            
            g1_data = tf_df[tf_df.GENE.isin(sampled_g1)]
            g2_data = tf_df[tf_df.GENE.isin(sampled_g2)]
            
            for score_col in score_columns:
                if score_col not in g1_data.columns or score_col not in g2_data.columns:
                    continue
                    
                g1_scores = g1_data[score_col].values
                g2_scores = g2_data[score_col].values
                
                if np.isnan(g1_scores).any() or np.isnan(g2_scores).any():
                    continue
                
                # Calculate if group2 > group1 (success)
                success = np.mean(g2_scores) > np.mean(g1_scores)
                
                results.append({
                    'percentile': pct,
                    'tf': tf,
                    'model': score_col,
                    'success': success,
                    'group1_mean': np.mean(g1_scores),
                    'group2_mean': np.mean(g2_scores),
                    'n_samples': n_samples
                })
    
    return pd.DataFrame(results)

def main():
    """Main analysis workflow"""
    np.random.seed(42)
    
    # Load data
    merged_df, cv_df = load_data()
    
    # Filter to available model columns
    available_columns = [col for col in model_columns if col in merged_df.columns]
    print(f"Found {len(available_columns)} model columns in data")
    
    # Compute delta accuracy across percentiles
    print("Computing delta accuracy across percentiles...")
    raw_results = compute_delta_accuracy_across_percentiles(merged_df, cv_df, available_columns)
    
    # Save raw results
    raw_results.to_csv(f'{OUTPUT_DIR}/raw_diagonal_results.csv', index=False)
    print(f"Raw results saved to {OUTPUT_DIR}/raw_diagonal_results.csv")
    
    # Aggregate results by model and percentile
    agg_results = []
    for (model, pct), group in raw_results.groupby(['model', 'percentile']):
        success_rate = group['success'].mean() * 100
        delta_accuracy = success_rate - 50
        agg_results.append({
            'model': model,
            'percentile': pct,
            'delta_accuracy': delta_accuracy,
            'n_comparisons': len(group)
        })
    
    df_agg = pd.DataFrame(agg_results)
    
    # Aggregate across all percentiles per model
    model_summary = (
        df_agg.groupby('model')['delta_accuracy']
        .agg(['mean', 'std', 'count'])
        .reset_index()
    )
    model_summary['sem'] = model_summary['std'] / np.sqrt(model_summary['count'])
    model_summary = model_summary.sort_values('mean', ascending=False)
    
    # Create the plot
    plt.figure(figsize=(12, 6))
    
    bars = plt.bar(
        x=np.arange(len(model_summary)),
        height=model_summary['mean'],
        yerr=model_summary['sem'],
        capsize=5,
        color=[model_colors.get(m, '#7f7f7f') for m in model_summary['model']],
        alpha=0.8
    )
    
    # Removed value annotations for cleaner appearance
    
    plt.axhline(0, color='red', linestyle='--', linewidth=1)
    plt.xticks(
        ticks=np.arange(len(model_summary)),
        labels=[model_display_names.get(m, m) for m in model_summary['model']],
        rotation=45,
        ha='right'
    )
    plt.ylabel('Mean Δ-Accuracy (%)')
    plt.title('Average Δ-Accuracy Across All Percentiles\nwith SEM Error Bars')
    plt.tight_layout()
    
    # Save plot
    plt.savefig(f'{OUTPUT_DIR}/delta_accuracy_plot.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{OUTPUT_DIR}/delta_accuracy_plot.pdf', bbox_inches='tight')
    plt.show()
    
    # Save summary results
    model_summary.to_csv(f'{OUTPUT_DIR}/model_summary_results.csv', index=False)
    df_agg.to_csv(f'{OUTPUT_DIR}/percentile_results.csv', index=False)
    
    print(f"Analysis complete! Results saved to {OUTPUT_DIR}/")
    return raw_results, model_summary

if __name__ == "__main__":
    raw_results, model_summary = main()