In [1]:
#!/usr/bin/env python
# coding: utf-8



In [2]:


#!/usr/bin/env python
# coding: utf-8
# Note: Assuming necessary imports and setup code (like log) are handled elsewhere or globally.
# For reproducibility, explicitly import logging if 'log' is used.
import logging
log = logging.getLogger(__name__)
# Add basic configuration if running standalone
# logging.basicConfig(level=logging.INFO)




In [3]:


import os
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from IPython.display import display, HTML # Added for explicit display calls

# Import utility functions (Assuming these exist in the environment)
# Make sure benchmark_utils.py is accessible
try:
    from benchmark_utils import *
except ImportError:
    log.error("benchmark_utils.py not found. Please ensure it's in the Python path.")
    # Define dummy functions if needed, or raise error
    def load_assembly_annotations(gene_type, path): return {}
    def parse_parameter_dir(param_dir): return {'landmark_groups': 0, 'landmarks_per_group': 0, 'stdev_scaling': 0}
    def get_samples_for_params(gene_type, param_dir): return []
    def load_genotype_calls(sample, gene_type, param_dir): return []
    def calculate_accuracy_metrics(genotypes, ground_truth, cnv): return [], [], []
    def calculate_precision_recall(tp, fp, fn): return 0.0, 0.0
    def get_parameter_dirs(gene_type): return []
    def analyze_default_params(gene_type, ground_truth, cnv): return pd.DataFrame()




In [4]:


# All gene types
gene_types = ['IGHV', 'IGLV', 'IGKV', 'TRAV', 'TRBV', 'TRGV', 'TRDV']




In [5]:


# Load ground truth data (Ensure path is correct)
ground_truth_data = {}
ground_truth_path = "/home/fordmk/data/ImmunoTyper-expansion-methods/HPRC-assembly-benchmarking/digger-functional-annotations-all/" # Example path
for gene_type in gene_types:
    try:
        ground_truth_data[gene_type] = load_assembly_annotations(gene_type, path=ground_truth_path)
        log.info(f"Loaded ground truth for {gene_type}: {len(ground_truth_data[gene_type])} samples")
    except Exception as e:
        log.error(f"Failed to load ground truth for {gene_type} from {ground_truth_path}: {e}")
        ground_truth_data[gene_type] = {}




INFO:benchmark_utils:Loading annotations for IGHV from /home/fordmk/data/ImmunoTyper-expansion-methods/HPRC-assembly-benchmarking/digger-functional-annotations-all/
INFO:benchmark_utils:Loaded ground truth for IGHV: 47 samples
INFO:benchmark_utils:Loading annotations for IGLV from /home/fordmk/data/ImmunoTyper-expansion-methods/HPRC-assembly-benchmarking/digger-functional-annotations-all/
INFO:benchmark_utils:Loaded ground truth for IGLV: 47 samples
INFO:benchmark_utils:Loading annotations for IGKV from /home/fordmk/data/ImmunoTyper-expansion-methods/HPRC-assembly-benchmarking/digger-functional-annotations-all/
INFO:benchmark_utils:Loaded ground truth for IGKV: 47 samples
INFO:benchmark_utils:Loading annotations for TRAV from /home/fordmk/data/ImmunoTyper-expansion-methods/HPRC-assembly-benchmarking/digger-functional-annotations-all/
INFO:benchmark_utils:Loaded ground truth for TRAV: 47 samples
INFO:benchmark_utils:Loading annotations for TRBV from /home/fordmk/data/ImmunoTyper-expansi

In [6]:


def analyze_gene_type_params(gene_type, param_dir, ground_truth, cnv=False):
    """
    Analyze results for a specific gene type and parameter set.
    Calculates F1 score for each individual sample.

    Args:
        gene_type: Type of gene (e.g., 'IGHV', 'TRAV')
        param_dir: Parameter directory name
        ground_truth: Dictionary with ground truth data
        cnv: Whether to consider copy number variations

    Returns:
        DataFrame with metrics for each sample
    """
    # Get parameter values
    params = parse_parameter_dir(param_dir)

    # Get all available samples for this parameter set
    samples = get_samples_for_params(gene_type.lower(), param_dir)

    # Create empty list to store results
    results = []

    for sample in samples:
        # Skip if sample not in ground truth
        if sample not in ground_truth:
            log.warning(f"Sample {sample} not found in ground truth for {gene_type}")
            continue

        try:
            # Load genotype calls
            genotypes = load_genotype_calls(sample, gene_type.lower(), param_dir)

            # Calculate metrics
            tp, fp, fn = calculate_accuracy_metrics(
                genotypes,
                ground_truth[sample],
                cnv=cnv
            )
            precision, recall = calculate_precision_recall(tp, fp, fn)

            # Calculate F1 score for this specific sample
            f1_score_individual = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            # Store results
            results.append({
                'sample': sample,
                'gene_type': gene_type,
                'param_dir': param_dir,
                'landmark_groups': params['landmark_groups'],
                'landmarks_per_group': params['landmarks_per_group'],
                'stdev_scaling': params['stdev_scaling'],
                'precision': precision,
                'recall': recall,
                'f1_score': f1_score_individual, # Store the per-sample F1
                'num_true_positives': len(tp),
                'num_false_positives': len(fp),
                'num_false_negatives': len(fn)
            })
        except FileNotFoundError:
            log.warning(f"Genotype calls not found for {sample} with parameters {param_dir}")
            continue
        except Exception as e:
            log.error(f"Error processing {sample} for {gene_type} with {param_dir}: {e}")
            continue

    # Create DataFrame
    if results:
        return pd.DataFrame(results)
    else:
        log.warning(f"No results for {gene_type} with parameters {param_dir}")
        return pd.DataFrame()




In [7]:


# --- Cell [6] remains the same ---
all_results = []

for gene_type in gene_types:
    # Get all parameter directories for this gene type
    param_dirs = get_parameter_dirs(gene_type.lower())
    log.info(f"Found {len(param_dirs)} parameter sets for {gene_type}")

    # Analyze each parameter set
    for param_dir in param_dirs:
        log.info(f"Analyzing {gene_type} with parameters {param_dir}")
        df = analyze_gene_type_params(
            gene_type,
            param_dir,
            ground_truth_data.get(gene_type, {}), # Use .get for safety
            cnv=False  # Change to True if you want CNV-sensitive analysis
        )
        if not df.empty:
            all_results.append(df)




INFO:benchmark_utils:Found 8 parameter sets for IGHV
INFO:benchmark_utils:Analyzing IGHV with parameters lg4_lpg4_stdev1.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg4_lpg4_stdev2.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg4_lpg8_stdev1.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg4_lpg8_stdev2.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg8_lpg4_stdev1.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg8_lpg4_stdev2.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg8_lpg8_stdev1.0
INFO:benchmark_utils:Analyzing IGHV with parameters lg8_lpg8_stdev2.0
INFO:benchmark_utils:Found 8 parameter sets for IGLV
INFO:benchmark_utils:Analyzing IGLV with parameters lg4_lpg4_stdev1.0
INFO:benchmark_utils:Analyzing IGLV with parameters lg4_lpg4_stdev2.0
INFO:benchmark_utils:Analyzing IGLV with parameters lg4_lpg8_stdev1.0
INFO:benchmark_utils:Analyzing IGLV with parameters lg4_lpg8_stdev2.0
INFO:benchmark_utils:Analyzing IGLV with parameters lg

In [8]:


# --- Cell [7] remains mostly the same ---
log.info("Processing results for default parameters (6, 6, 1.5)")

default_results = []
for gene_type in gene_types:
    log.info(f"Analyzing {gene_type} with default parameters")
    df = analyze_default_params(gene_type, ground_truth_data.get(gene_type, {}), cnv=False)
    if not df.empty:
        default_results.append(df)
        log.info(f"Added {len(df)} samples with default parameters for {gene_type}")

# Add default parameter results to the overall combined results
if default_results:
    default_df = pd.concat(default_results, ignore_index=True)
    log.info(f"Total default parameter results: {len(default_df)} samples")

    # Append to all_results
    all_results.append(default_df)

# Combine all results
if all_results:
    combined_df = pd.concat(all_results, ignore_index=True)
    # Add check for necessary columns
    required_cols = ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling', 'precision', 'recall', 'f1_score', 'sample']
    if all(col in combined_df.columns for col in required_cols):
        log.info(f"Combined results: {len(combined_df)} rows")
    else:
        log.error(f"Combined DataFrame is missing required columns. Found: {combined_df.columns.tolist()}")
        combined_df = pd.DataFrame() # Reset to empty if invalid
else:
    log.warning("No results found for any gene type or parameter set")
    combined_df = pd.DataFrame()




INFO:benchmark_utils:Processing results for default parameters (6, 6, 1.5)
INFO:benchmark_utils:Analyzing IGHV with default parameters
INFO:benchmark_utils:Found 41 samples for IGHV with default parameters
INFO:benchmark_utils:Added 40 samples with default parameters for IGHV
INFO:benchmark_utils:Analyzing IGLV with default parameters
INFO:benchmark_utils:Found 40 samples for IGLV with default parameters
INFO:benchmark_utils:Added 40 samples with default parameters for IGLV
INFO:benchmark_utils:Analyzing IGKV with default parameters
INFO:benchmark_utils:Found 40 samples for IGKV with default parameters
INFO:benchmark_utils:Added 40 samples with default parameters for IGKV
INFO:benchmark_utils:Analyzing TRAV with default parameters
INFO:benchmark_utils:Found 40 samples for TRAV with default parameters
INFO:benchmark_utils:Added 40 samples with default parameters for TRAV
INFO:benchmark_utils:Analyzing TRBV with default parameters
INFO:benchmark_utils:Found 40 samples for TRBV with defau

In [9]:


# --- MODIFIED Cell [8] ---
# Calculate grouped metrics and recalculate F1 from means

def calculate_f1_from_means(precision, recall):
    """Helper function to calculate F1 from precision and recall means."""
    if pd.isna(precision) or pd.isna(recall) or (precision + recall) == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)

grouped_metrics = pd.DataFrame() # Initialize empty
if not combined_df.empty and all(col in combined_df.columns for col in ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling', 'precision', 'recall', 'f1_score', 'sample']):
    # Group by gene type and parameter values
    grouped_metrics = combined_df.groupby(
        ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling']
    ).agg({
        'precision': ['mean', 'median', 'std'],
        'recall': ['mean', 'median', 'std'],
        'f1_score': ['mean', 'median', 'std'], # This is the mean of individual F1s
        'sample': 'count'  # Number of samples
    }).reset_index()

    # Rename columns for clarity
    grouped_metrics.columns = [
        '_'.join(col).strip('_') for col in grouped_metrics.columns.values
    ]

    # Rename sample_count to num_samples
    grouped_metrics = grouped_metrics.rename(columns={
        'sample_count': 'num_samples',
        'f1_score_mean': 'f1_score_mean_orig', # Rename original mean F1
        'f1_score_median': 'f1_score_median_orig',
        'f1_score_std': 'f1_score_std_orig'
        })

    # *** Calculate NEW F1 score mean from the mean precision and mean recall ***
    grouped_metrics['f1_score_mean'] = grouped_metrics.apply(
        lambda row: calculate_f1_from_means(row['precision_mean'], row['recall_mean']),
        axis=1
    )

    # Round metrics to 3 decimal places
    for col in grouped_metrics.columns:
        # Check if column exists before rounding (safety)
        if col in grouped_metrics.columns and col.startswith(('precision', 'recall', 'f1_score')):
             # Ensure column is numeric before rounding
             if pd.api.types.is_numeric_dtype(grouped_metrics[col]):
                grouped_metrics[col] = grouped_metrics[col].round(3)
             else:
                log.warning(f"Column {col} is not numeric, skipping rounding.")


    # Sort by gene type and then by the NEW mean F1 score (descending)
    grouped_metrics = grouped_metrics.sort_values(
        ['gene_type', 'f1_score_mean'], # Sort by the F1 calculated from means
        ascending=[True, False]
    )
else:
    log.warning("Combined DataFrame is empty or missing columns, cannot generate grouped_metrics.")
    # grouped_metrics remains an empty DataFrame




In [10]:


# --- MODIFIED Cell [9] ---
# Find best parameters based on the new F1 score mean

summary_table = pd.DataFrame() # Initialize empty
if not grouped_metrics.empty:
    # Find best parameter set for each gene type based on the NEW F1 score mean
    best_params_indices = grouped_metrics.groupby('gene_type')['f1_score_mean'].idxmax()
    best_params = grouped_metrics.loc[best_params_indices]

    # Create summary table with key metrics (using the new f1_score_mean)
    summary_table = best_params[
        ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling',
         'precision_mean', 'recall_mean', 'f1_score_mean', 'num_samples'] # Use the new f1_score_mean
    ].copy()

    # Add row for all parameter combinations
    if not combined_df.empty:
        # Calculate overall mean precision and recall first
        overall_mean_precision = combined_df['precision'].mean()
        overall_mean_recall = combined_df['recall'].mean()
        # Calculate overall F1 score from these means
        overall_f1_from_means = calculate_f1_from_means(overall_mean_precision, overall_mean_recall)

        all_gene_types_row = pd.DataFrame([{
            'gene_type': 'ALL',
            'landmark_groups': None,
            'landmarks_per_group': None,
            'stdev_scaling': None,
            'precision_mean': overall_mean_precision.round(3),
            'recall_mean': overall_mean_recall.round(3),
            'f1_score_mean': overall_f1_from_means.round(3), # Use F1 calculated from overall means
            'num_samples': combined_df['sample'].nunique() # Count unique samples across all groups
        }])

        summary_table = pd.concat([all_gene_types_row, summary_table], ignore_index=True)
    else:
         log.warning("combined_df is empty, cannot calculate 'ALL' row for summary table.")

    # Display summary table
    print("Best Parameter Combinations by Gene Type (using F1 calculated from mean P & R):")
    print(summary_table.to_string(index=False))
else:
    print("Grouped metrics are empty, cannot generate summary table.")




Best Parameter Combinations by Gene Type (using F1 calculated from mean P & R):
gene_type landmark_groups landmarks_per_group  stdev_scaling  precision_mean  recall_mean  f1_score_mean  num_samples
      ALL            None                None            NaN           0.774        0.865          0.817           40
     IGHV               4                   8            2.0           0.813        0.814          0.814           40
     IGKV               4                   4            2.0           0.798        0.730          0.763           40
     IGLV               8                   8            2.0           0.956        0.904          0.929           40
     TRAV               4                   4            2.0           0.788        0.932          0.854           40
     TRBV               4                   4            2.0           0.872        0.931          0.901           40
     TRDV               4                   4            2.0           0.372        0.962     

  summary_table = pd.concat([all_gene_types_row, summary_table], ignore_index=True)


In [11]:


# --- Cell [10] remains the same ---
# Displays pivot table, F1 here is still the mean of individual F1s
if not combined_df.empty and all(col in combined_df.columns for col in ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling', 'precision', 'recall', 'f1_score']):
    try:
        # Pivot table for comprehensive view across parameters
        performance_table = combined_df.pivot_table(
            index=['gene_type'],
            columns=['landmark_groups', 'landmarks_per_group', 'stdev_scaling'],
            values=['precision', 'recall', 'f1_score'], # f1_score here is the mean of individuals
            aggfunc='mean'
        ).round(3)

        # Reset index for better display
        performance_table = performance_table.reset_index()

        # Display detailed performance table
        print("\nDetailed Performance Across All Parameter Combinations (F1 is mean of individuals):")
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 200)
        print(performance_table)
    except Exception as e:
        log.error(f"Error creating pivot table: {e}")
else:
    print("Combined DataFrame empty or missing columns, cannot generate detailed performance table.")





Detailed Performance Across All Parameter Combinations (F1 is mean of individuals):
                    gene_type f1_score                                                         precision                                                         recall                              \
landmark_groups                      4                           6      8                              4                           6      8                           4                           6   
landmarks_per_group                  4             8             6      4             8                4             8             6      4             8             4             8             6   
stdev_scaling                      1.0    2.0    1.0    2.0    1.5    1.0    2.0    1.0    2.0       1.0    2.0    1.0    2.0    1.5    1.0    2.0    1.0    2.0    1.0    2.0    1.0    2.0    1.5   
0                        IGHV    0.790  0.803  0.791  0.812  0.812  0.786  0.812    NaN    NaN     0.805  0.797  0.814 

In [12]:


# --- Cell [11] remains the same ---
# Heatmaps are based on the original per-sample F1 scores averaged.
if not combined_df.empty and 'f1_score' in combined_df.columns:
    # For each gene type, create a heatmap of parameter impacts
    for gene_type in combined_df['gene_type'].unique():
        # Filter data for this gene type
        gene_data = combined_df[combined_df['gene_type'] == gene_type]

        # Check if we have multiple parameter combinations
        param_cols = ['landmark_groups', 'landmarks_per_group', 'stdev_scaling']
        if all(col in gene_data.columns for col in param_cols) and \
           len(gene_data[param_cols].drop_duplicates()) > 1:

            # Create pivot table for heatmap (using original mean f1)
            try:
                if len(gene_data['landmark_groups'].unique()) > 1 and len(gene_data['landmarks_per_group'].unique()) > 1:
                    # Determine unique stdev scaling values present for this gene type
                    stdev_scalings = gene_data['stdev_scaling'].unique()
                    stdev_scaling_value = stdev_scalings[0] if len(stdev_scalings) > 0 else 'N/A' # Use first if available

                    # Pivot focusing on landmark groups and landmarks per group
                    pivot_data = gene_data.pivot_table(
                        index='landmark_groups',
                        columns='landmarks_per_group',
                        values='f1_score', # Uses the original per-sample F1, then averaged
                        aggfunc='mean'
                    ).round(3)

                    # Plot heatmap
                    plt.figure(figsize=(10, 8))
                    sns.heatmap(pivot_data, annot=True, cmap='viridis', fmt='.3f')
                    # Include stdev scaling info in title if consistent, otherwise indicate variation
                    stdev_title_part = f'Stdev Scaling: {stdev_scaling_value}' if len(stdev_scalings) == 1 else 'Multiple Stdev Scalings'
                    plt.title(f'Mean Individual F1 Score by Parameters for {gene_type}\n{stdev_title_part}')
                    plt.xlabel('Landmarks Per Group')
                    plt.ylabel('Landmark Groups')
                    plt.tight_layout()
                    plt.savefig(f'{gene_type}_parameter_heatmap_orig_f1.png')
                    plt.close()
                else:
                    log.info(f"Skipping heatmap for {gene_type}: Not enough variation in landmark groups/per group.")
            except Exception as e:
                log.error(f"Error creating heatmap for {gene_type}: {e}")
        else:
            log.info(f"Skipping heatmap for {gene_type}: Insufficient parameter combinations or missing columns.")
else:
    print("No data available to generate heatmaps.")




In [13]:


# --- MODIFIED Cell [12] ---
# Generates LaTeX table based on the summary_table from Cell [9]
# which now uses the F1 calculated from means.

if not summary_table.empty:
    # Create LaTeX table header
    latex_table = """\\begin{table}[ht]
\\centering
\\caption{Best Parameter Combinations for Each Gene Type (F1 from Mean P \& R)}
\\label{tab:best_params} % Added a label for referencing
\\begin{tabular}{lcccccc}
\\toprule
Gene Type & Landmark & Landmarks & Stdev & Precision & Recall & F1 Score \\\\
& Groups & Per Group & Scaling & (mean) & (mean) & (mean) \\\\
\\midrule
"""

    # Add rows for each gene type
    for _, row in summary_table.iterrows():
        # Format values, handling potential None/NaN for the 'ALL' row
        gene_type = row['gene_type']
        lg = int(row['landmark_groups']) if pd.notna(row['landmark_groups']) else '—'
        lpg = int(row['landmarks_per_group']) if pd.notna(row['landmarks_per_group']) else '—'
        stdev = f"{row['stdev_scaling']:.1f}" if pd.notna(row['stdev_scaling']) else '—' # Format stdev
        precision = f"{row['precision_mean']:.3f}" if pd.notna(row['precision_mean']) else 'N/A'
        recall = f"{row['recall_mean']:.3f}" if pd.notna(row['recall_mean']) else 'N/A'
        f1 = f"{row['f1_score_mean']:.3f}" if pd.notna(row['f1_score_mean']) else 'N/A' # Uses the new F1

        latex_row = f"\\textit{{{gene_type}}} & {lg} & {lpg} & {stdev} & {precision} & {recall} & {f1} \\\\"
        latex_table += latex_row + "\n"

    # Add table footer
    latex_table += """\\bottomrule
\\end{tabular}
\\end{table}"""

    # Display LaTeX table
    print("\nLaTeX Table for Best Parameter Combinations (using F1 from mean P & R):")
    print(latex_table)
else:
    print("Summary table is empty, cannot generate LaTeX table.")





LaTeX Table for Best Parameter Combinations (using F1 from mean P & R):
\begin{table}[ht]
\centering
\caption{Best Parameter Combinations for Each Gene Type (F1 from Mean P \& R)}
\label{tab:best_params} % Added a label for referencing
\begin{tabular}{lcccccc}
\toprule
Gene Type & Landmark & Landmarks & Stdev & Precision & Recall & F1 Score \\
& Groups & Per Group & Scaling & (mean) & (mean) & (mean) \\
\midrule
\textit{ALL} & — & — & — & 0.774 & 0.865 & 0.817 \\
\textit{IGHV} & 4 & 8 & 2.0 & 0.813 & 0.814 & 0.814 \\
\textit{IGKV} & 4 & 4 & 2.0 & 0.798 & 0.730 & 0.763 \\
\textit{IGLV} & 8 & 8 & 2.0 & 0.956 & 0.904 & 0.929 \\
\textit{TRAV} & 4 & 4 & 2.0 & 0.788 & 0.932 & 0.854 \\
\textit{TRBV} & 4 & 4 & 2.0 & 0.872 & 0.931 & 0.901 \\
\textit{TRDV} & 4 & 4 & 2.0 & 0.372 & 0.962 & 0.537 \\
\textit{TRGV} & 4 & 4 & 2.0 & 0.923 & 0.837 & 0.878 \\
\bottomrule
\end{tabular}
\end{table}


In [14]:


# --- MODIFIED Cell [13] ---
# Generate comprehensive table using NEW F1 score calculation

param_metrics = pd.DataFrame() # Initialize empty
if not combined_df.empty and all(col in combined_df.columns for col in ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling', 'precision', 'recall', 'f1_score', 'sample']):
    # Create a comprehensive table with all parameter combinations
    param_metrics = combined_df.groupby(
        ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling']
    ).agg(
        mean_precision=('precision', 'mean'),
        mean_recall=('recall', 'mean'),
        mean_of_individual_f1s=('f1_score', 'mean'), # Keep the original mean F1
        num_samples=('sample', 'count')
    ).reset_index()

    # *** Calculate NEW F1 score from the mean precision and mean recall ***
    param_metrics['mean_f1_score'] = param_metrics.apply(
        lambda row: calculate_f1_from_means(row['mean_precision'], row['mean_recall']),
        axis=1
    )

    # Round metrics (including the new F1 score)
    metrics_to_round = ['mean_precision', 'mean_recall', 'mean_of_individual_f1s', 'mean_f1_score']
    for col in metrics_to_round:
         if col in param_metrics.columns: # Check if column exists
             param_metrics[col] = param_metrics[col].round(3)

    # Reorder columns for clarity (optional)
    col_order = ['gene_type', 'landmark_groups', 'landmarks_per_group', 'stdev_scaling',
                 'mean_precision', 'mean_recall', 'mean_f1_score', # Put new F1 here
                 'mean_of_individual_f1s', 'num_samples']
    # Ensure all columns exist before reordering
    param_metrics = param_metrics[[c for c in col_order if c in param_metrics.columns]]


    # Sort by gene type, then by the NEW F1 score
    param_metrics = param_metrics.sort_values(
        ['gene_type', 'mean_f1_score'], # Sort by the new F1 score
        ascending=[True, False]
    )

    # Display the table using IPython display

    # Style the table - highlight the new F1 score
    styled_table = param_metrics.style.set_properties(**{
        'text-align': 'center',
        'border': '1px solid black',
        'padding': '5px'
    }).background_gradient(
        subset=['mean_precision', 'mean_recall', 'mean_f1_score'], # Gradient on new F1
        cmap='viridis'
    ).format({ # Format numbers in display
        'mean_precision': '{:.3f}',
        'mean_recall': '{:.3f}',
        'mean_f1_score': '{:.3f}',
        'mean_of_individual_f1s': '{:.3f}',
        'stdev_scaling': '{:.1f}'
    }).set_caption('Comprehensive Performance Metrics (F1 Calculated from Mean P & R)')

    # Display the styled table
    display(HTML("<h2>Performance Metrics for All Parameter Combinations (F1 from Mean P & R)</h2>"))
    display(styled_table)
else:
    print("Combined DataFrame empty or missing columns, cannot display parameter metrics table.")




Unnamed: 0,gene_type,landmark_groups,landmarks_per_group,stdev_scaling,mean_precision,mean_recall,mean_f1_score,mean_of_individual_f1s,num_samples
3,IGHV,4,8,2.0,0.813,0.814,0.814,0.812,40
4,IGHV,6,6,1.5,0.825,0.801,0.813,0.812,40
6,IGHV,8,4,2.0,0.813,0.814,0.813,0.812,40
1,IGHV,4,4,2.0,0.797,0.813,0.805,0.803,40
2,IGHV,4,8,1.0,0.814,0.772,0.792,0.791,31
0,IGHV,4,4,1.0,0.805,0.778,0.791,0.79,39
5,IGHV,8,4,1.0,0.81,0.767,0.788,0.786,34
8,IGKV,4,4,2.0,0.798,0.73,0.763,0.76,40
10,IGKV,4,8,2.0,0.801,0.729,0.763,0.761,40
11,IGKV,6,6,1.5,0.809,0.722,0.763,0.761,40


In [15]:


# --- MODIFIED Cell [14] ---
# Create comparison table using the NEW F1 calculation for both default and best

comparison_df = pd.DataFrame() # Initialize empty
if not combined_df.empty and not grouped_metrics.empty:
    comparison_table = []

    # Loop through each gene type present in grouped_metrics
    for gene_type in grouped_metrics['gene_type'].unique():
        # Filter combined data for this gene type
        gene_data = combined_df[combined_df['gene_type'] == gene_type]

        # Get default parameter results (6, 6, 1.5) for this gene type
        default_param_mask = (
            (gene_data['landmark_groups'] == 6) &
            (gene_data['landmarks_per_group'] == 6) &
            (gene_data['stdev_scaling'] == 1.5)
        )
        default_results_group = gene_data[default_param_mask]

        # Skip if default results don't exist for this gene type or have no data
        if default_results_group.empty:
            log.warning(f"No default parameter results found for {gene_type} in combined_df")
            continue

        # *** Calculate F1 for default params FROM THEIR MEAN P/R ***
        default_mean_precision = default_results_group['precision'].mean()
        default_mean_recall = default_results_group['recall'].mean()
        default_f1 = calculate_f1_from_means(default_mean_precision, default_mean_recall)

        # Get best parameter set from the modified grouped_metrics (already has new F1)
        # Find the row corresponding to the best params for this gene_type
        best_params_row = grouped_metrics[
            (grouped_metrics['gene_type'] == gene_type)
        ].iloc[0] # Already sorted, so first row is best

        best_f1 = best_params_row['f1_score_mean'] # This is the F1 calculated from means

        # Calculate improvement using the consistent F1 calculation method
        f1_difference = best_f1 - default_f1
        percent_improvement = (f1_difference / default_f1 * 100) if default_f1 > 0 else float('inf')

        # Add to comparison table
        comparison_table.append({
            'gene_type': gene_type,
            'default_landmark_groups': 6,
            'default_landmarks_per_group': 6,
            'default_stdev_scaling': 1.5,
            'default_f1_score': default_f1, # F1 from mean P/R
            'best_landmark_groups': best_params_row['landmark_groups'],
            'best_landmarks_per_group': best_params_row['landmarks_per_group'],
            'best_stdev_scaling': best_params_row['stdev_scaling'],
            'best_f1_score': best_f1, # F1 from mean P/R
            'f1_difference': f1_difference,
            'percent_improvement': percent_improvement if np.isfinite(percent_improvement) else np.nan # Handle inf case
        })

    # Create DataFrame from comparison table
    if comparison_table:
        comparison_df = pd.DataFrame(comparison_table)

        # Sort by percent improvement (descending), handle NaN
        comparison_df = comparison_df.sort_values('percent_improvement', ascending=False, na_position='last')

        # Display comparison table
        print("\nComparison of Default Parameters vs. Best Parameters (using F1 from mean P & R):")
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 150)
        # Round for display
        display_comparison_df = comparison_df.round({
            'default_f1_score': 3, 'best_f1_score': 3, 'f1_difference': 3, 'percent_improvement': 1
            })
        print(display_comparison_df.to_string(index=False))

        # Create a more concise version for display
        concise_comparison = display_comparison_df[[
            'gene_type', 'default_f1_score',
            'best_landmark_groups', 'best_landmarks_per_group', 'best_stdev_scaling',
            'best_f1_score', 'f1_difference', 'percent_improvement'
        ]].copy()

        # Display using IPython display for better formatting

        # Style the table
        styled_comparison = concise_comparison.style.set_properties(**{
            'text-align': 'center',
            'border': '1px solid black',
            'padding': '5px'
        }).background_gradient(
            subset=['f1_difference', 'percent_improvement'],
            cmap='RdYlGn',
            axis=0 # Color each column independently
        ).format({ # Apply formatting
            'default_f1_score': '{:.3f}',
            'best_f1_score': '{:.3f}',
            'f1_difference': '{:+.3f}', # Add sign
            'percent_improvement': '{:.1f}%',
            'best_landmark_groups': '{:.0f}', # Integers
            'best_landmarks_per_group': '{:.0f}',
            'best_stdev_scaling': '{:.1f}'
        }).set_caption('Performance Improvement: Default Parameters vs. Best Parameters (F1 from Mean P & R)')

        # Display the styled table
        display(HTML("<h2>F1 Score Improvement: Default vs. Best Parameters (F1 from Mean P & R)</h2>"))
        display(styled_comparison)

        # Generate LaTeX table for the comparison (using the consistently calculated F1s)
        latex_comparison = """\\begin{table}[ht]
\\centering
\\caption{Comparison of Default vs Best Parameter Sets by Gene Type (F1 from Mean P \\& R)}
\\label{tab:param_comparison} % Added label
\\begin{tabular}{lccccc}
\\toprule
Gene Type & Default & Best Parameters & Best & F1 Score & Improvement \\\\
& F1 Score & (LG, LPG, Stdev) & F1 Score & Difference & (\\%) \\\\
\\midrule
"""

        # Add rows for each gene type from the rounded display df
        for _, row in display_comparison_df.iterrows():
            gene_type = row['gene_type']
            default_f1 = f"{row['default_f1_score']:.3f}"
            # Ensure best params are integers/formatted correctly for display
            lg = int(row['best_landmark_groups']) if pd.notna(row['best_landmark_groups']) else 'N/A'
            lpg = int(row['best_landmarks_per_group']) if pd.notna(row['best_landmarks_per_group']) else 'N/A'
            stdev = f"{row['best_stdev_scaling']:.1f}" if pd.notna(row['best_stdev_scaling']) else 'N/A'
            best_params = f"({lg}, {lpg}, {stdev})"
            best_f1 = f"{row['best_f1_score']:.3f}"
            diff = f"{row['f1_difference']:+.3f}" # Add sign
            pct = f"{row['percent_improvement']:.1f}" if pd.notna(row['percent_improvement']) else 'N/A'

            latex_row = f"\\textit{{{gene_type}}} & {default_f1} & {best_params} & {best_f1} & {diff} & {pct}\\% \\\\"
            latex_comparison += latex_row + "\n"

        # Add table footer
        latex_comparison += """\\bottomrule
\\end{tabular}
\\end{table}"""

        print("\nLaTeX Table for Parameter Comparison (using F1 from mean P & R):")
        print(latex_comparison)

    else:
        print("Could not generate comparison table (no data in comparison_table list).")
else:
    print("Combined DataFrame or Grouped Metrics are empty, cannot perform parameter comparison.")





Comparison of Default Parameters vs. Best Parameters (using F1 from mean P & R):
gene_type  default_landmark_groups  default_landmarks_per_group  default_stdev_scaling  default_f1_score  best_landmark_groups  best_landmarks_per_group  best_stdev_scaling  best_f1_score  f1_difference  percent_improvement
     TRDV                        6                            6                    1.5             0.519                     4                         4                 2.0          0.537          0.018                  3.5
     IGLV                        6                            6                    1.5             0.926                     8                         8                 2.0          0.929          0.003                  0.4
     TRGV                        6                            6                    1.5             0.875                     4                         4                 2.0          0.878          0.003                  0.4
     TRAV             

Unnamed: 0,gene_type,default_f1_score,best_landmark_groups,best_landmarks_per_group,best_stdev_scaling,best_f1_score,f1_difference,percent_improvement
5,TRDV,0.519,4,4,2.0,0.537,0.018,3.5%
2,IGLV,0.926,8,8,2.0,0.929,0.003,0.4%
6,TRGV,0.875,4,4,2.0,0.878,0.003,0.4%
3,TRAV,0.852,4,4,2.0,0.854,0.002,0.2%
0,IGHV,0.813,4,8,2.0,0.814,0.001,0.1%
4,TRBV,0.9,4,4,2.0,0.901,0.001,0.1%
1,IGKV,0.763,4,4,2.0,0.763,-0.0,-0.0%



LaTeX Table for Parameter Comparison (using F1 from mean P & R):
\begin{table}[ht]
\centering
\caption{Comparison of Default vs Best Parameter Sets by Gene Type (F1 from Mean P \& R)}
\label{tab:param_comparison} % Added label
\begin{tabular}{lccccc}
\toprule
Gene Type & Default & Best Parameters & Best & F1 Score & Improvement \\
& F1 Score & (LG, LPG, Stdev) & F1 Score & Difference & (\%) \\
\midrule
\textit{TRDV} & 0.519 & (4, 4, 2.0) & 0.537 & +0.018 & 3.5\% \\
\textit{IGLV} & 0.926 & (8, 8, 2.0) & 0.929 & +0.003 & 0.4\% \\
\textit{TRGV} & 0.875 & (4, 4, 2.0) & 0.878 & +0.003 & 0.4\% \\
\textit{TRAV} & 0.852 & (4, 4, 2.0) & 0.854 & +0.002 & 0.2\% \\
\textit{IGHV} & 0.813 & (4, 8, 2.0) & 0.814 & +0.001 & 0.1\% \\
\textit{TRBV} & 0.900 & (4, 4, 2.0) & 0.901 & +0.001 & 0.1\% \\
\textit{IGKV} & 0.763 & (4, 4, 2.0) & 0.763 & -0.000 & -0.0\% \\
\bottomrule
\end{tabular}
\end{table}


In [16]:


# --- Cell [15] remains the same ---
# Heatmaps and 3D plots are based on the original per-sample F1 scores averaged.
# If you wanted these based on the F1-from-means, this cell would also need modification
# to use the 'param_metrics' DataFrame calculated in Cell [13] instead of 'combined_df'.

if not combined_df.empty and 'f1_score' in combined_df.columns: # Still relies on original f1_score
    import matplotlib.pyplot as plt
    import seaborn as sns
    from mpl_toolkits.mplot3d import Axes3D # Moved import here

    # Get unique gene types and stdev values
    gene_types_present = combined_df['gene_type'].unique()
    stdev_values_present = combined_df['stdev_scaling'].unique()

    log.info("Generating heatmaps based on mean of individual F1 scores...")
    # Create a figure for each gene type and stdev combination
    for gene_type in gene_types_present:
        # Create a figure with subplots for each stdev value
        num_stdevs = len(stdev_values_present)
        if num_stdevs == 0: continue # Skip if no stdev values found

        fig, axes = plt.subplots(
            1, num_stdevs,
            figsize=(5 * num_stdevs, 6),
            squeeze=False # Always return 2D array for axes
        )

        fig.suptitle(f'Mean Individual F1 Score Heatmap for {gene_type}', fontsize=16)

        for i, stdev in enumerate(stdev_values_present):
             # Check if index i is valid for axes array
            if i >= axes.shape[1]:
                log.warning(f"Index {i} out of bounds for axes columns ({axes.shape[1]}). Skipping stdev {stdev}.")
                continue

            ax_current = axes[0, i] # Current subplot axis

            # Filter data for this gene type and stdev from combined_df
            filtered_data = combined_df[
                (combined_df['gene_type'] == gene_type) &
                (combined_df['stdev_scaling'] == stdev)
            ]

            # Skip if no data for this combination
            if filtered_data.empty:
                ax_current.text(0.5, 0.5, f"No data for stdev={stdev}",
                              ha='center', va='center', transform=ax_current.transAxes)
                ax_current.set_title(f'Stdev Scaling: {stdev}')
                ax_current.set_xticks([]) # Hide ticks if no data
                ax_current.set_yticks([])
                continue

            # Check for sufficient variation before pivoting
            if filtered_data['landmark_groups'].nunique() < 2 or filtered_data['landmarks_per_group'].nunique() < 2:
                 pivot_text = filtered_data[['landmark_groups', 'landmarks_per_group', 'f1_score']].to_string(index=False)
                 ax_current.text(0.5, 0.5,
                              f"Insufficient parameter combinations\nfor heatmap\n{pivot_text}",
                              ha='center', va='center', transform=ax_current.transAxes, fontsize=8)
                 ax_current.set_title(f'Stdev Scaling: {stdev}')
                 ax_current.set_xticks([])
                 ax_current.set_yticks([])
                 continue

            # Create pivot table for the heatmap (using mean of individual f1_score)
            try:
                pivot_data = filtered_data.pivot_table(
                    index='landmark_groups',
                    columns='landmarks_per_group',
                    values='f1_score', # Original F1 mean
                    aggfunc='mean'
                ) # Rounding done during annotation

                # Create the heatmap
                sns.heatmap(
                    pivot_data,
                    annot=True,
                    fmt='.3f', # Format annotation
                    cmap='viridis',
                    cbar=True, # Show color bar
                    ax=ax_current
                )

                # Set titles and labels
                ax_current.set_title(f'Stdev Scaling: {stdev}')
                ax_current.set_xlabel('Landmarks Per Group')
                ax_current.set_ylabel('Landmark Groups')

            except Exception as e:
                log.error(f"Error creating heatmap for {gene_type}, stdev {stdev}: {e}")
                ax_current.text(0.5, 0.5, f"Error generating heatmap",
                              ha='center', va='center', transform=ax_current.transAxes)
                ax_current.set_title(f'Stdev Scaling: {stdev}')


        plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout
        plt.savefig(f'{gene_type}_parameter_heatmaps_orig_f1.png', dpi=300, bbox_inches='tight')
        plt.close(fig) # Close the figure

    log.info("Generating 3D plots based on mean of individual F1 scores...")
    # Also create an alternative visualization - 3D surface plot (optional)
    for gene_type in gene_types_present:
        # Combine all stdev values into one dataset for this gene type
        gene_data = combined_df[combined_df['gene_type'] == gene_type]

        # Check if we have enough unique combinations data points for a meaningful 3D plot
        if len(gene_data[['landmark_groups', 'landmarks_per_group', 'stdev_scaling']].drop_duplicates()) > 5:
            fig = plt.figure(figsize=(12, 10))
            ax = fig.add_subplot(111, projection='3d')

            try:
                # Extract data for plotting
                X = gene_data['landmark_groups'].values
                Y = gene_data['landmarks_per_group'].values
                Z = gene_data['f1_score'].values # Original per-sample F1 averaged over samples for the group
                C = gene_data['stdev_scaling'].values # Color by stdev

                # Create a scatter plot with color based on stdev scaling and size based on f1? No, color=f1 better
                scatter = ax.scatter(X, Y, Z, c=Z, cmap='viridis', s=100, alpha=0.7) # Colored by F1 score

                # Add a color bar for F1 score
                cbar = fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=10)
                cbar.set_label('Mean Individual F1 Score')

                # Set labels and title
                ax.set_xlabel('Landmark Groups')
                ax.set_ylabel('Landmarks Per Group')
                ax.set_zlabel('Mean Individual F1 Score')
                ax.set_title(f'3D Visualization of Parameter Impact on Mean Individual F1 Score for {gene_type}')

                # Add legend or annotations for stdev scaling if needed (could get complex)
                # Simple approach: print unique stdevs included
                unique_stdevs_plot = sorted(gene_data['stdev_scaling'].unique())
                log.info(f"3D plot for {gene_type} includes stdev values: {unique_stdevs_plot}")


                plt.tight_layout()
                plt.savefig(f'{gene_type}_3d_parameter_visualization_orig_f1.png', dpi=300, bbox_inches='tight')
            except Exception as e:
                 log.error(f"Error creating 3D plot for {gene_type}: {e}")
            finally:
                plt.close(fig) # Ensure figure is closed even on error
        else:
             log.info(f"Skipping 3D plot for {gene_type}: Not enough unique parameter combinations ({len(gene_data[['landmark_groups', 'landmarks_per_group', 'stdev_scaling']].drop_duplicates())} found).")

else:
    print("Combined DataFrame empty or missing f1_score column, cannot create visualizations.")

INFO:benchmark_utils:Generating heatmaps based on mean of individual F1 scores...
INFO:benchmark_utils:Generating 3D plots based on mean of individual F1 scores...
INFO:benchmark_utils:3D plot for IGHV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for IGLV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for IGKV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for TRAV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for TRBV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for TRGV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
INFO:benchmark_utils:3D plot for TRDV includes stdev values: [np.float64(1.0), np.float64(1.5), np.float64(2.0)]
