In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import glob
import numpy as np
import seaborn as sns
from scipy import stats
import matplotlib.colors as mcolors


In [None]:
current_dir = '<path_to_repo>'

In [None]:
df = pd.read_csv(f'{current_dir}/LOL-EVE/data/benchmark_data/gnomad_indel_freq.csv')

In [None]:


def calculate_odds_ratio(low_freq_count, common_count, total_low_freq, total_common):
    odds_low = low_freq_count / max(1, total_low_freq - low_freq_count)
    odds_common = common_count / max(1, total_common - common_count)
    return odds_low / odds_common

def run_analysis_for_threshold(df, low_freq_range, common_threshold, columns_to_analyze, threshold_percentile):
    n_predictions = int(len(df) * threshold_percentile)

    thresholds = {}
    for col in columns_to_analyze:
        if col in ['PhyloP', 'CADD_raw_score']:
            thresholds[col] = np.percentile(df[col], 100 - threshold_percentile * 100)
            df[f'{col}_significant'] = df[col] >= thresholds[col]
        else:
            thresholds[col] = np.percentile(df[col], threshold_percentile * 100)
            df[f'{col}_significant'] = df[col] <= thresholds[col]

    low_freq = df[(df.MAF > low_freq_range[0]) & (df.MAF < low_freq_range[1])]
    common = df[df.MAF >= common_threshold]

    total_low_freq = len(low_freq)
    total_common = len(common)

    data = []
    for col in columns_to_analyze:
        low_freq_count = len(low_freq[low_freq[f'{col}_significant']])
        common_count = len(common[common[f'{col}_significant']])
        odds_ratio = calculate_odds_ratio(low_freq_count, common_count, total_low_freq, total_common)
        data.append((col, odds_ratio))

    return data


def create_threshold_variation_plot(ax, df, low_freq_range, common_threshold, columns_to_analyze, thresholds, row, col):
    # Color scheme remains the same
    colors = {
        "LOL-EVE": "#00aa55",
        "PhyloP": "#FF9AA2",
        "hyenadna-tiny-1k-seqlen": "#A8E6CF",
        "hyenadna-medium-450k-seqlen": "#A2D2FF",
        "hyenadna-medium-160k-seqlen": "#FDFD96",
        "hyenadna-large-1m-seqlen": "#FFB347",
        "hyenadna-small-32k-seqlen": "#E0AAFF",
        "caduceus-ph_seqlen-131k_d_model-256_n_layer-16": "#A3C1AD",
        "caduceus-ps_seqlen-131k_d_model-256_n_layer-16": "#B19CD9",
        "DNABERT-2-117M": "#FFD1DC",
        "nucleotide-transformer-2.5b-multi-species": "#AFEEEE",
        "nucleotide-transformer-2.5b-1000g": "#FFE4E1",
        "nucleotide-transformer-500m-human-ref": "#D0F0C0",
        "nucleotide-transformer-v2-500m-multi-species": "#F0E68C",
    }

    for col_name in columns_to_analyze:
        odds_ratios = []
        for threshold in thresholds:
            data = run_analysis_for_threshold(df, low_freq_range, common_threshold, [col_name], threshold)
            odds_ratios.append(data[0][1])
        
        label = col_name.split('mean_cross_entropy_diff_')[-1] if 'mean_cross_entropy_diff_' in col_name else col_name
        color = colors[label]
        
        if label in ['LOL-EVE', 'PhyloP']:
            ax.plot(thresholds, odds_ratios, marker='o', label=label, color=color, linewidth=3, markersize=8)
        else:
            ax.plot(thresholds, odds_ratios, label=label, color=color, linewidth=2)

    if row == 1:
        ax.set_xlabel('Threshold (percentile)', fontsize=20)
    else:
        ax.set_xlabel('')

    # Add y-axis label to all plots
    ax.set_ylabel('Odds Ratio', fontsize=20)

    ax.set_title(f'MAF {low_freq_range[0]}-{low_freq_range[1]} vs. ≥{common_threshold}', fontsize=20)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.grid(True, which="both", ls="-", alpha=0.1)
    ax.tick_params(axis='both', which='major', labelsize=20)

def create_maf_threshold_panel(df, maf_thresholds, columns_to_analyze, thresholds):
    fig, axs = plt.subplots(2, 2, figsize=(20, 18), dpi=300)

    # Find global min and max for y-axis
    global_min = float('inf')
    global_max = float('-inf')
    for low_freq_range, common_threshold in maf_thresholds:
        for col in columns_to_analyze:
            data = [run_analysis_for_threshold(df, low_freq_range, common_threshold, [col], threshold)[0][1] for threshold in thresholds]
            global_min = min(global_min, min(data))
            global_max = max(global_max, max(data))

    for i, ((row, col), (low_freq_range, common_threshold)) in enumerate(zip([(0,0), (0,1), (1,0), (1,1)], maf_thresholds)):
        create_threshold_variation_plot(axs[row, col], df, low_freq_range, common_threshold, columns_to_analyze, thresholds, row, col)
        axs[row, col].set_ylim(global_min * 0.9, global_max * 1.1)  # Set same y-axis limits for all subplots

    plt.tight_layout()

    handles, labels = axs[0, 0].get_legend_handles_labels()
    simplified_labels = [label.replace('nucleotide-transformer', 'NT') for label in labels]
    
    fig.legend(handles, simplified_labels, loc='lower center', bbox_to_anchor=(0.5, -0.05),
               ncol=2, fontsize=20, borderaxespad=0)

    plt.subplots_adjust(bottom=0.2)
    
    plt.savefig('odds_ratio_threshold_variation_panel_improved.png', dpi=300, bbox_inches='tight')
    plt.show()


# Call the function with your data and parameters
# Define the MAF thresholds
maf_thresholds = [
    ((0, 0.05), 0.05),  # Original: low < 0.05, common >= 0.05
    ((0, 0.01), 0.01),  # Very rare vs. others: low < 0.01, common >= 0.01
    ((0.001, 0.05), 0.05),  # Low frequency vs. common: 0.001 < low < 0.05, common >= 0.05
    ((0.001, 0.01), 0.01)  # Rare vs. others: 0.001 < low < 0.01, common >= 0.01
]

columns_to_analyze = [
    'mean_cross_entropy_diff_hyenadna-tiny-1k-seqlen',
    'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen',
    'mean_cross_entropy_diff_hyenadna-medium-160k-seqlen',
    'mean_cross_entropy_diff_hyenadna-large-1m-seqlen',
    'mean_cross_entropy_diff_hyenadna-small-32k-seqlen',
    'mean_cross_entropy_diff_caduceus-ph_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16',
    'mean_cross_entropy_diff_DNABERT-2-117M',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-multi-species',
    'mean_cross_entropy_diff_nucleotide-transformer-2.5b-1000g',
    'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref',
    'mean_cross_entropy_diff_nucleotide-transformer-v2-500m-multi-species',
    'PhyloP',
    'LOL-EVE'
]

# Define the thresholds to analyze (as percentiles)
thresholds = [0.001, 0.01, 0.05]

# Create the panel plot
create_maf_threshold_panel(df, maf_thresholds, columns_to_analyze, thresholds)

In [None]:


def find_threshold(scores, n_predictions):
    """
    Find the threshold that gives exactly n_predictions.
    Assumes that higher scores are more significant for both models.
    """
    sorted_scores = np.sort(scores)[::-1]  # Sort in descending order
    return sorted_scores[n_predictions - 1]

def calculate_odds_ratio(low_freq_count, common_count, total_low_freq, total_common):
    odds_low = low_freq_count / max(1, total_low_freq - low_freq_count)
    odds_common = common_count / max(1, total_common - common_count)
    return odds_low / odds_common

def run_analysis(df, low_freq_range, common_threshold, columns_to_analyze):
    """
    Run the analysis for a given MAF threshold configuration.
    
    :param df: The DataFrame containing the data
    :param low_freq_range: Tuple of (min_maf, max_maf) for low-frequency variants
    :param common_threshold: Minimum MAF for common variants
    :param columns_to_analyze: List of columns to analyze
    :return: Tuple of (sorted_data, total_low_freq, total_common)
    """


    # Decide on the number of predictions (e.g., top 1% of variants)
    n_predictions = int(len(df) * 0.01)

    # Find thresholds and apply them
    thresholds = {}
    for col in columns_to_analyze:
        if col in ['PhyloP', 'CADD_raw_score']:
            thresholds[col] = find_threshold(df[col], n_predictions)
            df[f'{col}_significant'] = df[col] >= thresholds[col]
        else:
            thresholds[col] = find_threshold(-df[col], n_predictions)
            df[f'{col}_significant'] = df[col] <= -thresholds[col]

    # Define frequency categories
    low_freq = df[(df.MAF > low_freq_range[0]) & (df.MAF < low_freq_range[1])]
    common = df[df.MAF >= common_threshold]

    total_low_freq = len(low_freq)
    total_common = len(common)

    # Calculate odds ratios
    data = []
    for col in columns_to_analyze:
        low_freq_count = len(low_freq[low_freq[f'{col}_significant']])
        common_count = len(common[common[f'{col}_significant']])
        odds_ratio = calculate_odds_ratio(low_freq_count, common_count, total_low_freq, total_common)
        data.append((col, odds_ratio))

    # Sort data by odds ratio
    sorted_data = sorted(data, key=lambda x: x[1], reverse=True)

    return sorted_data, total_low_freq, total_common

def plot_results(sorted_data, total_low_freq, total_common, low_freq_range, common_threshold):
    """
    Plot the results of the analysis with specified color palette.
    """
    names, odds = zip(*sorted_data)

    fig, ax = plt.subplots(figsize=(12, 8))

    # Define colors
    colors = {
        "LOL-EVE": "#00aa55",  # Dartmouth green for LOLEVE
        "Other": "#2f9aea"    # Green Blue for other bars
    }

    y_pos = np.arange(len(names))
    bars = ax.barh(y_pos, odds, align='center', color=[colors["LOL-EVE"] if name == "LOL-EVE" else colors["Other"] for name in names])

    ax.set_yticks(y_pos)
    ax.set_yticklabels([name.split('mean_cross_entropy_diff_')[-1] for name in names] , fontsize=16)
    ax.invert_yaxis()
    ax.set_xlabel('Odds ratio',fontsize=16)
    ax.set_ylabel('Model',fontsize=16)
    ax.yaxis.labelpad = 20 
    #ax.set_title(f'gnomAD v4 low-frequency ({low_freq_range[0]}-{low_freq_range[1]}) vs. common (>={common_threshold})\nn={total_low_freq} vs. {total_common}')

    min_x = 1.05  # Minimum x-coordinate for text
    offset = 0.05  # Offset from the end of the bar

    for i, v in enumerate(odds):
        x_pos = max(v + offset, min_x)
        ax.text(x_pos, i, f'{v:.2f}', va='center', ha='left', fontsize=16)

    ax.set_xlim(1, max(odds) * 1.1)
    ax.axvline(x=1, color='r', linestyle='--', linewidth=1)

    # Add legend
    # legend_elements = [plt.Rectangle((0,0),1,1, facecolor=colors["LOL-EVE"], label='LOL-EVE'),
    #                    plt.Rectangle((0,0),1,1, facecolor=colors["Other"], label='Other models')]
    # ax.legend(handles=legend_elements, loc='lower right')

    plt.tight_layout()
    plt.show()
    
df.rename({'mean_cross_entropy_diff_hyenadna-medium-450k-seqlen': 'HyenaDNA',
           'mean_cross_entropy_diff_caduceus-ps_seqlen-131k_d_model-256_n_layer-16' : 'Caduceus',
           'mean_cross_entropy_diff_nucleotide-transformer-500m-human-ref': 'NT',
           'PhyloP':'PhyloP',
           'mean_cross_entropy_diff_DNABERT-2-117M': 'DNABERT-2'
           }, inplace=True, axis=1)

columns_to_analyze = [
    'HyenaDNA',
    'Caduceus',
    'NT',
    'DNABERT-2',
    'PhyloP',
    'LOL-EVE'
]


# Define different MAF thresholds to analyze
maf_thresholds = [
  #  ((0, 0.05), 0.05),  # Original: low < 0.05, common >= 0.05
    ((0, 0.05), 0.05),  # Very rare vs. others: low < 0.01, common >= 0.01
  #  ((0.001, 0.05), 0.05),  # Low frequency vs. common: 0.01 < low < 0.05, common >= 0.05
  #  ((0.001, 0.01), 0.01)  # Rare vs. others: 0.001 < low < 0.01, common >= 0.01
]

# Run analysis for each threshold configuration
for low_freq_range, common_threshold in maf_thresholds:
    print(f"\nAnalyzing MAF thresholds: Low frequency {low_freq_range}, Common >= {common_threshold}")
    sorted_data, total_low_freq, total_common = run_analysis(df, low_freq_range, common_threshold, columns_to_analyze)
    plot_results(sorted_data, total_low_freq, total_common, low_freq_range, common_threshold)