In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import numpy as np
import glob
import re
from collections import Counter
import os

import pandas as pd
import numpy as np
from scipy.stats import variation
import matplotlib.pyplot as plt
import seaborn as sns
import gzip
from scipy import stats
from statsmodels.stats.multitest import fdrcorrection
from scipy.stats import pointbiserialr
from functools import lru_cache

In [None]:
df = pd.read_csv('')

In [None]:
score_columns = [
    'mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen',
    'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen',
    'mean_cross_entropy_diff_hyenadna-medium-160k-seqlen',
    'mean_cross_entropy_diff_hyenadna-large-1m-seqlen',
    'mean_cross_entropy_diff_hyenadna-small-32k-seqlen', 
    'mean_cross_entropy_diff_DNABERT-2-117M',
    'mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g',
    'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref',
    'mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species',
    'Phylop','GPN',
    'GC_percentage_delta',
    'Distance_TSS',
    'LOL-EVE', 'Enformer'
]

In [None]:
def load_biomart_data(file_path):
    """Load Biomart data for gene name conversion."""
    biomart_df = pd.read_csv(file_path, sep='\t', usecols=['Gene stable ID', 'Gene name'])
    return dict(zip(biomart_df['Gene stable ID'], biomart_df['Gene name']))

@lru_cache(maxsize=1)
def load_gtex_expression_data(file_path):
    with gzip.open(file_path, 'rt') as f:
        next(f)  # Skip headers
        next(f)
        df = pd.read_csv(f, sep='\t', index_col=0)
    return df.drop('Description', axis=1) if 'Description' in df.columns else df

def map_gene_names(cv_df, gene_map):
    # Create a copy of the DataFrame to avoid modifying the original
    result_df = cv_df.copy()
    
    # Create a mapping function that handles missing keys
    def get_gene_name(ensembl_id):
        # Extract the base ENSEMBL ID (before the dot)
        base_id = ensembl_id.split('.')[0]
        return gene_map.get(base_id, ensembl_id)  # Return original ID if not found
    
    # Create a new index with mapped gene names
    result_df.index = result_df.index.map(get_gene_name)
    
    # Sort by values for better visualization
    result_df = result_df.sort_values(ascending=False)
    
    return result_df


def series_to_df(series, value_column_name='Value'):
    # Convert series to DataFrame
    df = series.to_frame(name=value_column_name)
    
    # Reset index to make gene names a column
    df = df.reset_index()
    
    # Rename the index column to 'Gene'
    df = df.rename(columns={'index': 'Gene'})
    
    # Sort by value in descending order (optional)
    df = df.sort_values(by=value_column_name, ascending=False)
    
    return df


def calculate_expression_variability(expression_data):
    # Vectorized calculation of coefficient of variation
    mean = expression_data.mean(axis=1)
    std = expression_data.std(axis=1)
    cv = std / mean
    return cv.sort_values()


def compare_tf_across_groups_vectorized(df, group1_genes, group2_genes, score_columns):
    # Pre-compute gene group masks
    group1_mask = df.GENE.isin(group1_genes)
    group2_mask = df.GENE.isin(group2_genes)
    
    results = []
    unique_tfs = df['TF'].unique()
    
    # Create random scores for comparison
    score_columns = list(score_columns) 
    
    for tf in unique_tfs:
        tf_mask = df.TF == tf
        
        # Get scores for both groups at once
        group1_data = df[group1_mask & tf_mask][score_columns]
        group2_data = df[group2_mask & tf_mask][score_columns]
        
        if len(group1_data) == 0 or len(group2_data) == 0:
            continue
            
        for score_col in score_columns:
            group1_scores = group1_data[score_col].values
            group2_scores = group2_data[score_col].values
            
            # Compute statistics
            all_scores = np.concatenate([group1_scores, group2_scores])
            group_labels = np.concatenate([np.zeros(len(group1_scores)), np.ones(len(group2_scores))])
            
            biserial_corr, _ = pointbiserialr(group_labels, all_scores)
            
            if len(group1_scores) == len(group2_scores):
                statistic, p_value = stats.ttest_rel(group1_scores, group2_scores)
            else:
                statistic, p_value = stats.mannwhitneyu(group1_scores, group2_scores, alternative='two-sided')
            
            results.append({
                'TF': tf,
                'score_column': score_col,
                'statistic': statistic,
                'p_value': p_value,
                'biserial_corr': biserial_corr,
                'group1_mean': np.mean(group1_scores),
                'group2_mean': np.mean(group2_scores),
                'group1_median': np.median(group1_scores),
                'group2_median': np.median(group2_scores),
            })
    
    results_df = pd.DataFrame(results)
    if len(results_df) > 0:
        _, q_values = fdrcorrection(results_df['p_value'])
        results_df['q_value'] = q_values
    return results_df

In [None]:
def plot_accuracy_vs_percentile(cv, results, score_columns, training_genes, percentiles=[1]):
    print(f"Starting analysis with {len(cv)} genes and {len(results)} results rows")
    print(f"Will analyze percentiles: {percentiles}")
    
    # Color mapping
    colors = {
        "LOL-EVE": "#00aa55",
        "Phylop": "#FAD4D4",
        "mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen": "#B7E4C7",
        "mean_cross_entropy_diff_hyenadna-medium-450k-seqlen": "#A9D6E5",
        "mean_cross_entropy_diff_hyenadna-medium-160k-seqlen": "#FFF3B0",
        "mean_cross_entropy_diff_hyenadna-large-1m-seqlen": "#FBC4AB",
        "mean_cross_entropy_diff_hyenadna-small-32k-seqlen": "#D7BBF5",
        "mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16": "#C1D3B4",
        "mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16": "#CFB7E4",
        "mean_cross_entropy_diff_DNABERT-2-117M": "#FFC9C9",
        "mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species": "#B5C9E5",
        "mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g": "#FFE5B4",
        "mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref": "#C6E5B0",
        "mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species": "#BEE0E5",
        "Enformer": "#F7CAB9",
        "Distance_TSS": "#C4D6E7",
        "GC_percentage_delta": "#E6CCF5",
        "GPN": "#B6E2D3"
    }
    
    # Initialize lists to store data for each score column
    plot_data = []
    
    # Pre-compute percentile cutoffs
    cv_values = cv['Expression'].values
    percentile_cutoffs = {
        p: (np.percentile(cv_values, p), np.percentile(cv_values, 100-p))
        for p in percentiles
    }
    
    for p in percentiles:
        print(f"\nProcessing percentile {p}...")
        bottom_cutoff, top_cutoff = percentile_cutoffs[p]
        
        # Vectorized gene selection
        top_genes = set(cv[cv['Expression'] >= top_cutoff]['Gene'])
        bottom_genes = set(cv[cv['Expression'] <= bottom_cutoff]['Gene'])
        print(f"Selected {len(top_genes)} top genes and {len(bottom_genes)} bottom genes")
        results_comp = compare_tf_across_groups_vectorized(results, bottom_genes, top_genes, score_columns)
        
        if len(results_comp) > 0:
            # Calculate accuracy for each score column separately
            for score_col in score_columns:
                score_results = results_comp[results_comp['score_column'] == score_col]
                if len(score_results) > 0:
                    accuracy = (np.sum(score_results['group2_mean'] > score_results['group1_mean']) 
                              / len(score_results) * 100)
                    plot_data.append({
                        'percentile': p,
                        'accuracy': accuracy - 50,
                        'score': score_col
                    })
    
    # Create visualization
    print("\nCreating final visualization...")
    fig, ax = plt.subplots(figsize=(15, 8))
    data = pd.DataFrame(plot_data)
    print(f"Final data points: {len(data)}")
    
    # Create line plot with different line for each score using custom colors
    for score in data['score'].unique():
        score_data = data[data['score'] == score]
        color = colors.get(score, '#000000')  # Default to black if color not found
        sns.lineplot(
            data=score_data,
            x='percentile',
            y='accuracy',
            label=score,
            color=color,
            marker='o',
            ax=ax
        )
    
    ax.set_xlabel('Percentile Threshold (%)')
    ax.set_ylabel('Biological Accuracy (%)')
    ax.set_title('Biological Accuracy vs Percentile Threshold by Score')
    plt.axhline(y=0, color='red', linestyle='--')
    # Rotate legend labels for better readability
    plt.xticks(rotation=0)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    
    plt.tight_layout()
    return fig
# Main execution

# File paths
gtex_file = "/n/groups/marks/users/courtney/projects/regulatory_genomics/models/promEV_private/benchmarks/tfbs_removal/data/GTEx_Analysis_2017-06-05_v8_RNASeQCv1.1.9_gene_median_tpm.gct.gz"
biomart_file_path = "/n/groups/marks/users/courtney/projects/regulatory_genomics/datasets/BIOMART/mart_export_9_20.txt"
gene_map = load_biomart_data(biomart_file_path)

# Load data efficiently
training_genes_df = pd.read_table('/n/groups/marks/databases/whole_genome_alignments/raw_sequences/447_full/promoters_1000_v2/Homo_sapiens_promoters_1000_raw_no_overlap_filtered.bed', header=None)
training_genes_df[3] = training_genes_df[3].apply(lambda x: x.split('promoter_')[1])
training_genes = training_genes_df[3].unique()

# Load GTEx data and calculate CV
expression_data = load_gtex_expression_data(gtex_file)
cv = calculate_expression_variability(expression_data)
cv.dropna(inplace=True)

mapped_df = map_gene_names(cv, gene_map)
cv = series_to_df(mapped_df, value_column_name='Expression')
cv.rename({'Name':'Gene'}, axis=1, inplace=True)

cv = cv[cv.Gene.isin(training_genes)]
cv.Gene = cv.Gene.apply(lambda x: x.lower())
# Load results more efficiently
results = df

# # Create and save plot
fig = plot_accuracy_vs_percentile(cv, results, score_columns, training_genes)
plt.savefig('accuracy_vs_percentile_with_training.svg', 
            dpi=300, 
            bbox_inches='tight',
            format='svg',
            transparent=True,
            facecolor='none',
            edgecolor='none')

# plt.savefig('accuracy_vs_percentile_with_training2.png', 
#             dpi=300, 
#             bbox_inches='tight')
plt.close(fig)  # Clean up memory