In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from scipy import stats

In [None]:

def simplify_gene_name(gene_name):
    """Remove allele designation (part after *) from gene name"""
    if pd.isna(gene_name):
        return gene_name
    return str(gene_name).split('*')[0]

def extract_gene_family(gene_name):
    """Extract gene family (e.g., IGHV3 from IGHV3-23*01)"""
    if pd.isna(gene_name):
        return gene_name
    simplified = str(gene_name).split('*')[0]
    # Extract family part (everything before the dash)
    if '-' in simplified:
        return simplified.split('-')[0]
    return simplified

def analyze_vgene_distribution(csv_file):
    """
    Analyze the distribution of V genes and loci between heavy chains and generated light sequences.
    
    Parameters:
    csv_file (str): Path to the CSV file
    """
    
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    print("Dataset Overview:")
    print(f"Total rows: {len(df)}")
    print(f"Columns: {df.columns.tolist()}")
    
    # Check if true light V gene column exists
    true_light_vgene_col = None
    possible_true_light_cols = ['true_light_gene_name', 'true_light_v_gene', 'light_gene_name', 'true_light_v_gene_name']
    
    for col in possible_true_light_cols:
        if col in df.columns:
            true_light_vgene_col = col
            break
    
    print(f"True light V gene column: {true_light_vgene_col if true_light_vgene_col else 'NOT FOUND'}")
    
    # Get all relevant columns
    light_gene_cols = [col for col in df.columns if col.startswith('gen_light_') and col.endswith('_gene_name')]
    light_locus_cols = [col for col in df.columns if col.startswith('gen_light_') and col.endswith('_light_locus')]
    
    print(f"Generated light gene columns found: {len(light_gene_cols)}")
    print(f"Generated light locus columns found: {len(light_locus_cols)}")
    
    # ============================
    # V GENE ANALYSIS - GENERATED
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("GENERATED V GENE ANALYSIS")
    print("="*50)
    
    # Collect all heavy-light V gene pairs (full names)
    heavy_light_pairs_full = []
    vgene_dist_full = defaultdict(lambda: defaultdict(int))
    
    for idx, row in df.iterrows():
        heavy_gene = row['heavy_gene_name']
        if pd.isna(heavy_gene):
            continue
            
        for col in light_gene_cols:
            light_gene = row[col]
            if not pd.isna(light_gene):
                heavy_light_pairs_full.append((heavy_gene, light_gene))
                vgene_dist_full[heavy_gene][light_gene] += 1
    
    print(f"Total heavy-light V gene pairs (full): {len(heavy_light_pairs_full)}")
    
    # Create simplified pairs
    heavy_light_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                               for pair in heavy_light_pairs_full]
    vgene_dist_simple = defaultdict(lambda: defaultdict(int))
    
    for heavy_gene, light_gene in heavy_light_pairs_simple:
        if not pd.isna(heavy_gene) and not pd.isna(light_gene):
            vgene_dist_simple[heavy_gene][light_gene] += 1
    
    # Create family pairs
    heavy_light_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                               for pair in heavy_light_pairs_full]
    vgene_dist_family = defaultdict(lambda: defaultdict(int))
    
    for heavy_family, light_family in heavy_light_pairs_family:
        if not pd.isna(heavy_family) and not pd.isna(light_family):
            vgene_dist_family[heavy_family][light_family] += 1
    
    # Count frequencies
    heavy_counts_full = Counter([pair[0] for pair in heavy_light_pairs_full])
    light_counts_full = Counter([pair[1] for pair in heavy_light_pairs_full])
    heavy_counts_simple = Counter([pair[0] for pair in heavy_light_pairs_simple if not pd.isna(pair[0])])
    light_counts_simple = Counter([pair[1] for pair in heavy_light_pairs_simple if not pd.isna(pair[1])])
    heavy_counts_family = Counter([pair[0] for pair in heavy_light_pairs_family if not pd.isna(pair[0])])
    light_counts_family = Counter([pair[1] for pair in heavy_light_pairs_family if not pd.isna(pair[1])])
    
    print(f"Generated - Full Names: {len(heavy_counts_full)} heavy, {len(light_counts_full)} light")
    print(f"Generated - Simplified: {len(heavy_counts_simple)} heavy, {len(light_counts_simple)} light")
    print(f"Generated - Families: {len(heavy_counts_family)} heavy, {len(light_counts_family)} light")
    
    # ============================
    # V GENE ANALYSIS - TRUE
    # ============================
    
    # Initialize true data variables
    true_vgene_dist_full = None
    true_vgene_dist_simple = None
    true_vgene_dist_family = None
    true_heavy_counts_full = Counter()
    true_light_counts_full = Counter()
    true_heavy_counts_simple = Counter()
    true_light_counts_simple = Counter()
    true_heavy_counts_family = Counter()
    true_light_counts_family = Counter()
    
    if true_light_vgene_col:
        print("\nTRUE V GENE ANALYSIS")
        print("="*50)
        
        # Collect all heavy-true light V gene pairs
        heavy_true_pairs_full = []
        true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
        
        for idx, row in df.iterrows():
            heavy_gene = row['heavy_gene_name']
            true_light_gene = row[true_light_vgene_col]
            if pd.isna(heavy_gene) or pd.isna(true_light_gene):
                continue
            
            heavy_true_pairs_full.append((heavy_gene, true_light_gene))
            true_vgene_dist_full[heavy_gene][true_light_gene] += 1
        
        print(f"Total heavy-true light V gene pairs: {len(heavy_true_pairs_full)}")
        
        # Create simplified and family versions for true data
        heavy_true_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                                  for pair in heavy_true_pairs_full]
        heavy_true_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                                  for pair in heavy_true_pairs_full]
        
        true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
        true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
        
        for heavy_gene, light_gene in heavy_true_pairs_simple:
            if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                true_vgene_dist_simple[heavy_gene][light_gene] += 1
        
        for heavy_family, light_family in heavy_true_pairs_family:
            if not pd.isna(heavy_family) and not pd.isna(light_family):
                true_vgene_dist_family[heavy_family][light_family] += 1
        
        # Count frequencies for true data
        true_heavy_counts_full = Counter([pair[0] for pair in heavy_true_pairs_full])
        true_light_counts_full = Counter([pair[1] for pair in heavy_true_pairs_full])
        true_heavy_counts_simple = Counter([pair[0] for pair in heavy_true_pairs_simple if not pd.isna(pair[0])])
        true_light_counts_simple = Counter([pair[1] for pair in heavy_true_pairs_simple if not pd.isna(pair[1])])
        true_heavy_counts_family = Counter([pair[0] for pair in heavy_true_pairs_family if not pd.isna(pair[0])])
        true_light_counts_family = Counter([pair[1] for pair in heavy_true_pairs_family if not pd.isna(pair[1])])
        
        print(f"True - Full Names: {len(true_heavy_counts_full)} heavy, {len(true_light_counts_full)} light")
        print(f"True - Simplified: {len(true_heavy_counts_simple)} heavy, {len(true_light_counts_simple)} light")
        print(f"True - Families: {len(true_heavy_counts_family)} heavy, {len(true_light_counts_family)} light")
    
    # ============================
    # LOCUS ANALYSIS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("LOCUS ANALYSIS")
    print("="*50)
    
    # Analyze true vs generated loci
    true_gen_locus_pairs = []
    locus_dist = defaultdict(lambda: defaultdict(int))
    locus_accuracy = defaultdict(lambda: {'total': 0, 'correct': 0})
    
    for idx, row in df.iterrows():
        true_locus = row['true_light_light_locus']
        if pd.isna(true_locus):
            continue
            
        for col in light_locus_cols:
            gen_locus = row[col]
            if not pd.isna(gen_locus):
                true_gen_locus_pairs.append((true_locus, gen_locus))
                locus_dist[true_locus][gen_locus] += 1
                locus_accuracy[true_locus]['total'] += 1
                if gen_locus == true_locus:
                    locus_accuracy[true_locus]['correct'] += 1
    
    print(f"Total true-generated locus pairs: {len(true_gen_locus_pairs)}")
    
    # Calculate overall accuracy
    total_generated = sum([stats['total'] for stats in locus_accuracy.values()])
    total_correct = sum([stats['correct'] for stats in locus_accuracy.values()])
    overall_accuracy = (total_correct / total_generated * 100) if total_generated > 0 else 0
    
    print(f"Overall locus accuracy: {overall_accuracy:.1f}% ({total_correct}/{total_generated})")
    
    # ============================
    # COVERAGE ANALYSIS
    # ============================
    
    def analyze_coverage(heavy_counts, light_counts, vgene_dist, level_name):
        total_heavy_genes = len(heavy_counts)
        total_light_genes = len(light_counts)
        
        # Count total associations
        total_associations = sum(sum(light_dist.values()) for light_dist in vgene_dist.values())
        
        # Get top 5 genes
        top_5_heavy = [gene for gene, _ in heavy_counts.most_common(5)]
        top_5_light = [gene for gene, _ in light_counts.most_common(5)]
        
        # Calculate coverage of top 5 associations
        top_5_associations = 0
        for heavy in top_5_heavy:
            for light in top_5_light:
                top_5_associations += vgene_dist[heavy][light]
        
        coverage_pct = (top_5_associations / total_associations * 100) if total_associations > 0 else 0
        
        print(f"{level_name}: {total_heavy_genes}x{total_light_genes} genes, 5x5 covers {coverage_pct:.1f}%")
        
        return total_heavy_genes, total_light_genes, coverage_pct
    
    print("\nCoverage Analysis:")
    full_coverage = analyze_coverage(heavy_counts_full, light_counts_full, vgene_dist_full, "Full Names")
    simple_coverage = analyze_coverage(heavy_counts_simple, light_counts_simple, vgene_dist_simple, "Simplified")
    family_coverage = analyze_coverage(heavy_counts_family, light_counts_family, vgene_dist_family, "Families")
    
    # ============================
    # HEATMAPS
    # ============================
    
    def create_complete_contingency(heavy_counts, light_counts, vgene_dist):
        all_heavy = list(heavy_counts.keys())
        all_light = list(light_counts.keys())
        
        contingency = np.zeros((len(all_heavy), len(all_light)))
        
        for i, heavy_gene in enumerate(all_heavy):
            for j, light_gene in enumerate(all_light):
                contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        return contingency, all_heavy, all_light
    
    def create_heatmap(heavy_counts, light_counts, vgene_dist, title, cmap='YlOrRd', complete=False):
        if complete:
            contingency, heavy_genes, light_genes = create_complete_contingency(heavy_counts, light_counts, vgene_dist)
        else:
            heavy_genes = [gene for gene, _ in heavy_counts.most_common(5)]
            light_genes = [gene for gene, _ in light_counts.most_common(5)]
            
            if len(heavy_genes) < 2 or len(light_genes) < 2:
                print(f"Insufficient data for {title}")
                return
            
            contingency = np.zeros((len(heavy_genes), len(light_genes)))
            for i, heavy_gene in enumerate(heavy_genes):
                for j, light_gene in enumerate(light_genes):
                    contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        # Dynamic figure sizing
        fig_width = max(8, len(light_genes) * 0.6)
        fig_height = max(6, len(heavy_genes) * 0.5)
        
        fig, ax = plt.subplots(1, 1, figsize=(fig_width, fig_height))
        
        im = ax.imshow(contingency, cmap=cmap, aspect='auto')
        ax.set_xticks(range(len(light_genes)))
        ax.set_yticks(range(len(heavy_genes)))
        ax.set_xticklabels(light_genes, rotation=45, ha='right', fontsize=8)
        ax.set_yticklabels(heavy_genes, fontsize=16)
        ax.set_title(title)
        ax.set_xlabel('Light V Genes')
        ax.set_ylabel('Heavy V Genes')
        
        # Add annotations for smaller heatmaps
        if len(heavy_genes) <= 15 and len(light_genes) <= 15:
            for i in range(len(heavy_genes)):
                for j in range(len(light_genes)):
                    text_color = 'white' if contingency[i, j] > np.max(contingency) * 0.5 else 'black'
                    ax.text(j, i, int(contingency[i, j]), ha='center', va='center', 
                           fontsize=14, color=text_color)
        
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.show()
        
        return contingency
    
    print("\n" + "="*80 + "\n")
    print("CREATING HEATMAPS...")
    
    # GENERATED V GENE HEATMAPS
    print("\n--- GENERATED V GENE ASSOCIATION HEATMAPS ---")
    
    # 5x5 heatmaps
    create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                  "Generated V Gene Associations (Full Names - Top 5x5)")
    
    create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                  "Generated V Gene Associations (Simplified - Top 5x5)")
    
    create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                  "Generated V Gene Associations (Families - Top 5x5)")
    
    # Complete heatmaps when feasible
    print("\n--- COMPLETE GENERATED V GENE HEATMAPS ---")
    
    if family_coverage[0] <= 15 and family_coverage[1] <= 15:
        create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                      f"Complete Generated Gene Family Associations ({family_coverage[0]}x{family_coverage[1]})", 
                      complete=True)
    
    if simple_coverage[0] <= 20 and simple_coverage[1] <= 20:
        create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                      f"Complete Generated Simplified Associations ({simple_coverage[0]}x{simple_coverage[1]})", 
                      complete=True)
    
    if full_coverage[0] <= 25 and full_coverage[1] <= 25:
        create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                      f"Complete Generated Full Name Associations ({full_coverage[0]}x{full_coverage[1]})", 
                      complete=True)
    
    # TRUE V GENE HEATMAPS (if available)
    if true_light_vgene_col and true_vgene_dist_full:
        print("\n--- TRUE V GENE ASSOCIATION HEATMAPS ---")
        
        # Calculate true coverage
        true_full_coverage = analyze_coverage(true_heavy_counts_full, true_light_counts_full, true_vgene_dist_full, "True Full")
        true_simple_coverage = analyze_coverage(true_heavy_counts_simple, true_light_counts_simple, true_vgene_dist_simple, "True Simplified")
        true_family_coverage = analyze_coverage(true_heavy_counts_family, true_light_counts_family, true_vgene_dist_family, "True Families")
        
        # 5x5 true heatmaps
        create_heatmap(true_heavy_counts_full, true_light_counts_full, true_vgene_dist_full, 
                      "TRUE V Gene Associations (Full Names - Top 5x5)", cmap='Reds')
        
        create_heatmap(true_heavy_counts_simple, true_light_counts_simple, true_vgene_dist_simple, 
                      "TRUE V Gene Associations (Simplified - Top 5x5)", cmap='Reds')
        
        create_heatmap(true_heavy_counts_family, true_light_counts_family, true_vgene_dist_family, 
                      "TRUE V Gene Associations (Families - Top 5x5)", cmap='Reds')
        
        # Complete true heatmaps when feasible
        print("\n--- COMPLETE TRUE V GENE HEATMAPS ---")
        
        if true_family_coverage[0] <= 15 and true_family_coverage[1] <= 15:
            create_heatmap(true_heavy_counts_family, true_light_counts_family, true_vgene_dist_family, 
                          f"Complete TRUE Gene Family Associations ({true_family_coverage[0]}x{true_family_coverage[1]})", 
                          cmap='Reds', complete=True)
        
        if true_simple_coverage[0] <= 20 and true_simple_coverage[1] <= 20:
            create_heatmap(true_heavy_counts_simple, true_light_counts_simple, true_vgene_dist_simple, 
                          f"Complete TRUE Simplified Associations ({true_simple_coverage[0]}x{true_simple_coverage[1]})", 
                          cmap='Reds', complete=True)
        
        if true_full_coverage[0] <= 25 and true_full_coverage[1] <= 25:
            create_heatmap(true_heavy_counts_full, true_light_counts_full, true_vgene_dist_full, 
                          f"Complete TRUE Full Name Associations ({true_full_coverage[0]}x{true_full_coverage[1]})", 
                          cmap='Reds', complete=True)
    
    # LOCUS HEATMAP
    print("\n--- LOCUS ASSOCIATION HEATMAP ---")
    
    true_locus_counts = Counter([pair[0] for pair in true_gen_locus_pairs])
    gen_locus_counts = Counter([pair[1] for pair in true_gen_locus_pairs])
    
    unique_true_loci = list(true_locus_counts.keys())
    unique_gen_loci = list(gen_locus_counts.keys())
    
    if len(unique_true_loci) >= 2 and len(unique_gen_loci) >= 2:
        locus_contingency = np.zeros((len(unique_true_loci), len(unique_gen_loci)))
        
        for i, true_locus in enumerate(unique_true_loci):
            for j, gen_locus in enumerate(unique_gen_loci):
                locus_contingency[i, j] = locus_dist[true_locus][gen_locus]
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        im = ax.imshow(locus_contingency, cmap='Blues', aspect='auto')
        ax.set_xticks(range(len(unique_gen_loci)))
        ax.set_yticks(range(len(unique_true_loci)))
        ax.set_xticklabels(unique_gen_loci, rotation=45, ha='right')
        ax.set_yticklabels(unique_true_loci)
        ax.set_title('Locus Association Heatmap (True vs Generated)')
        ax.set_xlabel('Generated Light Chain Loci')
        ax.set_ylabel('True Light Chain Loci')
        
        # Add annotations
        for i in range(len(unique_true_loci)):
            for j in range(len(unique_gen_loci)):
                text_color = 'white' if locus_contingency[i, j] > np.max(locus_contingency) * 0.5 else 'black'
                ax.text(j, i, int(locus_contingency[i, j]), ha='center', va='center', color=text_color)
        
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.show()
    
    # ============================
    # RETURN RESULTS
    # ============================
    
    return {
        'heavy_counts_full': heavy_counts_full,
        'light_counts_full': light_counts_full,
        'heavy_counts_simple': heavy_counts_simple,
        'light_counts_simple': light_counts_simple,
        'heavy_counts_family': heavy_counts_family,
        'light_counts_family': light_counts_family,
        'true_heavy_counts_full': true_heavy_counts_full,
        'true_light_counts_full': true_light_counts_full,
        'true_heavy_counts_simple': true_heavy_counts_simple,
        'true_light_counts_simple': true_light_counts_simple,
        'true_heavy_counts_family': true_heavy_counts_family,
        'true_light_counts_family': true_light_counts_family,
        'vgene_dist_full': dict(vgene_dist_full),
        'vgene_dist_simple': dict(vgene_dist_simple),
        'vgene_dist_family': dict(vgene_dist_family),
        'true_vgene_dist_full': dict(true_vgene_dist_full) if true_vgene_dist_full else None,
        'true_vgene_dist_simple': dict(true_vgene_dist_simple) if true_vgene_dist_simple else None,
        'true_vgene_dist_family': dict(true_vgene_dist_family) if true_vgene_dist_family else None,
        'locus_dist': dict(locus_dist),
        'locus_accuracy': dict(locus_accuracy),
        'overall_locus_accuracy': overall_accuracy,
        'coverage_analysis': {
            'generated': {
                'full': full_coverage,
                'simple': simple_coverage,
                'family': family_coverage
            },
            'true': {
                'full': true_full_coverage if true_light_vgene_col else None,
                'simple': true_simple_coverage if true_light_vgene_col else None,
                'family': true_family_coverage if true_light_vgene_col else None
            } if true_light_vgene_col else None
        },
        'true_light_vgene_col': true_light_vgene_col
    }


In [None]:
#csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv'

csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv'

try:
    results = analyze_vgene_distribution(csv_file)
    print("\nAnalysis completed successfully!")
    print("\nReturned results contain:")
    print("- vgene_diversity: DataFrame with diversity statistics")
    print("- heavy_vgene_counts: Counter of heavy V gene frequencies")
    print("- light_vgene_counts: Counter of light V gene frequencies")
    print("- heavy_locus_counts: Counter of heavy locus frequencies")
    print("- light_locus_counts: Counter of light locus frequencies")
    print("- vgene_chi2: Chi-square test results for V gene associations")
    print("- locus_chi2: Chi-square test results for locus associations")
        
except FileNotFoundError:
    print(f"Error: Could not find the file '{csv_file}'")
    print("Please make sure the file path is correct.")
except Exception as e:
    print(f"An error occurred: {str(e)}")

In [None]:
import pandas as pd

df1 = pd.read_csv('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/pairing_result_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols.csv')
df2 = pd.read_csv('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv')


# Check if all merge columns exist in both DataFrames
merge_cols = ['overall_id', 'heavy_raw_sequence','true_light_raw_sequence', 
              'true_light_first_hit_gene', 'true_light_gene_name', 'true_light_light_locus',
              'gen_light_1_raw_sequence', 'gen_light_1_nt_trimmed']

print("Columns in df1:", df1.columns.tolist())
print("Columns in df2:", df2.columns.tolist())

# Check for missing columns
missing_df1 = [col for col in merge_cols if col not in df1.columns]
missing_df2 = [col for col in merge_cols if col not in df2.columns]

if missing_df1:
    print(f"Missing columns in df1: {missing_df1}")
if missing_df2:
    print(f"Missing columns in df2: {missing_df2}")

# Check data types
print("\nData types comparison:")
for col in merge_cols:
    if col in df1.columns and col in df2.columns:
        print(f"{col}: df1={df1[col].dtype}, df2={df2[col].dtype}")

# Perform merge
merged = df1.merge(df2, on=merge_cols, how='left', suffixes=('', '_df2'))

print(f"\nMerged: {len(merged)} rows")

# Better way to check unmatched rows - look for nulls in df2 columns only
df2_cols = [col for col in merged.columns if col.endswith('_df2') or col in df2.columns]
unmatched = merged[df2_cols].isnull().all(axis=1).sum()
print(f"Unmatched rows: {unmatched}")

# Save the merged DataFrame
merged.to_csv('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/pairing_result_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols_merged_complete.csv', index=False)


In [None]:
def simplify_gene_name(gene_name):
    """Remove allele designation (part after *) from gene name"""
    if pd.isna(gene_name):
        return gene_name
    return str(gene_name).split('*')[0]

def extract_gene_family(gene_name):
    """Extract gene family (e.g., IGHV3 from IGHV3-23*01)"""
    if pd.isna(gene_name):
        return gene_name
    simplified = str(gene_name).split('*')[0]
    # Extract family part (everything before the dash)
    if '-' in simplified:
        return simplified.split('-')[0]
    return simplified

def analyze_vgene_by_pairing_score(csv_file):
    """
    Analyze V gene distributions grouped by pairing scores (<0.5 vs ≥0.5)
    and compare generated vs true light chain V genes.
    
    Parameters:
    csv_file (str): Path to the CSV file
    """
    
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    print("Dataset Overview:")
    print(f"Total rows: {len(df)}")
    print(f"Columns: {df.columns.tolist()}")
    
    # Check for required columns
    required_cols = ['pairing_scores', 'true_light_gene_name', 'gen_light_1_first_hit_gene']
    
    # Check for heavy chain V gene column
    heavy_gene_col = None
    possible_heavy_cols = ['heavy_gene_name', 'true_v_gene_simple', 'heavy_v_gene', 'true_heavy_gene']
    
    for col in possible_heavy_cols:
        if col in df.columns:
            heavy_gene_col = col
            break
    
    if heavy_gene_col:
        required_cols.append(heavy_gene_col)
        print(f"✅ Using heavy chain V gene column: {heavy_gene_col}")
    else:
        print(f"⚠️ No heavy chain V gene column found. Searched for: {possible_heavy_cols}")
        print(f"Available columns: {df.columns.tolist()}")
        print(f"Will skip heavy-light V gene pairing analysis and focus on light chain only.")
    
    missing_cols = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        print(f"❌ Missing required columns: {missing_cols}")
        return None
    
    print(f"✅ Required columns found")
    
    # ============================
    # DATA PREPARATION
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("DATA PREPARATION")
    print("="*50)
    
    # Remove rows with missing pairing scores
    df_clean = df.dropna(subset=['pairing_scores']).copy()
    print(f"Rows with pairing scores: {len(df_clean)}/{len(df)} ({len(df_clean)/len(df)*100:.1f}%)")
    
    # Group by pairing score threshold
    df_low = df_clean[df_clean['pairing_scores'] < 0.5].copy()
    df_high = df_clean[df_clean['pairing_scores'] >= 0.5].copy()
    
    print(f"Low pairing scores (<0.5): {len(df_low)} rows ({len(df_low)/len(df_clean)*100:.1f}%)")
    print(f"High pairing scores (≥0.5): {len(df_high)} rows ({len(df_high)/len(df_clean)*100:.1f}%)")
    
    # Pairing score statistics
    print(f"\nPairing Score Statistics:")
    print(f"Overall - Mean: {df_clean['pairing_scores'].mean():.3f}, Median: {df_clean['pairing_scores'].median():.3f}")
    if len(df_low) > 0:
        print(f"Low group - Mean: {df_low['pairing_scores'].mean():.3f}, Range: {df_low['pairing_scores'].min():.3f}-{df_low['pairing_scores'].max():.3f}")
    if len(df_high) > 0:
        print(f"High group - Mean: {df_high['pairing_scores'].mean():.3f}, Range: {df_high['pairing_scores'].min():.3f}-{df_high['pairing_scores'].max():.3f}")
    
    def analyze_group(df_group, group_name):
        """Analyze V gene distributions for a specific pairing score group"""
        
        if len(df_group) == 0:
            print(f"\n⚠️ No data for {group_name} group")
            return None
        
        print(f"\n{'='*60}")
        print(f"{group_name.upper()} PAIRING SCORE GROUP ANALYSIS")
        print(f"{'='*60}")
        
        results = {}
        
        # ============================
        # GENERATED V GENE ANALYSIS
        # ============================
        
        print(f"\nGENERATED LIGHT CHAIN V GENE ANALYSIS - {group_name}")
        print("="*50)
        
        # Collect generated light chain V genes
        gen_light_genes = []
        for idx, row in df_group.iterrows():
            gen_light_gene = row['gen_light_1_first_hit_gene']
            if not pd.isna(gen_light_gene):
                gen_light_genes.append(gen_light_gene)
        
        print(f"Generated light V genes: {len(gen_light_genes)}")
        
        # Create simplified and family versions for generated
        gen_light_simple = [simplify_gene_name(gene) for gene in gen_light_genes if not pd.isna(gene)]
        gen_light_family = [extract_gene_family(gene) for gene in gen_light_genes if not pd.isna(gene)]
        
        # Count frequencies for generated
        gen_light_counts_full = Counter(gen_light_genes)
        gen_light_counts_simple = Counter([gene for gene in gen_light_simple if not pd.isna(gene)])
        gen_light_counts_family = Counter([gene for gene in gen_light_family if not pd.isna(gene)])
        
        print(f"Generated - Full: {len(gen_light_counts_full)} unique genes")
        print(f"Generated - Simple: {len(gen_light_counts_simple)} unique genes")
        print(f"Generated - Family: {len(gen_light_counts_family)} unique families")
        
        # ============================
        # TRUE V GENE ANALYSIS
        # ============================
        
        print(f"\nTRUE LIGHT CHAIN V GENE ANALYSIS - {group_name}")
        print("="*50)
        
        # Collect true light chain V genes
        true_light_genes = []
        for idx, row in df_group.iterrows():
            true_light_gene = row['true_light_gene_name']
            if not pd.isna(true_light_gene):
                true_light_genes.append(true_light_gene)
        
        print(f"True light V genes: {len(true_light_genes)}")
        
        # Create simplified and family versions for true
        true_light_simple = [simplify_gene_name(gene) for gene in true_light_genes if not pd.isna(gene)]
        true_light_family = [extract_gene_family(gene) for gene in true_light_genes if not pd.isna(gene)]
        
        # Count frequencies for true
        true_light_counts_full = Counter(true_light_genes)
        true_light_counts_simple = Counter([gene for gene in true_light_simple if not pd.isna(gene)])
        true_light_counts_family = Counter([gene for gene in true_light_family if not pd.isna(gene)])
        
        print(f"True - Full: {len(true_light_counts_full)} unique genes")
        print(f"True - Simple: {len(true_light_counts_simple)} unique genes") 
        print(f"True - Family: {len(true_light_counts_family)} unique families")
        
        # ============================
        # HEAVY-LIGHT PAIRING ANALYSIS (if heavy chain data available)
        # ============================
        
        if heavy_gene_col:
            print(f"\nHEAVY-LIGHT V GENE PAIRING ANALYSIS - {group_name}")
            print("="*50)
            
            # Generated heavy-light pairs
            gen_pairs_full = []
            gen_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            
            for idx, row in df_group.iterrows():
                heavy_gene = row[heavy_gene_col]
                gen_light_gene = row['gen_light_1_first_hit_gene']
                
                if pd.isna(heavy_gene) or pd.isna(gen_light_gene):
                    continue
                
                gen_pairs_full.append((heavy_gene, gen_light_gene))
                gen_vgene_dist_full[heavy_gene][gen_light_gene] += 1
            
            # True heavy-light pairs
            true_pairs_full = []
            true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            
            for idx, row in df_group.iterrows():
                heavy_gene = row[heavy_gene_col]
                true_light_gene = row['true_light_gene_name']
                
                if pd.isna(heavy_gene) or pd.isna(true_light_gene):
                    continue
                
                true_pairs_full.append((heavy_gene, true_light_gene))
                true_vgene_dist_full[heavy_gene][true_light_gene] += 1
            
            print(f"Generated heavy-light pairs: {len(gen_pairs_full)}")
            print(f"True heavy-light pairs: {len(true_pairs_full)}")
            
            # Create simplified and family versions for pairing
            gen_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                               for pair in gen_pairs_full]
            gen_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                               for pair in gen_pairs_full]
            
            true_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                                for pair in true_pairs_full]
            true_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                                for pair in true_pairs_full]
            
            # Create distribution dictionaries
            gen_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            
            for heavy_gene, light_gene in gen_pairs_simple:
                if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                    gen_vgene_dist_simple[heavy_gene][light_gene] += 1
            
            for heavy_family, light_family in gen_pairs_family:
                if not pd.isna(heavy_family) and not pd.isna(light_family):
                    gen_vgene_dist_family[heavy_family][light_family] += 1
            
            for heavy_gene, light_gene in true_pairs_simple:
                if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                    true_vgene_dist_simple[heavy_gene][light_gene] += 1
            
            for heavy_family, light_family in true_pairs_family:
                if not pd.isna(heavy_family) and not pd.isna(light_family):
                    true_vgene_dist_family[heavy_family][light_family] += 1
            
            # Count heavy chain frequencies
            gen_heavy_counts_full = Counter([pair[0] for pair in gen_pairs_full])
            gen_heavy_counts_simple = Counter([pair[0] for pair in gen_pairs_simple if not pd.isna(pair[0])])
            gen_heavy_counts_family = Counter([pair[0] for pair in gen_pairs_family if not pd.isna(pair[0])])
            
            true_heavy_counts_full = Counter([pair[0] for pair in true_pairs_full])
            true_heavy_counts_simple = Counter([pair[0] for pair in true_pairs_simple if not pd.isna(pair[0])])
            true_heavy_counts_family = Counter([pair[0] for pair in true_pairs_family if not pd.isna(pair[0])])
            
        else:
            # No heavy chain data - create empty distributions
            gen_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            
            gen_heavy_counts_full = Counter()
            gen_heavy_counts_simple = Counter()
            gen_heavy_counts_family = Counter()
            true_heavy_counts_full = Counter()
            true_heavy_counts_simple = Counter()
            true_heavy_counts_family = Counter()
        
        # ============================
        # LOCUS ANALYSIS
        # ============================
        
        print(f"\nLOCUS ANALYSIS - {group_name}")
        print("="*50)
        
        # Analyze true vs generated loci
        locus_pairs = []
        locus_dist = defaultdict(lambda: defaultdict(int))
        locus_accuracy = defaultdict(lambda: {'total': 0, 'correct': 0})
        
        for idx, row in df_group.iterrows():
            true_locus = row['true_light_light_locus']
            gen_locus = row['gen_light_1_light_locus']
            
            if pd.isna(true_locus) or pd.isna(gen_locus):
                continue
            
            locus_pairs.append((true_locus, gen_locus))
            locus_dist[true_locus][gen_locus] += 1
            locus_accuracy[true_locus]['total'] += 1
            if gen_locus == true_locus:
                locus_accuracy[true_locus]['correct'] += 1
        
        print(f"True-generated locus pairs: {len(locus_pairs)}")
        
        # Calculate overall locus accuracy
        total_generated = sum([stats['total'] for stats in locus_accuracy.values()])
        total_correct = sum([stats['correct'] for stats in locus_accuracy.values()])
        overall_accuracy = (total_correct / total_generated * 100) if total_generated > 0 else 0
        
        print(f"Locus accuracy: {overall_accuracy:.1f}% ({total_correct}/{total_generated})")
        
        # ============================
        # COVERAGE ANALYSIS
        # ============================
        
        def analyze_coverage(heavy_counts, light_counts, vgene_dist, level_name):
            total_heavy_genes = len(heavy_counts)
            total_light_genes = len(light_counts)
            
            if total_heavy_genes == 0 or total_light_genes == 0:
                return 0, 0, 0
            
            # Count total associations
            total_associations = sum(sum(light_dist.values()) for light_dist in vgene_dist.values())
            
            # Get top 5 genes
            top_5_heavy = [gene for gene, _ in heavy_counts.most_common(5)]
            top_5_light = [gene for gene, _ in light_counts.most_common(5)]
            
            # Calculate coverage of top 5 associations
            top_5_associations = 0
            for heavy in top_5_heavy:
                for light in top_5_light:
                    top_5_associations += vgene_dist[heavy][light]
            
            coverage_pct = (top_5_associations / total_associations * 100) if total_associations > 0 else 0
            
            return total_heavy_genes, total_light_genes, coverage_pct
        
        print(f"\nCoverage Analysis - {group_name}:")
        gen_full_cov = analyze_coverage(gen_heavy_counts_full, gen_light_counts_full, gen_vgene_dist_full, "Gen Full")
        gen_simple_cov = analyze_coverage(gen_heavy_counts_simple, gen_light_counts_simple, gen_vgene_dist_simple, "Gen Simple")
        gen_family_cov = analyze_coverage(gen_heavy_counts_family, gen_light_counts_family, gen_vgene_dist_family, "Gen Family")
        true_full_cov = analyze_coverage(true_heavy_counts_full, true_light_counts_full, true_vgene_dist_full, "True Full")
        true_simple_cov = analyze_coverage(true_heavy_counts_simple, true_light_counts_simple, true_vgene_dist_simple, "True Simple")
        true_family_cov = analyze_coverage(true_heavy_counts_family, true_light_counts_family, true_vgene_dist_family, "True Family")
        
        print(f"Generated - Full: {gen_full_cov[0]}x{gen_light_counts_full if gen_light_counts_full else 0}, 5x5 covers {gen_full_cov[2]:.1f}%")
        print(f"Generated - Simple: {gen_simple_cov[0]}x{len(gen_light_counts_simple)}, 5x5 covers {gen_simple_cov[2]:.1f}%")
        print(f"Generated - Family: {gen_family_cov[0]}x{len(gen_light_counts_family)}, 5x5 covers {gen_family_cov[2]:.1f}%")
        print(f"True - Full: {true_full_cov[0]}x{len(true_light_counts_full)}, 5x5 covers {true_full_cov[2]:.1f}%")
        print(f"True - Simple: {true_simple_cov[0]}x{len(true_light_counts_simple)}, 5x5 covers {true_simple_cov[2]:.1f}%")
        print(f"True - Family: {true_family_cov[0]}x{len(true_light_counts_family)}, 5x5 covers {true_family_cov[2]:.1f}%")
        
        # Store results
        results = {
            'group_name': group_name,
            'n_rows': len(df_group),
            'heavy_gene_col': heavy_gene_col,
            'gen_light_counts_full': gen_light_counts_full,
            'gen_light_counts_simple': gen_light_counts_simple,
            'gen_light_counts_family': gen_light_counts_family,
            'true_light_counts_full': true_light_counts_full,
            'true_light_counts_simple': true_light_counts_simple,
            'true_light_counts_family': true_light_counts_family,
            'gen_heavy_counts_full': gen_heavy_counts_full,
            'gen_heavy_counts_simple': gen_heavy_counts_simple,
            'gen_heavy_counts_family': gen_heavy_counts_family,
            'true_heavy_counts_full': true_heavy_counts_full,
            'true_heavy_counts_simple': true_heavy_counts_simple,
            'true_heavy_counts_family': true_heavy_counts_family,
            'gen_vgene_dist_full': dict(gen_vgene_dist_full),
            'gen_vgene_dist_simple': dict(gen_vgene_dist_simple),
            'gen_vgene_dist_family': dict(gen_vgene_dist_family),
            'true_vgene_dist_full': dict(true_vgene_dist_full),
            'true_vgene_dist_simple': dict(true_vgene_dist_simple),
            'true_vgene_dist_family': dict(true_vgene_dist_family),
            'locus_dist': dict(locus_dist),
            'locus_accuracy': dict(locus_accuracy),
            'overall_locus_accuracy': overall_accuracy,
            'coverage': {
                'gen_full': gen_full_cov,
                'gen_simple': gen_simple_cov,
                'gen_family': gen_family_cov,
                'true_full': true_full_cov,
                'true_simple': true_simple_cov,
                'true_family': true_family_cov
            }
        }
        
        return results
    
    # Analyze both groups
    low_results = analyze_group(df_low, "Low Pairing Score (<0.5)")
    high_results = analyze_group(df_high, "High Pairing Score (≥0.5)")
    
    # Skip if either group has no data
    if low_results is None or high_results is None:
        print("\n⚠️ One or both groups have no data. Skipping heatmap creation.")
        return {
            'overall_stats': {
                'total_rows': len(df),
                'rows_with_scores': len(df_clean),
                'low_score_rows': len(df_low),
                'high_score_rows': len(df_high)
            },
            'low_score_group': low_results,
            'high_score_group': high_results
        }
    
    # ============================
    # HEATMAPS
    # ============================
    
    def create_complete_contingency(heavy_counts, light_counts, vgene_dist):
        all_heavy = list(heavy_counts.keys())
        all_light = list(light_counts.keys())
        
        if len(all_heavy) == 0 or len(all_light) == 0:
            return None, [], []
        
        contingency = np.zeros((len(all_heavy), len(all_light)))
        
        for i, heavy_gene in enumerate(all_heavy):
            for j, light_gene in enumerate(all_light):
                contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        return contingency, all_heavy, all_light
    
    def create_heatmap(heavy_counts, light_counts, vgene_dist, title, cmap='YlOrRd', complete=False):
        if len(heavy_counts) == 0 or len(light_counts) == 0:
            print(f"No data for {title}")
            return None
        
        if complete:
            result = create_complete_contingency(heavy_counts, light_counts, vgene_dist)
            if result[0] is None:
                print(f"No data for {title}")
                return None
            contingency, heavy_genes, light_genes = result
        else:
            heavy_genes = [gene for gene, _ in heavy_counts.most_common(5)]
            light_genes = [gene for gene, _ in light_counts.most_common(5)]
            
            if len(heavy_genes) < 2 or len(light_genes) < 2:
                print(f"Insufficient data for {title}")
                return None
            
            contingency = np.zeros((len(heavy_genes), len(light_genes)))
            for i, heavy_gene in enumerate(heavy_genes):
                for j, light_gene in enumerate(light_genes):
                    contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        # Dynamic figure sizing
        fig_width = max(8, len(light_genes) * 0.6)
        fig_height = max(6, len(heavy_genes) * 0.5)
        
        fig, ax = plt.subplots(1, 1, figsize=(fig_width, fig_height))
        
        im = ax.imshow(contingency, cmap=cmap, aspect='auto')
        ax.set_xticks(range(len(light_genes)))
        ax.set_yticks(range(len(heavy_genes)))
        ax.set_xticklabels(light_genes, rotation=45, ha='right', fontsize=16)
        ax.set_yticklabels(heavy_genes, fontsize=16)
        ax.set_title(title)
        ax.set_xlabel('Light V Genes')
        ax.set_ylabel('Heavy V Genes')
        
        # Add annotations for smaller heatmaps
        if len(heavy_genes) <= 15 and len(light_genes) <= 15:
            for i in range(len(heavy_genes)):
                for j in range(len(light_genes)):
                    text_color = 'white' if contingency[i, j] > np.max(contingency) * 0.5 else 'black'
                    ax.text(j, i, int(contingency[i, j]), ha='center', va='center', 
                           fontsize=28, color=text_color)
        
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.show()
        
        return contingency
    
    def create_heatmaps_for_group(results, data_type="Generated"):
        """Create heatmaps for a specific group and data type"""
        
        group_name = results['group_name']
        prefix = 'gen_' if data_type == "Generated" else 'true_'
        cmap = 'YlOrRd' if data_type == "Generated" else 'Reds'
        
        print(f"\n--- {data_type.upper()} V GENE HEATMAPS - {group_name.upper()} ---")
        
        # Get data
        heavy_counts_full = results[f'{prefix}heavy_counts_full']
        light_counts_full = results[f'{prefix}light_counts_full']
        heavy_counts_simple = results[f'{prefix}heavy_counts_simple']
        light_counts_simple = results[f'{prefix}light_counts_simple']
        heavy_counts_family = results[f'{prefix}heavy_counts_family']
        light_counts_family = results[f'{prefix}light_counts_family']
        
        vgene_dist_full = results[f'{prefix}vgene_dist_full']
        vgene_dist_simple = results[f'{prefix}vgene_dist_simple']
        vgene_dist_family = results[f'{prefix}vgene_dist_family']
        
        coverage = results['coverage']
        
        # Only create heavy-light heatmaps if we have heavy chain data
        if heavy_gene_col and len(heavy_counts_full) > 0:
            # 5x5 heatmaps
            create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                          f"{data_type} V Gene Associations - {group_name} (Full Names, Top 5x5)", cmap)
            
            create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                          f"{data_type} V Gene Associations - {group_name} (Simplified, Top 5x5)", cmap)
            
            create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                          f"{data_type} V Gene Associations - {group_name} (Families, Top 5x5)", cmap)
            
            # Complete heatmaps when feasible
            coverage_key = f'{prefix[:-1]}_family'
            if coverage[coverage_key][0] <= 15 and coverage[coverage_key][1] <= 15:
                create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                              f"Complete {data_type} Family Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
            
            coverage_key = f'{prefix[:-1]}_simple'
            if coverage[coverage_key][0] <= 20 and coverage[coverage_key][1] <= 20:
                create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                              f"Complete {data_type} Simplified Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
            
            coverage_key = f'{prefix[:-1]}_full'
            if coverage[coverage_key][0] <= 25 and coverage[coverage_key][1] <= 25:
                create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                              f"Complete {data_type} Full Name Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
        else:
            print(f"Skipping heavy-light heatmaps for {data_type} - {group_name} (no heavy chain data)")
    
    print("\n" + "="*80 + "\n")
    print("CREATING HEATMAPS...")
    
    # Create heatmaps for both groups
    for results in [low_results, high_results]:
        create_heatmaps_for_group(results, "Generated")
        create_heatmaps_for_group(results, "True")
    
    # Create locus heatmaps for both groups
    print("\n--- LOCUS ASSOCIATION HEATMAPS ---")
    
    for results in [low_results, high_results]:
        group_name = results['group_name']
        locus_dist = results['locus_dist']
        
        if len(locus_dist) == 0:
            print(f"No locus data for {group_name}")
            continue
        
        true_loci = list(locus_dist.keys())
        gen_loci = list(set([gen for true_dist in locus_dist.values() for gen in true_dist.keys()]))
        
        if len(true_loci) >= 2 and len(gen_loci) >= 2:
            locus_contingency = np.zeros((len(true_loci), len(gen_loci)))
            
            for i, true_locus in enumerate(true_loci):
                for j, gen_locus in enumerate(gen_loci):
                    locus_contingency[i, j] = locus_dist[true_locus][gen_locus]
            
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            im = ax.imshow(locus_contingency, cmap='Blues', aspect='auto')
            ax.set_xticks(range(len(gen_loci)))
            ax.set_yticks(range(len(true_loci)))
            ax.set_xticklabels(gen_loci, rotation=45, ha='right')
            ax.set_yticklabels(true_loci)
            ax.set_title(f'Locus Associations - {group_name}')
            ax.set_xlabel('Generated Light Chain Loci')
            ax.set_ylabel('True Light Chain Loci')
            
            # Add annotations
            for i in range(len(true_loci)):
                for j in range(len(gen_loci)):
                    text_color = 'white' if locus_contingency[i, j] > np.max(locus_contingency) * 0.5 else 'black'
                    ax.text(j, i, int(locus_contingency[i, j]), ha='center', va='center', color=text_color)
            
            plt.colorbar(im, ax=ax)
            plt.tight_layout()
            plt.show()
        else:
            print(f"Insufficient locus data for {group_name}")
    
    # ============================
    # COMPARISON ANALYSIS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("COMPARISON ANALYSIS BETWEEN GROUPS")
    print("="*50)
    
    print(f"\nLocus Accuracy Comparison:")
    print(f"Low pairing scores: {low_results['overall_locus_accuracy']:.1f}%")
    print(f"High pairing scores: {high_results['overall_locus_accuracy']:.1f}%")
    print(f"Difference: {high_results['overall_locus_accuracy'] - low_results['overall_locus_accuracy']:.1f} percentage points")
    
    print(f"\nV Gene Diversity Comparison:")
    for level in ['full', 'simple', 'family']:
        low_gen = len(low_results[f'gen_light_counts_{level}'])
        high_gen = len(high_results[f'gen_light_counts_{level}'])
        low_true = len(low_results[f'true_light_counts_{level}'])
        high_true = len(high_results[f'true_light_counts_{level}'])
        
        print(f"{level.capitalize()} - Generated light genes: Low={low_gen}, High={high_gen}")
        print(f"{level.capitalize()} - True light genes: Low={low_true}, High={high_true}")
    
    # Top genes comparison
    print(f"\nTop 5 Generated Light V Genes Comparison:")
    print(f"{'Low Score Group':<30} {'High Score Group':<30}")
    print("-" * 60)
    
    low_top5_gen = low_results['gen_light_counts_full'].most_common(5)
    high_top5_gen = high_results['gen_light_counts_full'].most_common(5)
    
    max_rows = max(len(low_top5_gen), len(high_top5_gen))
    for i in range(max_rows):
        low_entry = f"{low_top5_gen[i][0]} ({low_top5_gen[i][1]})" if i < len(low_top5_gen) else ""
        high_entry = f"{high_top5_gen[i][0]} ({high_top5_gen[i][1]})" if i < len(high_top5_gen) else ""
        print(f"{low_entry:<30} {high_entry:<30}")
    
    print(f"\nTop 5 True Light V Genes Comparison:")
    print(f"{'Low Score Group':<30} {'High Score Group':<30}")
    print("-" * 60)
    
    low_top5_true = low_results['true_light_counts_full'].most_common(5)
    high_top5_true = high_results['true_light_counts_full'].most_common(5)
    
    max_rows = max(len(low_top5_true), len(high_top5_true))
    for i in range(max_rows):
        low_entry = f"{low_top5_true[i][0]} ({low_top5_true[i][1]})" if i < len(low_top5_true) else ""
        high_entry = f"{high_top5_true[i][0]} ({high_top5_true[i][1]})" if i < len(high_top5_true) else ""
        print(f"{low_entry:<30} {high_entry:<30}")
    
    # ============================
    # RETURN RESULTS
    # ============================
    
    return {
        'overall_stats': {
            'total_rows': len(df),
            'rows_with_scores': len(df_clean),
            'low_score_rows': len(df_low),
            'high_score_rows': len(df_high),
            'score_stats': {
                'overall_mean': df_clean['pairing_scores'].mean(),
                'overall_median': df_clean['pairing_scores'].median(),
                'low_mean': df_low['pairing_scores'].mean() if len(df_low) > 0 else None,
                'high_mean': df_high['pairing_scores'].mean() if len(df_high) > 0 else None
            }
        },
        'low_score_group': low_results,
        'high_score_group': high_results,
        'heavy_gene_col': heavy_gene_col
    }


In [None]:

csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/pairing_result_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols_merged_complete.csv'
    
try:
    results = analyze_vgene_by_pairing_score(csv_file)
    if results:
        print("\nAnalysis completed successfully!")
        print("Results contain separate analyses for low (<0.5) and high (≥0.5) pairing score groups.")
        print("Generated vs True light chain V gene associations are compared for each group.")
        
except FileNotFoundError:
    print(f"Error: Could not find the file '{csv_file}'")
    print("Please make sure the file path is correct.")
except Exception as e:
    print(f"An error occurred: {str(e)}")



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from scipy import stats

def simplify_gene_name(gene_name):
    """Remove allele designation (part after *) from gene name"""
    if pd.isna(gene_name):
        return gene_name
    return str(gene_name).split('*')[0]

def extract_gene_family(gene_name):
    """Extract gene family (e.g., IGHV3 from IGHV3-23*01)"""
    if pd.isna(gene_name):
        return gene_name
    simplified = str(gene_name).split('*')[0]
    # Extract family part (everything before the dash)
    if '-' in simplified:
        return simplified.split('-')[0]
    return simplified

def analyze_vgene_by_pairing_score(csv_file):
    """
    Analyze V gene distributions grouped by pairing scores (<0.5 vs ≥0.5)
    and compare generated vs true light chain V genes.
    
    Parameters:
    csv_file (str): Path to the CSV file
    """
    
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    print("Dataset Overview:")
    print(f"Total rows: {len(df)}")
    print(f"Columns: {df.columns.tolist()}")
    
    # Check for required columns
    required_cols = ['pairing_scores', 'true_light_gene_name', 'gen_light_1_first_hit_gene']
    
    # Check for heavy chain V gene column
    heavy_gene_col = None
    possible_heavy_cols = ['heavy_gene_name', 'true_v_gene_simple', 'heavy_v_gene', 'true_heavy_gene']
    
    for col in possible_heavy_cols:
        if col in df.columns:
            heavy_gene_col = col
            break
    
    if heavy_gene_col:
        required_cols.append(heavy_gene_col)
        print(f"✅ Using heavy chain V gene column: {heavy_gene_col}")
    else:
        print(f"⚠️ No heavy chain V gene column found. Searched for: {possible_heavy_cols}")
        print(f"Available columns: {df.columns.tolist()}")
        print(f"Will skip heavy-light V gene pairing analysis and focus on light chain only.")
    
    missing_cols = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        print(f"❌ Missing required columns: {missing_cols}")
        return None
    
    print(f"✅ Required columns found")
    
    # ============================
    # DATA PREPARATION
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("DATA PREPARATION")
    print("="*50)
    
    # Remove rows with missing pairing scores
    df_clean = df.dropna(subset=['pairing_scores']).copy()
    print(f"Rows with pairing scores: {len(df_clean)}/{len(df)} ({len(df_clean)/len(df)*100:.1f}%)")
    
    # Group by pairing score threshold
    df_low = df_clean[df_clean['pairing_scores'] < 0.5].copy()
    df_high = df_clean[df_clean['pairing_scores'] >= 0.5].copy()
    
    print(f"Low pairing scores (<0.5): {len(df_low)} rows ({len(df_low)/len(df_clean)*100:.1f}%)")
    print(f"High pairing scores (≥0.5): {len(df_high)} rows ({len(df_high)/len(df_clean)*100:.1f}%)")
    
    # Pairing score statistics
    print(f"\nPairing Score Statistics:")
    print(f"Overall - Mean: {df_clean['pairing_scores'].mean():.3f}, Median: {df_clean['pairing_scores'].median():.3f}")
    if len(df_low) > 0:
        print(f"Low group - Mean: {df_low['pairing_scores'].mean():.3f}, Range: {df_low['pairing_scores'].min():.3f}-{df_low['pairing_scores'].max():.3f}")
    if len(df_high) > 0:
        print(f"High group - Mean: {df_high['pairing_scores'].mean():.3f}, Range: {df_high['pairing_scores'].min():.3f}-{df_high['pairing_scores'].max():.3f}")
    
    def analyze_group(df_group, group_name):
        """Analyze V gene distributions for a specific pairing score group"""
        
        if len(df_group) == 0:
            print(f"\n⚠️ No data for {group_name} group")
            return None
        
        print(f"\n{'='*60}")
        print(f"{group_name.upper()} PAIRING SCORE GROUP ANALYSIS")
        print(f"{'='*60}")
        
        results = {}
        
        # ============================
        # GENERATED V GENE ANALYSIS
        # ============================
        
        print(f"\nGENERATED LIGHT CHAIN V GENE ANALYSIS - {group_name}")
        print("="*50)
        
        # Collect generated light chain V genes
        gen_light_genes = []
        for idx, row in df_group.iterrows():
            gen_light_gene = row['gen_light_1_first_hit_gene']
            if not pd.isna(gen_light_gene):
                gen_light_genes.append(gen_light_gene)
        
        print(f"Generated light V genes: {len(gen_light_genes)}")
        
        # Create simplified and family versions for generated
        gen_light_simple = [simplify_gene_name(gene) for gene in gen_light_genes if not pd.isna(gene)]
        gen_light_family = [extract_gene_family(gene) for gene in gen_light_genes if not pd.isna(gene)]
        
        # Count frequencies for generated
        gen_light_counts_full = Counter(gen_light_genes)
        gen_light_counts_simple = Counter([gene for gene in gen_light_simple if not pd.isna(gene)])
        gen_light_counts_family = Counter([gene for gene in gen_light_family if not pd.isna(gene)])
        
        print(f"Generated - Full: {len(gen_light_counts_full)} unique genes")
        print(f"Generated - Simple: {len(gen_light_counts_simple)} unique genes")
        print(f"Generated - Family: {len(gen_light_counts_family)} unique families")
        
        # ============================
        # TRUE V GENE ANALYSIS
        # ============================
        
        print(f"\nTRUE LIGHT CHAIN V GENE ANALYSIS - {group_name}")
        print("="*50)
        
        # Collect true light chain V genes
        true_light_genes = []
        for idx, row in df_group.iterrows():
            true_light_gene = row['true_light_gene_name']
            if not pd.isna(true_light_gene):
                true_light_genes.append(true_light_gene)
        
        print(f"True light V genes: {len(true_light_genes)}")
        
        # Create simplified and family versions for true
        true_light_simple = [simplify_gene_name(gene) for gene in true_light_genes if not pd.isna(gene)]
        true_light_family = [extract_gene_family(gene) for gene in true_light_genes if not pd.isna(gene)]
        
        # Count frequencies for true
        true_light_counts_full = Counter(true_light_genes)
        true_light_counts_simple = Counter([gene for gene in true_light_simple if not pd.isna(gene)])
        true_light_counts_family = Counter([gene for gene in true_light_family if not pd.isna(gene)])
        
        print(f"True - Full: {len(true_light_counts_full)} unique genes")
        print(f"True - Simple: {len(true_light_counts_simple)} unique genes") 
        print(f"True - Family: {len(true_light_counts_family)} unique families")
        
        # ============================
        # HEAVY-LIGHT PAIRING ANALYSIS (if heavy chain data available)
        # ============================
        
        if heavy_gene_col:
            print(f"\nHEAVY-LIGHT V GENE PAIRING ANALYSIS - {group_name}")
            print("="*50)
            
            # Generated heavy-light pairs
            gen_pairs_full = []
            gen_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            
            for idx, row in df_group.iterrows():
                heavy_gene = row[heavy_gene_col]
                gen_light_gene = row['gen_light_1_first_hit_gene']
                
                if pd.isna(heavy_gene) or pd.isna(gen_light_gene):
                    continue
                
                gen_pairs_full.append((heavy_gene, gen_light_gene))
                gen_vgene_dist_full[heavy_gene][gen_light_gene] += 1
            
            # True heavy-light pairs
            true_pairs_full = []
            true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            
            for idx, row in df_group.iterrows():
                heavy_gene = row[heavy_gene_col]
                true_light_gene = row['true_light_gene_name']
                
                if pd.isna(heavy_gene) or pd.isna(true_light_gene):
                    continue
                
                true_pairs_full.append((heavy_gene, true_light_gene))
                true_vgene_dist_full[heavy_gene][true_light_gene] += 1
            
            print(f"Generated heavy-light pairs: {len(gen_pairs_full)}")
            print(f"True heavy-light pairs: {len(true_pairs_full)}")
            
            # Create simplified and family versions for pairing
            gen_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                               for pair in gen_pairs_full]
            gen_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                               for pair in gen_pairs_full]
            
            true_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                                for pair in true_pairs_full]
            true_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                                for pair in true_pairs_full]
            
            # Create distribution dictionaries
            gen_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            
            for heavy_gene, light_gene in gen_pairs_simple:
                if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                    gen_vgene_dist_simple[heavy_gene][light_gene] += 1
            
            for heavy_family, light_family in gen_pairs_family:
                if not pd.isna(heavy_family) and not pd.isna(light_family):
                    gen_vgene_dist_family[heavy_family][light_family] += 1
            
            for heavy_gene, light_gene in true_pairs_simple:
                if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                    true_vgene_dist_simple[heavy_gene][light_gene] += 1
            
            for heavy_family, light_family in true_pairs_family:
                if not pd.isna(heavy_family) and not pd.isna(light_family):
                    true_vgene_dist_family[heavy_family][light_family] += 1
            
            # Count heavy chain frequencies
            gen_heavy_counts_full = Counter([pair[0] for pair in gen_pairs_full])
            gen_heavy_counts_simple = Counter([pair[0] for pair in gen_pairs_simple if not pd.isna(pair[0])])
            gen_heavy_counts_family = Counter([pair[0] for pair in gen_pairs_family if not pd.isna(pair[0])])
            
            true_heavy_counts_full = Counter([pair[0] for pair in true_pairs_full])
            true_heavy_counts_simple = Counter([pair[0] for pair in true_pairs_simple if not pd.isna(pair[0])])
            true_heavy_counts_family = Counter([pair[0] for pair in true_pairs_family if not pd.isna(pair[0])])
            
        else:
            # No heavy chain data - create empty distributions
            gen_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            gen_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
            true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
            
            gen_heavy_counts_full = Counter()
            gen_heavy_counts_simple = Counter()
            gen_heavy_counts_family = Counter()
            true_heavy_counts_full = Counter()
            true_heavy_counts_simple = Counter()
            true_heavy_counts_family = Counter()
        
        # ============================
        # LOCUS ANALYSIS
        # ============================
        
        print(f"\nLOCUS ANALYSIS - {group_name}")
        print("="*50)
        
        # Analyze true vs generated loci
        locus_pairs = []
        locus_dist = defaultdict(lambda: defaultdict(int))
        locus_accuracy = defaultdict(lambda: {'total': 0, 'correct': 0})
        
        for idx, row in df_group.iterrows():
            true_locus = row['true_light_light_locus']
            gen_locus = row['gen_light_1_light_locus']
            
            if pd.isna(true_locus) or pd.isna(gen_locus):
                continue
            
            locus_pairs.append((true_locus, gen_locus))
            locus_dist[true_locus][gen_locus] += 1
            locus_accuracy[true_locus]['total'] += 1
            if gen_locus == true_locus:
                locus_accuracy[true_locus]['correct'] += 1
        
        print(f"True-generated locus pairs: {len(locus_pairs)}")
        
        # Calculate overall locus accuracy
        total_generated = sum([stats['total'] for stats in locus_accuracy.values()])
        total_correct = sum([stats['correct'] for stats in locus_accuracy.values()])
        overall_accuracy = (total_correct / total_generated * 100) if total_generated > 0 else 0
        
        print(f"Locus accuracy: {overall_accuracy:.1f}% ({total_correct}/{total_generated})")
        
        # ============================
        # COVERAGE ANALYSIS
        # ============================
        
        def analyze_coverage(heavy_counts, light_counts, vgene_dist, level_name):
            total_heavy_genes = len(heavy_counts)
            total_light_genes = len(light_counts)
            
            if total_heavy_genes == 0 or total_light_genes == 0:
                return 0, 0, 0
            
            # Count total associations
            total_associations = sum(sum(light_dist.values()) for light_dist in vgene_dist.values())
            
            # Get top 5 genes
            top_5_heavy = [gene for gene, _ in heavy_counts.most_common(5)]
            top_5_light = [gene for gene, _ in light_counts.most_common(5)]
            
            # Calculate coverage of top 5 associations
            top_5_associations = 0
            for heavy in top_5_heavy:
                for light in top_5_light:
                    top_5_associations += vgene_dist[heavy][light]
            
            coverage_pct = (top_5_associations / total_associations * 100) if total_associations > 0 else 0
            
            return total_heavy_genes, total_light_genes, coverage_pct
        
        print(f"\nCoverage Analysis - {group_name}:")
        gen_full_cov = analyze_coverage(gen_heavy_counts_full, gen_light_counts_full, gen_vgene_dist_full, "Gen Full")
        gen_simple_cov = analyze_coverage(gen_heavy_counts_simple, gen_light_counts_simple, gen_vgene_dist_simple, "Gen Simple")
        gen_family_cov = analyze_coverage(gen_heavy_counts_family, gen_light_counts_family, gen_vgene_dist_family, "Gen Family")
        true_full_cov = analyze_coverage(true_heavy_counts_full, true_light_counts_full, true_vgene_dist_full, "True Full")
        true_simple_cov = analyze_coverage(true_heavy_counts_simple, true_light_counts_simple, true_vgene_dist_simple, "True Simple")
        true_family_cov = analyze_coverage(true_heavy_counts_family, true_light_counts_family, true_vgene_dist_family, "True Family")
        
        print(f"Generated - Full: {gen_full_cov[0]}x{gen_light_counts_full if gen_light_counts_full else 0}, 5x5 covers {gen_full_cov[2]:.1f}%")
        print(f"Generated - Simple: {gen_simple_cov[0]}x{len(gen_light_counts_simple)}, 5x5 covers {gen_simple_cov[2]:.1f}%")
        print(f"Generated - Family: {gen_family_cov[0]}x{len(gen_light_counts_family)}, 5x5 covers {gen_family_cov[2]:.1f}%")
        print(f"True - Full: {true_full_cov[0]}x{len(true_light_counts_full)}, 5x5 covers {true_full_cov[2]:.1f}%")
        print(f"True - Simple: {true_simple_cov[0]}x{len(true_light_counts_simple)}, 5x5 covers {true_simple_cov[2]:.1f}%")
        print(f"True - Family: {true_family_cov[0]}x{len(true_light_counts_family)}, 5x5 covers {true_family_cov[2]:.1f}%")
        
        # Store results
        results = {
            'group_name': group_name,
            'n_rows': len(df_group),
            'heavy_gene_col': heavy_gene_col,
            'gen_light_counts_full': gen_light_counts_full,
            'gen_light_counts_simple': gen_light_counts_simple,
            'gen_light_counts_family': gen_light_counts_family,
            'true_light_counts_full': true_light_counts_full,
            'true_light_counts_simple': true_light_counts_simple,
            'true_light_counts_family': true_light_counts_family,
            'gen_heavy_counts_full': gen_heavy_counts_full,
            'gen_heavy_counts_simple': gen_heavy_counts_simple,
            'gen_heavy_counts_family': gen_heavy_counts_family,
            'true_heavy_counts_full': true_heavy_counts_full,
            'true_heavy_counts_simple': true_heavy_counts_simple,
            'true_heavy_counts_family': true_heavy_counts_family,
            'gen_vgene_dist_full': dict(gen_vgene_dist_full),
            'gen_vgene_dist_simple': dict(gen_vgene_dist_simple),
            'gen_vgene_dist_family': dict(gen_vgene_dist_family),
            'true_vgene_dist_full': dict(true_vgene_dist_full),
            'true_vgene_dist_simple': dict(true_vgene_dist_simple),
            'true_vgene_dist_family': dict(true_vgene_dist_family),
            'locus_dist': dict(locus_dist),
            'locus_accuracy': dict(locus_accuracy),
            'overall_locus_accuracy': overall_accuracy,
            'coverage': {
                'gen_full': gen_full_cov,
                'gen_simple': gen_simple_cov,
                'gen_family': gen_family_cov,
                'true_full': true_full_cov,
                'true_simple': true_simple_cov,
                'true_family': true_family_cov
            }
        }
        
        return results
    
    # Analyze both groups
    low_results = analyze_group(df_low, "Low Pairing Score (<0.5)")
    high_results = analyze_group(df_high, "High Pairing Score (≥0.5)")
    
    # Skip if either group has no data
    if low_results is None or high_results is None:
        print("\n⚠️ One or both groups have no data. Skipping heatmap creation.")
        return {
            'overall_stats': {
                'total_rows': len(df),
                'rows_with_scores': len(df_clean),
                'low_score_rows': len(df_low),
                'high_score_rows': len(df_high)
            },
            'low_score_group': low_results,
            'high_score_group': high_results
        }
    
    # ============================
    # HEATMAPS
    # ============================
    
    def create_complete_contingency(heavy_counts, light_counts, vgene_dist):
        all_heavy = list(heavy_counts.keys())
        all_light = list(light_counts.keys())
        
        if len(all_heavy) == 0 or len(all_light) == 0:
            return None, [], []
        
        contingency = np.zeros((len(all_heavy), len(all_light)))
        
        for i, heavy_gene in enumerate(all_heavy):
            for j, light_gene in enumerate(all_light):
                contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        return contingency, all_heavy, all_light
    
    def create_heatmap(heavy_counts, light_counts, vgene_dist, title, cmap='YlOrRd', complete=False):
        if len(heavy_counts) == 0 or len(light_counts) == 0:
            print(f"No data for {title}")
            return None
        
        if complete:
            result = create_complete_contingency(heavy_counts, light_counts, vgene_dist)
            if result[0] is None:
                print(f"No data for {title}")
                return None
            contingency, heavy_genes, light_genes = result
        else:
            heavy_genes = [gene for gene, _ in heavy_counts.most_common(5)]
            light_genes = [gene for gene, _ in light_counts.most_common(5)]
            
            if len(heavy_genes) < 2 or len(light_genes) < 2:
                print(f"Insufficient data for {title}")
                return None
            
            contingency = np.zeros((len(heavy_genes), len(light_genes)))
            for i, heavy_gene in enumerate(heavy_genes):
                for j, light_gene in enumerate(light_genes):
                    contingency[i, j] = vgene_dist[heavy_gene][light_gene]
        
        # Dynamic figure sizing
        fig_width = max(8, len(light_genes) * 0.6)
        fig_height = max(6, len(heavy_genes) * 0.5)
        
        fig, ax = plt.subplots(1, 1, figsize=(fig_width, fig_height))
        
        im = ax.imshow(contingency, cmap=cmap, aspect='auto')
        ax.set_xticks(range(len(light_genes)))
        ax.set_yticks(range(len(heavy_genes)))
        ax.set_xticklabels(light_genes, rotation=45, ha='right', fontsize=16)
        ax.set_yticklabels(heavy_genes, fontsize=16)
        ax.set_title(title)
        ax.set_xlabel('Light V Genes')
        ax.set_ylabel('Heavy V Genes')
        
        # Add annotations for smaller heatmaps
        if len(heavy_genes) <= 15 and len(light_genes) <= 15:
            for i in range(len(heavy_genes)):
                for j in range(len(light_genes)):
                    text_color = 'white' if contingency[i, j] > np.max(contingency) * 0.5 else 'black'
                    ax.text(j, i, int(contingency[i, j]), ha='center', va='center', 
                           fontsize=14, color=text_color)
        
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.show()
        
        return contingency
    
    def create_heatmaps_for_group(results, data_type="Generated"):
        """Create heatmaps for a specific group and data type"""
        
        group_name = results['group_name']
        prefix = 'gen_' if data_type == "Generated" else 'true_'
        cmap = 'YlOrRd' if data_type == "Generated" else 'Reds'
        
        print(f"\n--- {data_type.upper()} V GENE HEATMAPS - {group_name.upper()} ---")
        
        # Get data
        heavy_counts_full = results[f'{prefix}heavy_counts_full']
        light_counts_full = results[f'{prefix}light_counts_full']
        heavy_counts_simple = results[f'{prefix}heavy_counts_simple']
        light_counts_simple = results[f'{prefix}light_counts_simple']
        heavy_counts_family = results[f'{prefix}heavy_counts_family']
        light_counts_family = results[f'{prefix}light_counts_family']
        
        vgene_dist_full = results[f'{prefix}vgene_dist_full']
        vgene_dist_simple = results[f'{prefix}vgene_dist_simple']
        vgene_dist_family = results[f'{prefix}vgene_dist_family']
        
        coverage = results['coverage']
        
        # Only create heavy-light heatmaps if we have heavy chain data
        if heavy_gene_col and len(heavy_counts_full) > 0:
            # 5x5 heatmaps
            create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                          f"{data_type} V Gene Associations - {group_name} (Full Names, Top 5x5)", cmap)
            
            create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                          f"{data_type} V Gene Associations - {group_name} (Simplified, Top 5x5)", cmap)
            
            create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                          f"{data_type} V Gene Associations - {group_name} (Families, Top 5x5)", cmap)
            
            # Complete heatmaps when feasible
            coverage_key = f'{prefix[:-1]}_family'
            if coverage[coverage_key][0] <= 15 and coverage[coverage_key][1] <= 15:
                create_heatmap(heavy_counts_family, light_counts_family, vgene_dist_family, 
                              f"Complete {data_type} Family Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
            
            coverage_key = f'{prefix[:-1]}_simple'
            if coverage[coverage_key][0] <= 20 and coverage[coverage_key][1] <= 20:
                create_heatmap(heavy_counts_simple, light_counts_simple, vgene_dist_simple, 
                              f"Complete {data_type} Simplified Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
            
            coverage_key = f'{prefix[:-1]}_full'
            if coverage[coverage_key][0] <= 25 and coverage[coverage_key][1] <= 25:
                create_heatmap(heavy_counts_full, light_counts_full, vgene_dist_full, 
                              f"Complete {data_type} Full Name Associations - {group_name} ({coverage[coverage_key][0]}x{coverage[coverage_key][1]})", 
                              cmap, complete=True)
        else:
            print(f"Skipping heavy-light heatmaps for {data_type} - {group_name} (no heavy chain data)")
    
    print("\n" + "="*80 + "\n")
    print("CREATING HEATMAPS...")
    
    # Create heatmaps for both groups
    for results in [low_results, high_results]:
        create_heatmaps_for_group(results, "Generated")
        create_heatmaps_for_group(results, "True")
    
    # Create locus heatmaps for both groups
    print("\n--- LOCUS ASSOCIATION HEATMAPS ---")
    
    for results in [low_results, high_results]:
        group_name = results['group_name']
        locus_dist = results['locus_dist']
        
        if len(locus_dist) == 0:
            print(f"No locus data for {group_name}")
            continue
        
        true_loci = list(locus_dist.keys())
        gen_loci = list(set([gen for true_dist in locus_dist.values() for gen in true_dist.keys()]))
        
        if len(true_loci) >= 2 and len(gen_loci) >= 2:
            locus_contingency = np.zeros((len(true_loci), len(gen_loci)))
            
            for i, true_locus in enumerate(true_loci):
                for j, gen_locus in enumerate(gen_loci):
                    locus_contingency[i, j] = locus_dist[true_locus][gen_locus]
            
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            im = ax.imshow(locus_contingency, cmap='Blues', aspect='auto')
            ax.set_xticks(range(len(gen_loci)))
            ax.set_yticks(range(len(true_loci)))
            ax.set_xticklabels(gen_loci, rotation=45, ha='right')
            ax.set_yticklabels(true_loci)
            ax.set_title(f'Locus Associations - {group_name}')
            ax.set_xlabel('Generated Light Chain Loci')
            ax.set_ylabel('True Light Chain Loci')
            
            # Add annotations
            for i in range(len(true_loci)):
                for j in range(len(gen_loci)):
                    text_color = 'white' if locus_contingency[i, j] > np.max(locus_contingency) * 0.5 else 'black'
                    ax.text(j, i, int(locus_contingency[i, j]), ha='center', va='center', color=text_color)
            
            plt.colorbar(im, ax=ax)
            plt.tight_layout()
            plt.show()
        else:
            print(f"Insufficient locus data for {group_name}")
    
    # ============================
    # V GENE CONSISTENCY ANALYSIS
    # ============================
    
    def analyze_vgene_consistency(results_list, analysis_name):
        """Analyze if certain heavy V genes consistently generate the same light V genes"""
        
        print(f"\n{'='*80}")
        print(f"V GENE CONSISTENCY ANALYSIS - {analysis_name}")
        print(f"{'='*80}")
        
        all_consistency_results = {}
        
        for results in results_list:
            if results is None:
                continue
                
            group_name = results['group_name']
            print(f"\n--- {group_name.upper()} ---")
            
            # Only analyze if we have heavy chain data
            if not heavy_gene_col or len(results['gen_heavy_counts_full']) == 0:
                print(f"No heavy-light pairing data for {group_name}")
                continue
            
            consistency_stats = {}
            
            # Analyze at different levels
            for level in ['full', 'simple', 'family']:
                print(f"\n{level.capitalize()} Level Analysis:")
                
                gen_vgene_dist = results[f'gen_vgene_dist_{level}']
                true_vgene_dist = results[f'true_vgene_dist_{level}']
                
                # Generated consistency analysis
                gen_consistency = analyze_heavy_light_consistency(gen_vgene_dist, f"Generated {level}")
                
                # True consistency analysis  
                true_consistency = analyze_heavy_light_consistency(true_vgene_dist, f"True {level}")
                
                consistency_stats[level] = {
                    'generated': gen_consistency,
                    'true': true_consistency
                }
            
            all_consistency_results[group_name] = consistency_stats
        
        return all_consistency_results
    
    def analyze_heavy_light_consistency(vgene_dist, analysis_type):
        """Analyze consistency of heavy-light V gene associations"""
        
        consistency_results = []
        
        for heavy_gene, light_dist in vgene_dist.items():
            if len(light_dist) == 0:
                continue
                
            total_associations = sum(light_dist.values())
            
            # Calculate diversity metrics
            most_common_light = max(light_dist.items(), key=lambda x: x[1])
            most_common_count = most_common_light[1]
            most_common_pct = (most_common_count / total_associations) * 100
            
            # Calculate entropy (lower = more consistent)
            proportions = np.array(list(light_dist.values())) / total_associations
            entropy = -np.sum(proportions * np.log2(proportions + 1e-10))
            
            # Calculate Gini coefficient (higher = more concentrated)
            sorted_counts = sorted(light_dist.values(), reverse=True)
            n = len(sorted_counts)
            cumsum = np.cumsum(sorted_counts)
            gini = (n + 1 - 2 * np.sum(cumsum) / cumsum[-1]) / n if cumsum[-1] > 0 else 0
            
            # Consistency classification
            if most_common_pct >= 90:
                consistency_level = "Very High"
            elif most_common_pct >= 75:
                consistency_level = "High"
            elif most_common_pct >= 50:
                consistency_level = "Moderate"
            else:
                consistency_level = "Low"
            
            consistency_results.append({
                'heavy_gene': heavy_gene,
                'total_associations': total_associations,
                'unique_light_genes': len(light_dist),
                'most_common_light': most_common_light[0],
                'most_common_count': most_common_count,
                'most_common_pct': most_common_pct,
                'entropy': entropy,
                'gini_coefficient': gini,
                'consistency_level': consistency_level,
                'light_distribution': dict(light_dist)
            })
        
        # Sort by consistency (highest percentage first)
        consistency_results.sort(key=lambda x: x['most_common_pct'], reverse=True)
        
        # Print results
        print(f"\n{analysis_type} Consistency Results:")
        print(f"{'Heavy V Gene':<25} {'Most Common Light':<25} {'Count':<8} {'%':<8} {'Unique':<8} {'Level':<12} {'Entropy':<8}")
        print("-" * 110)
        
        for result in consistency_results[:20]:  # Show top 20
            print(f"{result['heavy_gene']:<25} {result['most_common_light']:<25} "
                  f"{result['most_common_count']:<8} {result['most_common_pct']:<8.1f} "
                  f"{result['unique_light_genes']:<8} {result['consistency_level']:<12} {result['entropy']:<8.2f}")
        
        # Summary statistics
        if consistency_results:
            very_high = len([r for r in consistency_results if r['consistency_level'] == 'Very High'])
            high = len([r for r in consistency_results if r['consistency_level'] == 'High'])
            moderate = len([r for r in consistency_results if r['consistency_level'] == 'Moderate'])
            low = len([r for r in consistency_results if r['consistency_level'] == 'Low'])
            
            print(f"\nConsistency Summary:")
            print(f"  Very High (≥90%): {very_high} heavy V genes")
            print(f"  High (75-89%): {high} heavy V genes")
            print(f"  Moderate (50-74%): {moderate} heavy V genes")  
            print(f"  Low (<50%): {low} heavy V genes")
            
            # Find perfect matches (100% consistency)
            perfect_matches = [r for r in consistency_results if r['most_common_pct'] == 100.0]
            if perfect_matches:
                print(f"\n🎯 PERFECT MATCHES (100% consistency):")
                for match in perfect_matches:
                    print(f"  {match['heavy_gene']} → {match['most_common_light']} "
                          f"({match['total_associations']} associations)")
            
            # Find highly consistent genes (≥90%)
            highly_consistent = [r for r in consistency_results if r['most_common_pct'] >= 90 and r['most_common_pct'] < 100]
            if highly_consistent:
                print(f"\n⭐ HIGHLY CONSISTENT (≥90%):")
                for match in highly_consistent[:10]:  # Top 10
                    print(f"  {match['heavy_gene']} → {match['most_common_light']} "
                          f"({match['most_common_pct']:.1f}%, {match['total_associations']} total)")
        
        return consistency_results
    
    def create_consistency_heatmap(consistency_results, title, top_n=15):
        """Create a heatmap showing the most consistent heavy-light associations"""
        
        if not consistency_results:
            print(f"No data for {title}")
            return
        
        # Get top N most consistent heavy genes
        top_results = consistency_results[:top_n]
        
        if len(top_results) < 2:
            print(f"Insufficient data for {title}")
            return
        
        # Create matrix for heatmap
        heavy_genes = [r['heavy_gene'] for r in top_results]
        all_light_genes = set()
        for r in top_results:
            all_light_genes.update(r['light_distribution'].keys())
        all_light_genes = sorted(list(all_light_genes))
        
        # Create consistency matrix (percentages)
        consistency_matrix = np.zeros((len(heavy_genes), len(all_light_genes)))
        
        for i, result in enumerate(top_results):
            total = result['total_associations']
            for j, light_gene in enumerate(all_light_genes):
                count = result['light_distribution'].get(light_gene, 0)
                consistency_matrix[i, j] = (count / total) * 100 if total > 0 else 0
        
        # Create heatmap
        fig, ax = plt.subplots(1, 1, figsize=(max(12, len(all_light_genes) * 0.8), max(8, len(heavy_genes) * 0.6)))
        
        im = ax.imshow(consistency_matrix, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=100)
        ax.set_xticks(range(len(all_light_genes)))
        ax.set_yticks(range(len(heavy_genes)))
        ax.set_xticklabels(all_light_genes, rotation=45, ha='right', fontsize=16)
        ax.set_yticklabels(heavy_genes, fontsize=16)
        ax.set_title(f'{title}\n(% of associations for each heavy V gene)')
        ax.set_xlabel('Light V Genes')
        ax.set_ylabel('Heavy V Genes (sorted by consistency)')
        
        # Add annotations for high percentages
        for i in range(len(heavy_genes)):
            for j in range(len(all_light_genes)):
                value = consistency_matrix[i, j]
                if value >= 25:  # Only annotate if ≥25%
                    text_color = 'white' if value > 50 else 'black'
                    ax.text(j, i, f'{value:.0f}%', ha='center', va='center', 
                           fontsize=7, color=text_color, weight='bold')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Percentage of Associations (%)')
        
        plt.tight_layout()
        plt.show()
    
    # Run consistency analysis
    print("\n" + "="*80 + "\n")
    print("ANALYZING V GENE CONSISTENCY...")
    
    consistency_results = analyze_vgene_consistency([low_results, high_results], "BY PAIRING SCORE GROUPS")
    
    # Create consistency heatmaps
    print("\n--- CONSISTENCY HEATMAPS ---")
    
    for group_results in [low_results, high_results]:
        if group_results is None or not heavy_gene_col:
            continue
            
        group_name = group_results['group_name']
        
        for level in ['full', 'simple', 'family']:
            gen_vgene_dist = group_results[f'gen_vgene_dist_{level}']
            true_vgene_dist = group_results[f'true_vgene_dist_{level}']
            
            if len(gen_vgene_dist) > 0:
                gen_consistency = analyze_heavy_light_consistency(gen_vgene_dist, f"Generated {level}")
                create_consistency_heatmap(gen_consistency, 
                                         f"Generated V Gene Consistency - {group_name} ({level.capitalize()})")
            
            if len(true_vgene_dist) > 0:
                true_consistency = analyze_heavy_light_consistency(true_vgene_dist, f"True {level}")
                create_consistency_heatmap(true_consistency, 
                                         f"True V Gene Consistency - {group_name} ({level.capitalize()})")
    
    # ============================
    # COMPARISON ANALYSIS BETWEEN GROUPS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("COMPARISON ANALYSIS BETWEEN GROUPS")
    print("="*50)
    
    print(f"\nLocus Accuracy Comparison:")
    print(f"Low pairing scores: {low_results['overall_locus_accuracy']:.1f}%")
    print(f"High pairing scores: {high_results['overall_locus_accuracy']:.1f}%")
    print(f"Difference: {high_results['overall_locus_accuracy'] - low_results['overall_locus_accuracy']:.1f} percentage points")
    
    print(f"\nV Gene Diversity Comparison:")
    for level in ['full', 'simple', 'family']:
        low_gen = len(low_results[f'gen_light_counts_{level}'])
        high_gen = len(high_results[f'gen_light_counts_{level}'])
        low_true = len(low_results[f'true_light_counts_{level}'])
        high_true = len(high_results[f'true_light_counts_{level}'])
        
        print(f"{level.capitalize()} - Generated light genes: Low={low_gen}, High={high_gen}")
        print(f"{level.capitalize()} - True light genes: Low={low_true}, High={high_true}")
    
    # Top genes comparison
    print(f"\nTop 5 Generated Light V Genes Comparison:")
    print(f"{'Low Score Group':<30} {'High Score Group':<30}")
    print("-" * 60)
    
    low_top5_gen = low_results['gen_light_counts_full'].most_common(5)
    high_top5_gen = high_results['gen_light_counts_full'].most_common(5)
    
    max_rows = max(len(low_top5_gen), len(high_top5_gen))
    for i in range(max_rows):
        low_entry = f"{low_top5_gen[i][0]} ({low_top5_gen[i][1]})" if i < len(low_top5_gen) else ""
        high_entry = f"{high_top5_gen[i][0]} ({high_top5_gen[i][1]})" if i < len(high_top5_gen) else ""
        print(f"{low_entry:<30} {high_entry:<30}")
    
    print(f"\nTop 5 True Light V Genes Comparison:")
    print(f"{'Low Score Group':<30} {'High Score Group':<30}")
    print("-" * 60)
    
    low_top5_true = low_results['true_light_counts_full'].most_common(5)
    high_top5_true = high_results['true_light_counts_full'].most_common(5)
    
    max_rows = max(len(low_top5_true), len(high_top5_true))
    for i in range(max_rows):
        low_entry = f"{low_top5_true[i][0]} ({low_top5_true[i][1]})" if i < len(low_top5_true) else ""
        high_entry = f"{high_top5_true[i][0]} ({high_top5_true[i][1]})" if i < len(high_top5_true) else ""
        print(f"{low_entry:<30} {high_entry:<30}")
    
    # ============================
    # RETURN RESULTS
    # ============================
    
    return {
        'overall_stats': {
            'total_rows': len(df),
            'rows_with_scores': len(df_clean),
            'low_score_rows': len(df_low),
            'high_score_rows': len(df_high),
            'score_stats': {
                'overall_mean': df_clean['pairing_scores'].mean(),
                'overall_median': df_clean['pairing_scores'].median(),
                'low_mean': df_low['pairing_scores'].mean() if len(df_low) > 0 else None,
                'high_mean': df_high['pairing_scores'].mean() if len(df_high) > 0 else None
            }
        },
        'low_score_group': low_results,
        'high_score_group': high_results,
        'heavy_gene_col': heavy_gene_col,
        'consistency_analysis': consistency_results
    }



In [None]:
csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/pairing_result_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols_merged_complete.csv'

results = analyze_vgene_by_pairing_score(csv_file)
if results:
    print("\nAnalysis completed successfully!")
    print("Results contain separate analyses for low (<0.5) and high (≥0.5) pairing score groups.")
    print("Generated vs True light chain V gene associations are compared for each group.")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from scipy import stats

def simplify_gene_name(gene_name):
    """Remove allele designation (part after *) from gene name"""
    if pd.isna(gene_name):
        return gene_name
    return str(gene_name).split('*')[0]

def extract_gene_family(gene_name):
    """Extract gene family (e.g., IGHV3 from IGHV3-23*01)"""
    if pd.isna(gene_name):
        return gene_name
    simplified = str(gene_name).split('*')[0]
    # Extract family part (everything before the dash)
    if '-' in simplified:
        return simplified.split('-')[0]
    return simplified

def analyze_vgene_consistency(csv_file):
    """
    Analyze V gene consistency - whether certain heavy V genes always generate 
    the same light chain V genes, using all generated sequences.
    
    Parameters:
    csv_file (str): Path to the CSV file
    """
    
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    print("Dataset Overview:")
    print(f"Total rows: {len(df)}")
    print(f"Columns: {df.columns.tolist()}")
    
    # Check for true light V gene column
    true_light_vgene_col = None
    possible_true_light_cols = ['true_light_gene_name', 'true_light_v_gene', 'light_gene_name', 'true_light_v_gene_name']
    
    for col in possible_true_light_cols:
        if col in df.columns:
            true_light_vgene_col = col
            break
    
    print(f"True light V gene column: {true_light_vgene_col if true_light_vgene_col else 'NOT FOUND'}")
    
    # Get heavy chain V gene column
    heavy_gene_col = None
    possible_heavy_cols = ['heavy_gene_name', 'true_v_gene_simple', 'heavy_v_gene', 'true_heavy_gene']
    
    for col in possible_heavy_cols:
        if col in df.columns:
            heavy_gene_col = col
            break
    
    print(f"Heavy chain V gene column: {heavy_gene_col if heavy_gene_col else 'NOT FOUND'}")
    
    # Get generated light chain V gene columns
    light_gene_cols = [col for col in df.columns if col.startswith('gen_light_') and col.endswith('_gene_name')]
    light_locus_cols = [col for col in df.columns if col.startswith('gen_light_') and col.endswith('_light_locus')]
    
    print(f"Generated light gene columns found: {len(light_gene_cols)}")
    print(f"Generated light locus columns found: {len(light_locus_cols)}")
    
    if not heavy_gene_col:
        print("❌ No heavy chain V gene column found - cannot analyze heavy-light consistency")
        return None
    
    if len(light_gene_cols) == 0:
        print("❌ No generated light chain V gene columns found")
        return None
    
    # ============================
    # COLLECT V GENE PAIRS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("COLLECTING V GENE ASSOCIATIONS")
    print("="*50)
    
    # Generated heavy-light pairs
    gen_pairs_full = []
    gen_vgene_dist_full = defaultdict(lambda: defaultdict(int))
    
    for idx, row in df.iterrows():
        heavy_gene = row[heavy_gene_col]
        if pd.isna(heavy_gene):
            continue
            
        for col in light_gene_cols:
            light_gene = row[col]
            if not pd.isna(light_gene):
                gen_pairs_full.append((heavy_gene, light_gene))
                gen_vgene_dist_full[heavy_gene][light_gene] += 1
    
    print(f"Generated heavy-light V gene pairs: {len(gen_pairs_full)}")
    
    # Create simplified and family versions for generated
    gen_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                       for pair in gen_pairs_full]
    gen_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                       for pair in gen_pairs_full]
    
    gen_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
    gen_vgene_dist_family = defaultdict(lambda: defaultdict(int))
    
    for heavy_gene, light_gene in gen_pairs_simple:
        if not pd.isna(heavy_gene) and not pd.isna(light_gene):
            gen_vgene_dist_simple[heavy_gene][light_gene] += 1
    
    for heavy_family, light_family in gen_pairs_family:
        if not pd.isna(heavy_family) and not pd.isna(light_family):
            gen_vgene_dist_family[heavy_family][light_family] += 1
    
    # True heavy-light pairs (if available)
    true_vgene_dist_full = defaultdict(lambda: defaultdict(int))
    true_vgene_dist_simple = defaultdict(lambda: defaultdict(int))
    true_vgene_dist_family = defaultdict(lambda: defaultdict(int))
    
    if true_light_vgene_col:
        true_pairs_full = []
        
        for idx, row in df.iterrows():
            heavy_gene = row[heavy_gene_col]
            true_light_gene = row[true_light_vgene_col]
            if pd.isna(heavy_gene) or pd.isna(true_light_gene):
                continue
            
            true_pairs_full.append((heavy_gene, true_light_gene))
            true_vgene_dist_full[heavy_gene][true_light_gene] += 1
        
        print(f"True heavy-light V gene pairs: {len(true_pairs_full)}")
        
        # Create simplified and family versions for true
        true_pairs_simple = [(simplify_gene_name(pair[0]), simplify_gene_name(pair[1])) 
                            for pair in true_pairs_full]
        true_pairs_family = [(extract_gene_family(pair[0]), extract_gene_family(pair[1])) 
                            for pair in true_pairs_full]
        
        for heavy_gene, light_gene in true_pairs_simple:
            if not pd.isna(heavy_gene) and not pd.isna(light_gene):
                true_vgene_dist_simple[heavy_gene][light_gene] += 1
        
        for heavy_family, light_family in true_pairs_family:
            if not pd.isna(heavy_family) and not pd.isna(light_family):
                true_vgene_dist_family[heavy_family][light_family] += 1
    
    # Count frequencies
    gen_heavy_counts_full = Counter([pair[0] for pair in gen_pairs_full])
    gen_light_counts_full = Counter([pair[1] for pair in gen_pairs_full])
    gen_heavy_counts_simple = Counter([pair[0] for pair in gen_pairs_simple if not pd.isna(pair[0])])
    gen_light_counts_simple = Counter([pair[1] for pair in gen_pairs_simple if not pd.isna(pair[1])])
    gen_heavy_counts_family = Counter([pair[0] for pair in gen_pairs_family if not pd.isna(pair[0])])
    gen_light_counts_family = Counter([pair[1] for pair in gen_pairs_family if not pd.isna(pair[1])])
    
    print(f"Generated - Full Names: {len(gen_heavy_counts_full)} heavy, {len(gen_light_counts_full)} light")
    print(f"Generated - Simplified: {len(gen_heavy_counts_simple)} heavy, {len(gen_light_counts_simple)} light")
    print(f"Generated - Families: {len(gen_heavy_counts_family)} heavy, {len(gen_light_counts_family)} light")
    
    if true_light_vgene_col:
        true_heavy_counts_full = Counter([pair[0] for pair in true_pairs_full])
        true_light_counts_full = Counter([pair[1] for pair in true_pairs_full])
        true_heavy_counts_simple = Counter([pair[0] for pair in true_pairs_simple if not pd.isna(pair[0])])
        true_light_counts_simple = Counter([pair[1] for pair in true_pairs_simple if not pd.isna(pair[1])])
        true_heavy_counts_family = Counter([pair[0] for pair in true_pairs_family if not pd.isna(pair[0])])
        true_light_counts_family = Counter([pair[1] for pair in true_pairs_family if not pd.isna(pair[1])])
        
        print(f"True - Full Names: {len(true_heavy_counts_full)} heavy, {len(true_light_counts_full)} light")
        print(f"True - Simplified: {len(true_heavy_counts_simple)} heavy, {len(true_light_counts_simple)} light")
        print(f"True - Families: {len(true_heavy_counts_family)} heavy, {len(true_light_counts_family)} light")
    
    # ============================
    # CONSISTENCY ANALYSIS FUNCTIONS
    # ============================
    
    def analyze_heavy_light_consistency(vgene_dist, analysis_type, min_associations=5):
        """Analyze consistency of heavy-light V gene associations"""
        
        consistency_results = []
        
        for heavy_gene, light_dist in vgene_dist.items():
            if len(light_dist) == 0:
                continue
                
            total_associations = sum(light_dist.values())
            
            # Skip genes with too few associations
            if total_associations < min_associations:
                continue
            
            # Calculate diversity metrics
            most_common_light = max(light_dist.items(), key=lambda x: x[1])
            most_common_count = most_common_light[1]
            most_common_pct = (most_common_count / total_associations) * 100
            
            # Calculate entropy (lower = more consistent)
            proportions = np.array(list(light_dist.values())) / total_associations
            entropy = -np.sum(proportions * np.log2(proportions + 1e-10))
            
            # Calculate Gini coefficient (higher = more concentrated)
            sorted_counts = sorted(light_dist.values(), reverse=True)
            n = len(sorted_counts)
            cumsum = np.cumsum(sorted_counts)
            gini = (n + 1 - 2 * np.sum(cumsum) / cumsum[-1]) / n if cumsum[-1] > 0 else 0
            
            # Consistency classification
            if most_common_pct >= 95:
                consistency_level = "Perfect"
            elif most_common_pct >= 90:
                consistency_level = "Very High"
            elif most_common_pct >= 75:
                consistency_level = "High"
            elif most_common_pct >= 50:
                consistency_level = "Moderate"
            else:
                consistency_level = "Low"
            
            consistency_results.append({
                'heavy_gene': heavy_gene,
                'total_associations': total_associations,
                'unique_light_genes': len(light_dist),
                'most_common_light': most_common_light[0],
                'most_common_count': most_common_count,
                'most_common_pct': most_common_pct,
                'entropy': entropy,
                'gini_coefficient': gini,
                'consistency_level': consistency_level,
                'light_distribution': dict(light_dist)
            })
        
        # Sort by consistency (highest percentage first)
        consistency_results.sort(key=lambda x: x['most_common_pct'], reverse=True)
        
        return consistency_results
    
    def print_consistency_results(consistency_results, analysis_type):
        """Print formatted consistency results"""
        
        print(f"\n{analysis_type} Consistency Results:")
        print(f"{'Heavy V Gene':<25} {'Most Common Light':<25} {'Count':<8} {'%':<8} {'Unique':<8} {'Level':<12} {'Entropy':<8}")
        print("-" * 110)
        
        for result in consistency_results[:25]:  # Show top 25
            print(f"{result['heavy_gene']:<25} {result['most_common_light']:<25} "
                  f"{result['most_common_count']:<8} {result['most_common_pct']:<8.1f} "
                  f"{result['unique_light_genes']:<8} {result['consistency_level']:<12} {result['entropy']:<8.2f}")
        
        # Summary statistics
        if consistency_results:
            perfect = len([r for r in consistency_results if r['consistency_level'] == 'Perfect'])
            very_high = len([r for r in consistency_results if r['consistency_level'] == 'Very High'])
            high = len([r for r in consistency_results if r['consistency_level'] == 'High'])
            moderate = len([r for r in consistency_results if r['consistency_level'] == 'Moderate'])
            low = len([r for r in consistency_results if r['consistency_level'] == 'Low'])
            
            total = len(consistency_results)
            
            print(f"\nConsistency Summary ({total} heavy V genes analyzed):")
            print(f"  Perfect (≥95%): {perfect} genes ({perfect/total*100:.1f}%)")
            print(f"  Very High (90-94%): {very_high} genes ({very_high/total*100:.1f}%)")
            print(f"  High (75-89%): {high} genes ({high/total*100:.1f}%)")
            print(f"  Moderate (50-74%): {moderate} genes ({moderate/total*100:.1f}%)")
            print(f"  Low (<50%): {low} genes ({low/total*100:.1f}%)")
            
            # Find perfect/near-perfect matches
            perfect_matches = [r for r in consistency_results if r['most_common_pct'] >= 95]
            if perfect_matches:
                print(f"\n🎯 PERFECT/NEAR-PERFECT MATCHES (≥95% consistency):")
                for match in perfect_matches[:15]:  # Top 15
                    print(f"  {match['heavy_gene']} → {match['most_common_light']} "
                          f"({match['most_common_pct']:.1f}%, {match['total_associations']} total)")
            
            # Find highly consistent genes
            highly_consistent = [r for r in consistency_results if 90 <= r['most_common_pct'] < 95]
            if highly_consistent:
                print(f"\n⭐ HIGHLY CONSISTENT (90-94%):")
                for match in highly_consistent[:10]:  # Top 10
                    print(f"  {match['heavy_gene']} → {match['most_common_light']} "
                          f"({match['most_common_pct']:.1f}%, {match['total_associations']} total)")
    
    def create_consistency_heatmap(consistency_results, title, top_n=20):
        """Create a heatmap showing the most consistent heavy-light associations"""
        
        if not consistency_results:
            print(f"No data for {title}")
            return
        
        # Get top N most consistent heavy genes
        top_results = consistency_results[:top_n]
        
        if len(top_results) < 2:
            print(f"Insufficient data for {title}")
            return
        
        # Create matrix for heatmap
        heavy_genes = [r['heavy_gene'] for r in top_results]
        all_light_genes = set()
        for r in top_results:
            all_light_genes.update(r['light_distribution'].keys())
        all_light_genes = sorted(list(all_light_genes))
        
        # Create consistency matrix (percentages)
        consistency_matrix = np.zeros((len(heavy_genes), len(all_light_genes)))
        
        for i, result in enumerate(top_results):
            total = result['total_associations']
            for j, light_gene in enumerate(all_light_genes):
                count = result['light_distribution'].get(light_gene, 0)
                consistency_matrix[i, j] = (count / total) * 100 if total > 0 else 0
        
        # Create heatmap
        fig, ax = plt.subplots(1, 1, figsize=(max(14, len(all_light_genes) * 0.8), max(10, len(heavy_genes) * 0.6)))
        
        im = ax.imshow(consistency_matrix, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=100)
        ax.set_xticks(range(len(all_light_genes)))
        ax.set_yticks(range(len(heavy_genes)))
        ax.set_xticklabels(all_light_genes, rotation=45, ha='right', fontsize=8)
        ax.set_yticklabels(heavy_genes, fontsize=8)
        ax.set_title(f'{title}\n(% of associations for each heavy V gene)')
        ax.set_xlabel('Light V Genes')
        ax.set_ylabel('Heavy V Genes (sorted by consistency)')
        
        # Add annotations for high percentages
        for i in range(len(heavy_genes)):
            for j in range(len(all_light_genes)):
                value = consistency_matrix[i, j]
                if value >= 30:  # Only annotate if ≥30%
                    text_color = 'white' if value > 50 else 'black'
                    ax.text(j, i, f'{value:.0f}%', ha='center', va='center', 
                           fontsize=7, color=text_color, weight='bold')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Percentage of Associations (%)')
        
        plt.tight_layout()
        plt.show()
    
    # ============================
    # RUN CONSISTENCY ANALYSIS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("V GENE CONSISTENCY ANALYSIS")
    print("="*50)
    
    # Analyze generated V gene consistency at all levels
    gen_consistency_full = analyze_heavy_light_consistency(gen_vgene_dist_full, "Generated Full Names")
    gen_consistency_simple = analyze_heavy_light_consistency(gen_vgene_dist_simple, "Generated Simplified")
    gen_consistency_family = analyze_heavy_light_consistency(gen_vgene_dist_family, "Generated Families")
    
    print_consistency_results(gen_consistency_full, "GENERATED - FULL NAMES")
    print_consistency_results(gen_consistency_simple, "GENERATED - SIMPLIFIED")
    print_consistency_results(gen_consistency_family, "GENERATED - FAMILIES")
    
    # Analyze true V gene consistency if available
    if true_light_vgene_col:
        true_consistency_full = analyze_heavy_light_consistency(true_vgene_dist_full, "True Full Names")
        true_consistency_simple = analyze_heavy_light_consistency(true_vgene_dist_simple, "True Simplified")
        true_consistency_family = analyze_heavy_light_consistency(true_vgene_dist_family, "True Families")
        
        print_consistency_results(true_consistency_full, "TRUE - FULL NAMES")
        print_consistency_results(true_consistency_simple, "TRUE - SIMPLIFIED")
        print_consistency_results(true_consistency_family, "TRUE - FAMILIES")
    
    # ============================
    # CREATE CONSISTENCY HEATMAPS
    # ============================
    
    print("\n" + "="*80 + "\n")
    print("CREATING CONSISTENCY HEATMAPS...")
    
    # Generated consistency heatmaps
    create_consistency_heatmap(gen_consistency_full, "Generated V Gene Consistency (Full Names)")
    create_consistency_heatmap(gen_consistency_simple, "Generated V Gene Consistency (Simplified)")
    create_consistency_heatmap(gen_consistency_family, "Generated V Gene Consistency (Families)")
    
    # True consistency heatmaps (if available)
    if true_light_vgene_col:
        create_consistency_heatmap(true_consistency_full, "True V Gene Consistency (Full Names)")
        create_consistency_heatmap(true_consistency_simple, "True V Gene Consistency (Simplified)")
        create_consistency_heatmap(true_consistency_family, "True V Gene Consistency (Families)")
    
    # ============================
    # COMPARE GENERATED VS TRUE CONSISTENCY
    # ============================
    
    if true_light_vgene_col:
        print("\n" + "="*80 + "\n")
        print("GENERATED vs TRUE CONSISTENCY COMPARISON")
        print("="*50)
        
        def compare_consistency(gen_results, true_results, level_name):
            print(f"\n{level_name} Consistency Comparison:")
            
            # Create dictionaries for easier lookup
            gen_dict = {r['heavy_gene']: r for r in gen_results}
            true_dict = {r['heavy_gene']: r for r in true_results}
            
            # Find common heavy genes
            common_genes = set(gen_dict.keys()).intersection(set(true_dict.keys()))
            
            if len(common_genes) == 0:
                print("  No common heavy genes found")
                return
            
            print(f"  Common heavy genes: {len(common_genes)}")
            
            # Compare consistency levels
            gen_high_consistency = len([r for r in gen_results if r['most_common_pct'] >= 90])
            true_high_consistency = len([r for r in true_results if r['most_common_pct'] >= 90])
            
            print(f"  High consistency (≥90%): Generated={gen_high_consistency}, True={true_high_consistency}")
            
            # Find genes with matching top light chains
            matching_top_light = 0
            for gene in common_genes:
                if gen_dict[gene]['most_common_light'] == true_dict[gene]['most_common_light']:
                    matching_top_light += 1
            
            print(f"  Genes with matching top light chains: {matching_top_light}/{len(common_genes)} ({matching_top_light/len(common_genes)*100:.1f}%)")
            
            # Show top mismatches
            print(f"\n  Top mismatches (different preferred light chains):")
            mismatches = []
            for gene in common_genes:
                if gen_dict[gene]['most_common_light'] != true_dict[gene]['most_common_light']:
                    mismatches.append({
                        'heavy': gene,
                        'gen_light': gen_dict[gene]['most_common_light'],
                        'true_light': true_dict[gene]['most_common_light'],
                        'gen_pct': gen_dict[gene]['most_common_pct'],
                        'true_pct': true_dict[gene]['most_common_pct']
                    })
            
            # Sort by generated consistency (most consistent first)
            mismatches.sort(key=lambda x: x['gen_pct'], reverse=True)
            
            for mismatch in mismatches[:10]:  # Show top 10
                print(f"    {mismatch['heavy']}: Gen→{mismatch['gen_light']} ({mismatch['gen_pct']:.1f}%) vs True→{mismatch['true_light']} ({mismatch['true_pct']:.1f}%)")
        
        # Compare at all levels
        compare_consistency(gen_consistency_full, true_consistency_full, "Full Names")
        compare_consistency(gen_consistency_simple, true_consistency_simple, "Simplified")
        compare_consistency(gen_consistency_family, true_consistency_family, "Families")
    
    # ============================
    # RETURN RESULTS
    # ============================
    
    return {
        'dataset_info': {
            'total_rows': len(df),
            'heavy_gene_col': heavy_gene_col,
            'true_light_vgene_col': true_light_vgene_col,
            'gen_light_cols': light_gene_cols,
            'gen_pairs_count': len(gen_pairs_full),
            'true_pairs_count': len(true_pairs_full) if true_light_vgene_col else 0
        },
        'generated_consistency': {
            'full': gen_consistency_full,
            'simple': gen_consistency_simple,
            'family': gen_consistency_family
        },
        'true_consistency': {
            'full': true_consistency_full if true_light_vgene_col else None,
            'simple': true_consistency_simple if true_light_vgene_col else None,
            'family': true_consistency_family if true_light_vgene_col else None
        } if true_light_vgene_col else None,
        'vgene_distributions': {
            'generated': {
                'full': dict(gen_vgene_dist_full),
                'simple': dict(gen_vgene_dist_simple),
                'family': dict(gen_vgene_dist_family)
            },
            'true': {
                'full': dict(true_vgene_dist_full),
                'simple': dict(true_vgene_dist_simple),
                'family': dict(true_vgene_dist_family)
            } if true_light_vgene_col else None
        }
    }


In [None]:

# Example usage
csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv'
    
try:
    results = analyze_vgene_consistency(csv_file)
    if results:
        print("\nV Gene Consistency Analysis completed successfully!")
        print("Results show which heavy V genes consistently generate the same light V genes.")
        
except FileNotFoundError:
    print(f"Error: Could not find the file '{csv_file}'")
    print("Please make sure the file path is correct.")
except Exception as e:
    print(f"An error occurred: {str(e)}")

In [None]:
import pandas as pd
import re

def merge_gene_names_from_split_files(first_csv_path, matching_csv_path, non_matching_csv_path, output_path):
    """
    Merge gene_name information from matching and non-matching CSV files to first CSV based on sequence_id patterns.
    
    Args:
        first_csv_path (str): Path to the first CSV file with heavy/light chain data
        matching_csv_path (str): Path to the matching sequences CSV file with gene_name information
        non_matching_csv_path (str): Path to the non-matching sequences CSV file with gene_name information
        output_path (str): Path for the output merged CSV file
    """
    
    # Read the CSV files
    df1 = pd.read_csv(first_csv_path)
    df_matching = pd.read_csv(matching_csv_path)
    df_non_matching = pd.read_csv(non_matching_csv_path)
    
    # Combine both matching and non-matching dataframes
    df2_combined = pd.concat([df_matching, df_non_matching], ignore_index=True)
    
    # Create a mapping dictionary from sequence_id to gene_name
    gene_name_mapping = dict(zip(df2_combined['sequence_id'], df2_combined['gene_name']))
    
    # Initialize new columns for gene names
    df1['gen_light_gene_name'] = None
    df1['true_light_gene_name'] = None
    df1['heavy_chain_gene_name'] = None
    
    # Process each row in the first dataframe
    for idx, row in df1.iterrows():
        heavy_chain_num = row['heavy_chain_number']
        gen_light_num = row['gen_light_chain_number']
        
        # Map generated light chain gene name
        gen_light_id = f"gen_light_{gen_light_num}_heavy_chain_{heavy_chain_num}"
        if gen_light_id in gene_name_mapping:
            df1.at[idx, 'gen_light_gene_name'] = gene_name_mapping[gen_light_id]
        
        # Map true light chain gene name
        true_light_id = f"true_light_chain_heavy_chain_{heavy_chain_num}"
        if true_light_id in gene_name_mapping:
            df1.at[idx, 'true_light_gene_name'] = gene_name_mapping[true_light_id]
        
        # Map heavy chain gene name
        heavy_chain_id = f"heavy_chain_{heavy_chain_num}"
        if heavy_chain_id in gene_name_mapping:
            df1.at[idx, 'heavy_chain_gene_name'] = gene_name_mapping[heavy_chain_id]
    
    # Save the merged dataframe
    df1.to_csv(output_path, index=False)
    
    # Print summary statistics
    print(f"Total rows processed: {len(df1)}")
    print(f"Generated light chain gene names found: {df1['gen_light_gene_name'].notna().sum()}")
    print(f"True light chain gene names found: {df1['true_light_gene_name'].notna().sum()}")
    print(f"Heavy chain gene names found: {df1['heavy_chain_gene_name'].notna().sum()}")
    print(f"Output saved to: {output_path}")
    
    return df1

def merge_gene_names_single_column_from_split_files(first_csv_path, matching_csv_path, non_matching_csv_path, output_path):
    """
    Merge gene_name information from matching and non-matching CSV files to first CSV, focusing on generated light chains.
    
    Args:
        first_csv_path (str): Path to the first CSV file with heavy/light chain data
        matching_csv_path (str): Path to the matching sequences CSV file with gene_name information
        non_matching_csv_path (str): Path to the non-matching sequences CSV file with gene_name information
        output_path (str): Path for the output merged CSV file
    """
    
    # Read the CSV files
    df1 = pd.read_csv(first_csv_path)
    df_matching = pd.read_csv(matching_csv_path)
    df_non_matching = pd.read_csv(non_matching_csv_path)
    
    # Combine both matching and non-matching dataframes
    df2_combined = pd.concat([df_matching, df_non_matching], ignore_index=True)
    
    # Create a mapping dictionary from sequence_id to gene_name
    gene_name_mapping = dict(zip(df2_combined['sequence_id'], df2_combined['gene_name']))
    
    # Initialize new column for gene names
    df1['gene_name'] = None
    
    # Process each row in the first dataframe
    for idx, row in df1.iterrows():
        heavy_chain_num = row['heavy_chain_number']
        gen_light_num = row['gen_light_chain_number']
        
        # Map generated light chain gene name
        gen_light_id = f"gen_light_{gen_light_num}_heavy_chain_{heavy_chain_num}"
        if gen_light_id in gene_name_mapping:
            df1.at[idx, 'gene_name'] = gene_name_mapping[gen_light_id]
    
    # Save the merged dataframe
    df1.to_csv(output_path, index=False)
    
    # Print summary statistics
    print(f"Total rows processed: {len(df1)}")
    print(f"Gene names found: {df1['gene_name'].notna().sum()}")
    print(f"Output saved to: {output_path}")
    
    return df1

def analyze_sequence_coverage(first_csv_path, matching_csv_path, non_matching_csv_path):
    """
    Analyze which sequences from the first CSV are covered in the matching/non-matching files.
    Useful for debugging and understanding data coverage.
    
    Args:
        first_csv_path (str): Path to the first CSV file with heavy/light chain data
        matching_csv_path (str): Path to the matching sequences CSV file
        non_matching_csv_path (str): Path to the non-matching sequences CSV file
    """
    
    # Read the CSV files
    df1 = pd.read_csv(first_csv_path)
    df_matching = pd.read_csv(matching_csv_path)
    df_non_matching = pd.read_csv(non_matching_csv_path)
    
    # Combine both matching and non-matching dataframes
    df2_combined = pd.concat([df_matching, df_non_matching], ignore_index=True)
    
    # Get all sequence IDs from the combined data
    available_sequence_ids = set(df2_combined['sequence_id'])
    
    # Create expected sequence IDs from the first CSV
    expected_gen_light_ids = set()
    expected_true_light_ids = set()
    expected_heavy_ids = set()
    
    for _, row in df1.iterrows():
        heavy_chain_num = row['heavy_chain_number']
        gen_light_num = row['gen_light_chain_number']
        
        expected_gen_light_ids.add(f"gen_light_{gen_light_num}_heavy_chain_{heavy_chain_num}")
        expected_true_light_ids.add(f"true_light_chain_heavy_chain_{heavy_chain_num}")
        expected_heavy_ids.add(f"heavy_chain_{heavy_chain_num}")
    
    # Check coverage
    print("=== SEQUENCE COVERAGE ANALYSIS ===")
    print(f"Total rows in first CSV: {len(df1)}")
    print(f"Total sequences in matching file: {len(df_matching)}")
    print(f"Total sequences in non-matching file: {len(df_non_matching)}")
    print(f"Total combined sequences: {len(df2_combined)}")
    
    print(f"\nGenerated light chain coverage:")
    print(f"  Expected: {len(expected_gen_light_ids)}")
    print(f"  Found: {len(expected_gen_light_ids & available_sequence_ids)}")
    print(f"  Missing: {len(expected_gen_light_ids - available_sequence_ids)}")
    
    print(f"\nTrue light chain coverage:")
    print(f"  Expected: {len(expected_true_light_ids)}")
    print(f"  Found: {len(expected_true_light_ids & available_sequence_ids)}")
    print(f"  Missing: {len(expected_true_light_ids - available_sequence_ids)}")
    
    print(f"\nHeavy chain coverage:")
    print(f"  Expected: {len(expected_heavy_ids)}")
    print(f"  Found: {len(expected_heavy_ids & available_sequence_ids)}")
    print(f"  Missing: {len(expected_heavy_ids - available_sequence_ids)}")
    
    # Show some examples of missing sequences (if any)
    missing_gen_light = expected_gen_light_ids - available_sequence_ids
    if missing_gen_light:
        print(f"\nExample missing generated light chains: {list(missing_gen_light)[:5]}")
    
    missing_true_light = expected_true_light_ids - available_sequence_ids
    if missing_true_light:
        print(f"Example missing true light chains: {list(missing_true_light)[:5]}")
    
    missing_heavy = expected_heavy_ids - available_sequence_ids
    if missing_heavy:
        print(f"Example missing heavy chains: {list(missing_heavy)[:5]}")



In [None]:

first_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions.csv"
matching_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv"
non_matching_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv"
output_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv"

# First, analyze coverage to understand your data
analyze_sequence_coverage(first_csv, matching_csv, non_matching_csv)
    
# Option 1: Add separate gene name columns for each sequence type
merged_df = merge_gene_names_from_split_files(first_csv, matching_csv, non_matching_csv, output_csv)
    
# Option 2: Add only generated light chain gene names
# merged_df = merge_gene_names_single_column_from_split_files(first_csv, matching_csv, non_matching_csv, output_csv)
    
# Display first few rows to verify
print("\nFirst few rows of merged data:")
print(merged_df.head())

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re
from scipy.stats import entropy

In [None]:
df = pd.read_csv('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv')

print("Dataset Overview:")
print(f"Total rows: {len(df)}")
print(f"Unique heavy chains: {df['heavy_chain_number'].nunique()}")
print(f"Generated light chains per heavy chain: {df['gen_light_chain_number'].max()}")
print(f"Unique heavy V genes: {df['heavy_chain_gene_name'].nunique()}")
print(f"Unique generated light V genes: {df['gen_light_gene_name'].nunique()}")
print(f"Unique true light V genes: {df['true_light_gene_name'].nunique()}")


In [None]:
# Function to simplify gene names (remove allele info)
def simplify_gene_name(gene_name):
    if pd.isna(gene_name):
        return gene_name
    # Remove allele information (e.g., *01, *02)
    return re.sub(r'\*\d+', '', str(gene_name))

# Function to get gene family (e.g., IGHV3 from IGHV3-23)
def get_gene_family(gene_name):
    if pd.isna(gene_name):
        return gene_name
    # Extract family (e.g., IGHV3 from IGHV3-23*01)
    match = re.match(r'([A-Z]+\d+)', str(gene_name))
    return match.group(1) if match else str(gene_name)

# Add simplified and family columns
df['heavy_gene_simplified'] = df['heavy_chain_gene_name'].apply(simplify_gene_name)
df['heavy_gene_family'] = df['heavy_chain_gene_name'].apply(get_gene_family)
df['gen_light_gene_simplified'] = df['gen_light_gene_name'].apply(simplify_gene_name)
df['gen_light_gene_family'] = df['gen_light_gene_name'].apply(get_gene_family)
df['true_light_gene_simplified'] = df['true_light_gene_name'].apply(simplify_gene_name)
df['true_light_gene_family'] = df['true_light_gene_name'].apply(get_gene_family)


In [None]:

print("\n" + "="*80)
print("V GENE CONSISTENCY ANALYSIS")
print("="*80)

def analyze_consistency(heavy_col, light_col, analysis_name):
    """Analyze consistency between heavy and light V genes"""
    print(f"\n{analysis_name}")
    print("-" * len(analysis_name))
    
    # Calculate consistency for each heavy V gene
    consistency_results = []
    
    for heavy_gene in df[heavy_col].dropna().unique():
        heavy_data = df[df[heavy_col] == heavy_gene]
        
        # Count light gene associations
        light_counts = heavy_data[light_col].value_counts()
        
        if len(light_counts) > 0:
            total_associations = len(heavy_data)
            most_common_light = light_counts.index[0]
            most_common_count = light_counts.iloc[0]
            consistency_pct = (most_common_count / total_associations) * 100
            unique_light_genes = len(light_counts)
            
            # Calculate entropy (diversity measure)
            probs = light_counts.values / total_associations
            gene_entropy = entropy(probs)
            
            consistency_results.append({
                'heavy_gene': heavy_gene,
                'most_common_light': most_common_light,
                'consistency_pct': consistency_pct,
                'most_common_count': most_common_count,
                'total_associations': total_associations,
                'unique_light_genes': unique_light_genes,
                'entropy': gene_entropy,
                'light_distribution': dict(light_counts)
            })
    
    consistency_df = pd.DataFrame(consistency_results)
    consistency_df = consistency_df.sort_values('consistency_pct', ascending=False)
    
    # Categorize consistency levels
    def categorize_consistency(pct):
        if pct >= 95:
            return 'Perfect (≥95%)'
        elif pct >= 90:
            return 'Very High (90-94%)'
        elif pct >= 75:
            return 'High (75-89%)'
        elif pct >= 50:
            return 'Moderate (50-74%)'
        else:
            return 'Low (<50%)'
    
    consistency_df['consistency_category'] = consistency_df['consistency_pct'].apply(categorize_consistency)
    
    # Summary statistics
    print(f"\nConsistency Level Distribution:")
    category_counts = consistency_df['consistency_category'].value_counts()
    for category in ['Perfect (≥95%)', 'Very High (90-94%)', 'High (75-89%)', 'Moderate (50-74%)', 'Low (<50%)']:
        count = category_counts.get(category, 0)
        pct = (count / len(consistency_df)) * 100 if len(consistency_df) > 0 else 0
        print(f"  {category}: {count} genes ({pct:.1f}%)")
    
    print(f"\nOverall Statistics:")
    print(f"  Average consistency: {consistency_df['consistency_pct'].mean():.1f}%")
    print(f"  Median consistency: {consistency_df['consistency_pct'].median():.1f}%")
    print(f"  Average unique light genes per heavy: {consistency_df['unique_light_genes'].mean():.1f}")
    print(f"  Average entropy: {consistency_df['entropy'].mean():.3f}")
    
    # Show top consistent genes
    print(f"\nTop 10 Most Consistent Heavy V Genes:")
    top_consistent = consistency_df.head(10)
    for _, row in top_consistent.iterrows():
        print(f"  {row['heavy_gene']}: {row['consistency_pct']:.1f}% → {row['most_common_light']} "
              f"({row['most_common_count']}/{row['total_associations']})")
    
    # Show most diverse genes
    print(f"\nTop 10 Most Diverse Heavy V Genes:")
    most_diverse = consistency_df.nlargest(10, 'unique_light_genes')
    for _, row in most_diverse.iterrows():
        print(f"  {row['heavy_gene']}: {row['unique_light_genes']} different light genes, "
              f"{row['consistency_pct']:.1f}% consistency")
    
    return consistency_df

In [None]:
# Analyze at different levels
print("\n" + "="*50)
print("1. FULL NAME ANALYSIS (Allele-Specific)")
print("="*50)
gen_full_consistency = analyze_consistency('heavy_chain_gene_name', 'gen_light_gene_name', 
                                          "Generated Light Chains - Full Names")

print("\n" + "="*50)
print("2. SIMPLIFIED ANALYSIS (Gene-Level)")
print("="*50)
gen_simplified_consistency = analyze_consistency('heavy_gene_simplified', 'gen_light_gene_simplified', 
                                                "Generated Light Chains - Simplified")

print("\n" + "="*50)
print("3. FAMILY ANALYSIS (Family-Level)")
print("="*50)
gen_family_consistency = analyze_consistency('heavy_gene_family', 'gen_light_gene_family', 
                                            "Generated Light Chains - Family Level")

print("\n" + "="*50)
print("4. TRUE LIGHT CHAIN ANALYSIS")
print("="*50)

print("\n" + "="*40)
print("4a. TRUE - FULL NAME ANALYSIS")
print("="*40)
true_full_consistency = analyze_consistency('heavy_chain_gene_name', 'true_light_gene_name', 
                                           "True Light Chains - Full Names")

print("\n" + "="*40)
print("4b. TRUE - SIMPLIFIED ANALYSIS")
print("="*40)
true_simplified_consistency = analyze_consistency('heavy_gene_simplified', 'true_light_gene_simplified', 
                                                 "True Light Chains - Simplified")

print("\n" + "="*40)
print("4c. TRUE - FAMILY ANALYSIS")
print("="*40)
true_family_consistency = analyze_consistency('heavy_gene_family', 'true_light_gene_family', 
                                             "True Light Chains - Family Level")

# Compare generated vs true preferences
print("\n" + "="*50)
print("5. GENERATED vs TRUE COMPARISON")
print("="*50)

# Merge generated and true consistency results
comparison_df = pd.merge(
    gen_full_consistency[['heavy_gene', 'most_common_light', 'consistency_pct']],
    true_full_consistency[['heavy_gene', 'most_common_light', 'consistency_pct']],
    on='heavy_gene', suffixes=('_gen', '_true'), how='outer'
)

# Find matching preferences
comparison_df['preferences_match'] = comparison_df['most_common_light_gen'] == comparison_df['most_common_light_true']
comparison_df['consistency_diff'] = comparison_df['consistency_pct_gen'] - comparison_df['consistency_pct_true']

print(f"\nPreference Matching:")
matches = comparison_df['preferences_match'].sum()
total = len(comparison_df)
print(f"  Heavy V genes with matching preferences: {matches}/{total} ({matches/total*100:.1f}%)")

print(f"\nConsistency Comparison:")
gen_avg = comparison_df['consistency_pct_gen'].mean()
true_avg = comparison_df['consistency_pct_true'].mean()
print(f"  Generated average consistency: {gen_avg:.1f}%")
print(f"  True average consistency: {true_avg:.1f}%")
print(f"  Difference: {gen_avg - true_avg:.1f} percentage points")

print(f"\nTop Mismatches (Generated ≠ True Preferences):")
mismatches = comparison_df[~comparison_df['preferences_match']].head(10)
for _, row in mismatches.iterrows():
    print(f"  {row['heavy_gene']}: Gen→{row['most_common_light_gen']} vs True→{row['most_common_light_true']}")


In [None]:
# Create comparison visualization heatmaps
def create_comparison_heatmap(comparison_df, metric='consistency_pct', top_n=20):
    """Create side-by-side heatmaps comparing generated vs true"""
    
    # Filter to top heavy genes by frequency and clean data
    valid_comparison = comparison_df.dropna(subset=[f'{metric}_gen', f'{metric}_true'])
    
    # Calculate total occurrences for each heavy gene to rank by importance
    heavy_gene_counts = pd.concat([
        df.groupby('heavy_chain_gene_name').size(),
    ]).sort_values(ascending=False)
    
    # Get top heavy genes that appear in comparison
    top_heavy_genes = [gene for gene in heavy_gene_counts.index 
                      if gene in valid_comparison['heavy_gene'].values][:top_n]
    
    # Filter comparison data
    plot_data = valid_comparison[valid_comparison['heavy_gene'].isin(top_heavy_genes)]
    
    # Create the visualization
    fig, axes = plt.subplots(1, 3, figsize=(20, 10))
    
    # Prepare data for heatmaps
    plot_data_sorted = plot_data.sort_values(f'{metric}_true', ascending=False)
    
    # Heatmap 1: Generated consistency
    gen_data = plot_data_sorted.set_index('heavy_gene')[[f'{metric}_gen']]
    sns.heatmap(gen_data, annot=True, fmt='.1f', cmap='RdYlBu_r', 
                cbar_kws={'label': 'Consistency (%)'}, ax=axes[0])
    axes[0].set_title('Generated Light Chain\nConsistency (%)')
    axes[0].set_xlabel('')
    axes[0].set_ylabel('Heavy V Gene')
    
    # Heatmap 2: True consistency  
    true_data = plot_data_sorted.set_index('heavy_gene')[[f'{metric}_true']]
    sns.heatmap(true_data, annot=True, fmt='.1f', cmap='RdYlBu_r', 
                cbar_kws={'label': 'Consistency (%)'}, ax=axes[1])
    axes[1].set_title('True Light Chain\nConsistency (%)')
    axes[1].set_xlabel('')
    axes[1].set_ylabel('')
    
    # Heatmap 3: Difference (Generated - True)
    diff_data = plot_data_sorted.set_index('heavy_gene')[['consistency_diff']]
    sns.heatmap(diff_data, annot=True, fmt='.1f', cmap='RdBu_r', center=0,
                cbar_kws={'label': 'Difference (%)'}, ax=axes[2])
    axes[2].set_title('Difference\n(Generated - True)')
    axes[2].set_xlabel('')
    axes[2].set_ylabel('')
    
    plt.tight_layout()
    plt.show()
    
    return plot_data_sorted

def create_preference_matching_heatmap(comparison_df, top_n=20):
    """Create heatmap showing preference matching patterns"""
    
    # Get top heavy genes by frequency
    heavy_gene_counts = df.groupby('heavy_chain_gene_name').size().sort_values(ascending=False)
    top_heavy_genes = [gene for gene in heavy_gene_counts.index 
                      if gene in comparison_df['heavy_gene'].values][:top_n]
    
    # Filter and prepare data
    plot_data = comparison_df[comparison_df['heavy_gene'].isin(top_heavy_genes)].copy()
    
    # Create matching status and consistency level categories
    plot_data['match_status'] = plot_data['preferences_match'].map({True: 'Match', False: 'Mismatch'})
    
    # Create performance categories
    def performance_category(row):
        if pd.isna(row['consistency_pct_gen']) or pd.isna(row['consistency_pct_true']):
            return 'No Data'
        
        gen_pct = row['consistency_pct_gen']
        true_pct = row['consistency_pct_true']
        diff = abs(gen_pct - true_pct)
        
        if row['preferences_match']:
            if diff <= 10:
                return 'Excellent Match'
            elif diff <= 20:
                return 'Good Match'
            else:
                return 'Poor Match'
        else:
            return 'Wrong Preference'
    
    plot_data['performance'] = plot_data.apply(performance_category, axis=1)
    
    # Create summary matrix for heatmap
    summary_data = []
    for _, row in plot_data.iterrows():
        summary_data.append({
            'heavy_gene': row['heavy_gene'],
            'generated_consistency': row['consistency_pct_gen'],
            'true_consistency': row['consistency_pct_true'],
            'performance': row['performance'],
            'gen_light': row['most_common_light_gen'],
            'true_light': row['most_common_light_true']
        })
    
    summary_df = pd.DataFrame(summary_data)
    
    # Sort by true consistency for better visualization
    summary_df = summary_df.sort_values('true_consistency', ascending=False)
    
    # Create the heatmap
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create performance score for color mapping
    performance_scores = {
        'Excellent Match': 4,
        'Good Match': 3, 
        'Poor Match': 2,
        'Wrong Preference': 1,
        'No Data': 0
    }
    
    summary_df['performance_score'] = summary_df['performance'].map(performance_scores)
    
    # Create heatmap data
    heatmap_data = summary_df.set_index('heavy_gene')[['performance_score']]
    
    # Create custom colormap
    colors = ['lightgray', 'red', 'orange', 'lightgreen', 'darkgreen']
    cmap = plt.matplotlib.colors.ListedColormap(colors)
    
    sns.heatmap(heatmap_data, annot=False, cmap=cmap, 
                cbar_kws={'label': 'Performance Level'}, ax=ax)
    
    # Add text annotations
    for i, (_, row) in enumerate(summary_df.iterrows()):
        gen_pct = row['generated_consistency']
        true_pct = row['true_consistency']
        if pd.notna(gen_pct) and pd.notna(true_pct):
            ax.text(0.5, i + 0.5, f'{gen_pct:.1f}% vs {true_pct:.1f}%', 
                   ha='center', va='center', fontsize=8, weight='bold')
    
    ax.set_title('Model Performance: Generated vs True V Gene Pairing\n' +
                'Green=Excellent Match, Light Green=Good Match, Orange=Poor Match, Red=Wrong Preference')
    ax.set_xlabel('')
    ax.set_ylabel('Heavy V Gene (sorted by true consistency)')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\nPerformance Summary:")
    perf_counts = summary_df['performance'].value_counts()
    for perf, count in perf_counts.items():
        pct = (count / len(summary_df)) * 100
        print(f"  {perf}: {count} genes ({pct:.1f}%)")
    
    return summary_df

def create_simple_consistency_comparison(gen_consistency_df, true_consistency_df, top_n=15):
    """Create a simple bar chart comparing generated vs true consistency percentages"""
    
    # Merge the data
    comparison = pd.merge(
        gen_consistency_df[['heavy_gene', 'consistency_pct', 'most_common_light']],
        true_consistency_df[['heavy_gene', 'consistency_pct', 'most_common_light']],
        on='heavy_gene', suffixes=('_gen', '_true'), how='outer'
    )
    
    # Fill NaN values with 0 for missing data
    comparison['consistency_pct_gen'] = comparison['consistency_pct_gen'].fillna(0)
    comparison['consistency_pct_true'] = comparison['consistency_pct_true'].fillna(0)
    
    # Calculate which genes have matching preferences
    comparison['preferences_match'] = comparison['most_common_light_gen'] == comparison['most_common_light_true']
    
    # Get top genes by either generated or true consistency
    comparison['max_consistency'] = comparison[['consistency_pct_gen', 'consistency_pct_true']].max(axis=1)
    top_genes = comparison.nlargest(top_n, 'max_consistency')
    
    # Create the visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    
    # Plot 1: Bar chart comparison
    x = np.arange(len(top_genes))
    width = 0.35
    
    bars1 = ax1.bar(x - width/2, top_genes['consistency_pct_gen'], width, 
                    label='Generated', color='skyblue', alpha=0.8)
    bars2 = ax1.bar(x + width/2, top_genes['consistency_pct_true'], width, 
                    label='True', color='lightcoral', alpha=0.8)
    
    ax1.set_xlabel('Heavy V Gene')
    ax1.set_ylabel('Consistency Percentage (%)')
    ax1.set_title('Generated vs True Consistency Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels(top_genes['heavy_gene'], rotation=45, ha='right')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
        height1 = bar1.get_height()
        height2 = bar2.get_height()
        if height1 > 0:
            ax1.text(bar1.get_x() + bar1.get_width()/2., height1 + 1,
                    f'{height1:.1f}%', ha='center', va='bottom', fontsize=8)
        if height2 > 0:
            ax1.text(bar2.get_x() + bar2.get_width()/2., height2 + 1,
                    f'{height2:.1f}%', ha='center', va='bottom', fontsize=8)
    
    # Plot 2: Scatter plot with preference matching
    matching_genes = top_genes[top_genes['preferences_match']]
    mismatched_genes = top_genes[~top_genes['preferences_match']]
    
    ax2.scatter(matching_genes['consistency_pct_true'], matching_genes['consistency_pct_gen'], 
               color='green', s=100, alpha=0.7, label='Same Light Chain Preference')
    ax2.scatter(mismatched_genes['consistency_pct_true'], mismatched_genes['consistency_pct_gen'], 
               color='red', s=100, alpha=0.7, label='Different Light Chain Preference')
    
    # Add diagonal line for perfect match
    max_val = max(top_genes['consistency_pct_gen'].max(), top_genes['consistency_pct_true'].max())
    ax2.plot([0, max_val], [0, max_val], 'k--', alpha=0.5, label='Perfect Match')
    
    ax2.set_xlabel('True Consistency (%)')
    ax2.set_ylabel('Generated Consistency (%)')
    ax2.set_title('Consistency Correlation with Preference Matching')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Add labels for interesting points
    for _, row in top_genes.iterrows():
        if row['consistency_pct_gen'] > 30 or row['consistency_pct_true'] > 30:
            ax2.annotate(row['heavy_gene'], 
                        (row['consistency_pct_true'], row['consistency_pct_gen']),
                        xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed comparison
    print("\nDetailed Comparison (Top genes):")
    print("="*80)
    for _, row in top_genes.iterrows():
        gen_pct = row['consistency_pct_gen']
        true_pct = row['consistency_pct_true']
        gen_light = row['most_common_light_gen']
        true_light = row['most_common_light_true']
        match_status = "✓" if row['preferences_match'] else "✗"
        
        print(f"{row['heavy_gene']:<15} | Gen: {gen_pct:5.1f}% → {gen_light:<15} | "
              f"True: {true_pct:5.1f}% → {true_light:<15} | Match: {match_status}")
    
    return comparison

def create_top_genes_heatmap(gen_consistency_df, true_consistency_df, top_n=15):
    """Create a heatmap showing top genes with their consistency percentages"""
    
    # Get top genes from both datasets
    top_gen = gen_consistency_df.nlargest(top_n//2, 'consistency_pct')
    top_true = true_consistency_df.nlargest(top_n//2, 'consistency_pct')
    all_top_genes = pd.concat([top_gen['heavy_gene'], top_true['heavy_gene']]).unique()
    
    # Create matrix for heatmap
    heatmap_data = []
    for gene in all_top_genes:
        gen_row = gen_consistency_df[gen_consistency_df['heavy_gene'] == gene]
        true_row = true_consistency_df[true_consistency_df['heavy_gene'] == gene]
        
        gen_pct = gen_row['consistency_pct'].iloc[0] if len(gen_row) > 0 else 0
        true_pct = true_row['consistency_pct'].iloc[0] if len(true_row) > 0 else 0
        
        heatmap_data.append({
            'Heavy_Gene': gene,
            'Generated': gen_pct,
            'True': true_pct
        })
    
    heatmap_df = pd.DataFrame(heatmap_data)
    heatmap_df = heatmap_df.set_index('Heavy_Gene')
    
    # Sort by maximum consistency
    heatmap_df['max_consistency'] = heatmap_df.max(axis=1)
    heatmap_df = heatmap_df.sort_values('max_consistency', ascending=False)
    heatmap_df = heatmap_df.drop('max_consistency', axis=1)
    
    # Create heatmap
    plt.figure(figsize=(8, 12))
    sns.heatmap(heatmap_df, annot=True, fmt='.1f', cmap='RdYlBu_r', 
                cbar_kws={'label': 'Consistency Percentage (%)'})
    plt.title('Top Heavy V Genes: Generated vs True Consistency')
    plt.xlabel('Dataset')
    plt.ylabel('Heavy V Gene')
    plt.tight_layout()
    plt.show()
    
    return heatmap_df

# Create the visualizations
print("\n" + "="*60)
print("CONSISTENCY COMPARISON VISUALIZATIONS")
print("="*60)

comparison_data = create_simple_consistency_comparison(gen_full_consistency, true_full_consistency)
heatmap_data = create_top_genes_heatmap(gen_full_consistency, true_full_consistency)

# Create visualizations
print("\n" + "="*50)
print("6. CREATING VISUALIZATIONS")
print("="*50)

# Create consistency heatmaps
def create_consistency_heatmap(heavy_col, light_col, title, max_genes=20):
    """Create a heatmap showing heavy→light gene associations"""
    
    # Get top heavy genes by frequency
    top_heavy = df[heavy_col].value_counts().head(max_genes).index
    
    # Create crosstab
    crosstab = pd.crosstab(df[heavy_col], df[light_col], normalize='index') * 100
    crosstab = crosstab.loc[top_heavy]
    
    # Keep only columns with substantial associations (>10% for any heavy gene)
    # This threshold was too restrictive before
    col_mask = (crosstab > 10).any(axis=0)
    crosstab_filtered = crosstab.loc[:, col_mask]
    
    # If no columns meet the threshold, keep top associations
    if crosstab_filtered.empty or crosstab_filtered.shape[1] < 5:
        # Keep columns that represent the top light chain for each heavy chain
        top_light_per_heavy = crosstab.idxmax(axis=1)
        important_light_genes = set(top_light_per_heavy.values)
        
        # Also keep any column with >5% association
        col_mask_relaxed = (crosstab > 5).any(axis=0)
        additional_light_genes = set(crosstab.columns[col_mask_relaxed])
        
        # Combine both sets
        all_important_light = important_light_genes.union(additional_light_genes)
        crosstab_filtered = crosstab[list(all_important_light)]
    
    # Create heatmap
    plt.figure(figsize=(16, 10))
    
    # Only show annotations for values >= 30% to avoid cluttering
    annot_mask = crosstab_filtered >= 30
    
    sns.heatmap(crosstab_filtered, annot=True, fmt='.1f', cmap='RdYlBu_r', 
                cbar_kws={'label': 'Percentage (%)'}, 
                annot_kws={'size': 8})
    
    plt.title(f'{title}\nHeavy V Gene → Light V Gene Associations\n(Values ≥30% are annotated)')
    plt.xlabel('Light V Gene')
    plt.ylabel('Heavy V Gene')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Print some statistics about what's shown
    print(f"Heatmap for {title}:")
    print(f"  Heavy genes shown: {len(crosstab_filtered)}")
    print(f"  Light genes shown: {len(crosstab_filtered.columns)}")
    print(f"  Maximum association: {crosstab_filtered.max().max():.1f}%")
    print(f"  Associations ≥50%: {(crosstab_filtered >= 50).sum().sum()}")
    print(f"  Associations ≥30%: {(crosstab_filtered >= 30).sum().sum()}")
    print()
    
    return crosstab_filtered

# Create heatmaps for different analysis levels
print("\nCreating heatmaps...")

print("Generated Light Chain Heatmaps:")
# Generated - Full name heatmap
gen_full_heatmap = create_consistency_heatmap('heavy_chain_gene_name', 'gen_light_gene_name', 
                                             'Generated Light Chains - Full Names')

# Generated - Simplified heatmap  
gen_simplified_heatmap = create_consistency_heatmap('heavy_gene_simplified', 'gen_light_gene_simplified', 
                                                   'Generated Light Chains - Simplified')

# Generated - Family heatmap
gen_family_heatmap = create_consistency_heatmap('heavy_gene_family', 'gen_light_gene_family', 
                                               'Generated Light Chains - Family Level')

print("\nTrue Light Chain Heatmaps:")
# True - Full name heatmap
true_full_heatmap = create_consistency_heatmap('heavy_chain_gene_name', 'true_light_gene_name', 
                                              'True Light Chains - Full Names')

# True - Simplified heatmap  
true_simplified_heatmap = create_consistency_heatmap('heavy_gene_simplified', 'true_light_gene_simplified', 
                                                    'True Light Chains - Simplified')

# True - Family heatmap
true_family_heatmap = create_consistency_heatmap('heavy_gene_family', 'true_light_gene_family', 
                                                'True Light Chains - Family Level')

# Consistency distribution plots
fig, axes = plt.subplots(3, 2, figsize=(15, 18))

# Plot 1: Generated consistency percentage distribution
axes[0, 0].hist(gen_full_consistency['consistency_pct'], bins=20, alpha=0.7, edgecolor='black', color='skyblue')
axes[0, 0].set_title('Generated: Consistency Percentages\n(Full Names)')
axes[0, 0].set_xlabel('Consistency Percentage')
axes[0, 0].set_ylabel('Number of Heavy V Genes')
axes[0, 0].axvline(gen_full_consistency['consistency_pct'].mean(), color='red', linestyle='--', label='Mean')
axes[0, 0].legend()

# Plot 2: True consistency percentage distribution
axes[0, 1].hist(true_full_consistency['consistency_pct'], bins=20, alpha=0.7, edgecolor='black', color='lightcoral')
axes[0, 1].set_title('True: Consistency Percentages\n(Full Names)')
axes[0, 1].set_xlabel('Consistency Percentage')
axes[0, 1].set_ylabel('Number of Heavy V Genes')
axes[0, 1].axvline(true_full_consistency['consistency_pct'].mean(), color='red', linestyle='--', label='Mean')
axes[0, 1].legend()

# Plot 3: Generated unique light genes per heavy
axes[1, 0].hist(gen_full_consistency['unique_light_genes'], bins=20, alpha=0.7, edgecolor='black', color='skyblue')
axes[1, 0].set_title('Generated: Light Gene Diversity\n(Full Names)')
axes[1, 0].set_xlabel('Number of Unique Light Genes')
axes[1, 0].set_ylabel('Number of Heavy V Genes')

# Plot 4: True unique light genes per heavy
axes[1, 1].hist(true_full_consistency['unique_light_genes'], bins=20, alpha=0.7, edgecolor='black', color='lightcoral')
axes[1, 1].set_title('True: Light Gene Diversity\n(Full Names)')
axes[1, 1].set_xlabel('Number of Unique Light Genes')
axes[1, 1].set_ylabel('Number of Heavy V Genes')

# Plot 5: Generated vs True consistency comparison
valid_comparison = comparison_df.dropna(subset=['consistency_pct_gen', 'consistency_pct_true'])
axes[2, 0].scatter(valid_comparison['consistency_pct_true'], valid_comparison['consistency_pct_gen'], alpha=0.6)
axes[2, 0].plot([0, 100], [0, 100], 'r--', label='Perfect Match')
axes[2, 0].set_xlabel('True Consistency (%)')
axes[2, 0].set_ylabel('Generated Consistency (%)')
axes[2, 0].set_title('Generated vs True Consistency')
axes[2, 0].legend()

# Plot 6: Consistency categories comparison
gen_category_counts = gen_full_consistency['consistency_category'].value_counts()
true_category_counts = true_full_consistency['consistency_category'].value_counts()

categories = ['Perfect (≥95%)', 'Very High (90-94%)', 'High (75-89%)', 'Moderate (50-74%)', 'Low (<50%)']
gen_counts = [gen_category_counts.get(cat, 0) for cat in categories]
true_counts = [true_category_counts.get(cat, 0) for cat in categories]

x = np.arange(len(categories))
width = 0.35

axes[2, 1].bar(x - width/2, gen_counts, width, label='Generated', color='skyblue')
axes[2, 1].bar(x + width/2, true_counts, width, label='True', color='lightcoral')
axes[2, 1].set_xticks(x)
axes[2, 1].set_xticklabels(categories, rotation=45, ha='right')
axes[2, 1].set_title('Consistency Categories Comparison\n(Full Names)')
axes[2, 1].set_ylabel('Number of Heavy V Genes')
axes[2, 1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Create visualizations
print("\n" + "="*50)
print("6. CREATING VISUALIZATIONS")
print("="*50)

# Create consistency heatmaps
def create_consistency_heatmap(heavy_col, light_col, title, max_genes=20):
    """Create a heatmap showing heavy→light gene associations"""
    
    # Get top heavy genes by frequency
    top_heavy = df[heavy_col].value_counts().head(max_genes).index
    
    # Create crosstab
    crosstab = pd.crosstab(df[heavy_col], df[light_col], normalize='index') * 100
    crosstab = crosstab.loc[top_heavy]
    
    # Keep only columns with substantial associations (>10% for any heavy gene)
    # This threshold was too restrictive before
    col_mask = (crosstab > 10).any(axis=0)
    crosstab_filtered = crosstab.loc[:, col_mask]
    
    # If no columns meet the threshold, keep top associations
    if crosstab_filtered.empty or crosstab_filtered.shape[1] < 5:
        # Keep columns that represent the top light chain for each heavy chain
        top_light_per_heavy = crosstab.idxmax(axis=1)
        important_light_genes = set(top_light_per_heavy.values)
        
        # Also keep any column with >5% association
        col_mask_relaxed = (crosstab > 5).any(axis=0)
        additional_light_genes = set(crosstab.columns[col_mask_relaxed])
        
        # Combine both sets
        all_important_light = important_light_genes.union(additional_light_genes)
        crosstab_filtered = crosstab[list(all_important_light)]
    
    # Create heatmap
    plt.figure(figsize=(16, 10))
    
    # Only show annotations for values >= 30% to avoid cluttering
    annot_mask = crosstab_filtered >= 30
    
    sns.heatmap(crosstab_filtered, annot=True, fmt='.1f', cmap='RdYlBu_r', 
                cbar_kws={'label': 'Percentage (%)'}, 
                annot_kws={'size': 8})
    
    plt.title(f'{title}\nHeavy V Gene → Light V Gene Associations\n(Values ≥30% are annotated)')
    plt.xlabel('Light V Gene')
    plt.ylabel('Heavy V Gene')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Print some statistics about what's shown
    print(f"Heatmap for {title}:")
    print(f"  Heavy genes shown: {len(crosstab_filtered)}")
    print(f"  Light genes shown: {len(crosstab_filtered.columns)}")
    print(f"  Maximum association: {crosstab_filtered.max().max():.1f}%")
    print(f"  Associations ≥50%: {(crosstab_filtered >= 50).sum().sum()}")
    print(f"  Associations ≥30%: {(crosstab_filtered >= 30).sum().sum()}")
    print()
    
    return crosstab_filtered

# Create heatmaps for different analysis levels
print("\nCreating heatmaps...")

print("Generated Light Chain Heatmaps:")
# Generated - Full name heatmap
gen_full_heatmap = create_consistency_heatmap('heavy_chain_gene_name', 'gen_light_gene_name', 
                                             'Generated Light Chains - Full Names')

# Generated - Simplified heatmap  
gen_simplified_heatmap = create_consistency_heatmap('heavy_gene_simplified', 'gen_light_gene_simplified', 
                                                   'Generated Light Chains - Simplified')

# Generated - Family heatmap
gen_family_heatmap = create_consistency_heatmap('heavy_gene_family', 'gen_light_gene_family', 
                                               'Generated Light Chains - Family Level')

print("\nTrue Light Chain Heatmaps:")
# True - Full name heatmap
true_full_heatmap = create_consistency_heatmap('heavy_chain_gene_name', 'true_light_gene_name', 
                                              'True Light Chains - Full Names')

# True - Simplified heatmap  
true_simplified_heatmap = create_consistency_heatmap('heavy_gene_simplified', 'true_light_gene_simplified', 
                                                    'True Light Chains - Simplified')

# True - Family heatmap
true_family_heatmap = create_consistency_heatmap('heavy_gene_family', 'true_light_gene_family', 
                                                'True Light Chains - Family Level')

# Consistency distribution plots
fig, axes = plt.subplots(3, 2, figsize=(15, 18))

# Plot 1: Generated consistency percentage distribution
axes[0, 0].hist(gen_full_consistency['consistency_pct'], bins=20, alpha=0.7, edgecolor='black', color='skyblue')
axes[0, 0].set_title('Generated: Consistency Percentages\n(Full Names)')
axes[0, 0].set_xlabel('Consistency Percentage')
axes[0, 0].set_ylabel('Number of Heavy V Genes')
axes[0, 0].axvline(gen_full_consistency['consistency_pct'].mean(), color='red', linestyle='--', label='Mean')
axes[0, 0].legend()

# Plot 2: True consistency percentage distribution
axes[0, 1].hist(true_full_consistency['consistency_pct'], bins=20, alpha=0.7, edgecolor='black', color='lightcoral')
axes[0, 1].set_title('True: Consistency Percentages\n(Full Names)')
axes[0, 1].set_xlabel('Consistency Percentage')
axes[0, 1].set_ylabel('Number of Heavy V Genes')
axes[0, 1].axvline(true_full_consistency['consistency_pct'].mean(), color='red', linestyle='--', label='Mean')
axes[0, 1].legend()

# Plot 3: Generated unique light genes per heavy
axes[1, 0].hist(gen_full_consistency['unique_light_genes'], bins=20, alpha=0.7, edgecolor='black', color='skyblue')
axes[1, 0].set_title('Generated: Light Gene Diversity\n(Full Names)')
axes[1, 0].set_xlabel('Number of Unique Light Genes')
axes[1, 0].set_ylabel('Number of Heavy V Genes')

# Plot 4: True unique light genes per heavy
axes[1, 1].hist(true_full_consistency['unique_light_genes'], bins=20, alpha=0.7, edgecolor='black', color='lightcoral')
axes[1, 1].set_title('True: Light Gene Diversity\n(Full Names)')
axes[1, 1].set_xlabel('Number of Unique Light Genes')
axes[1, 1].set_ylabel('Number of Heavy V Genes')

# Plot 5: Generated vs True consistency comparison
valid_comparison = comparison_df.dropna(subset=['consistency_pct_gen', 'consistency_pct_true'])
axes[2, 0].scatter(valid_comparison['consistency_pct_true'], valid_comparison['consistency_pct_gen'], alpha=0.6)
axes[2, 0].plot([0, 100], [0, 100], 'r--', label='Perfect Match')
axes[2, 0].set_xlabel('True Consistency (%)')
axes[2, 0].set_ylabel('Generated Consistency (%)')
axes[2, 0].set_title('Generated vs True Consistency')
axes[2, 0].legend()

# Plot 6: Consistency categories comparison
gen_category_counts = gen_full_consistency['consistency_category'].value_counts()
true_category_counts = true_full_consistency['consistency_category'].value_counts()

categories = ['Perfect (≥95%)', 'Very High (90-94%)', 'High (75-89%)', 'Moderate (50-74%)', 'Low (<50%)']
gen_counts = [gen_category_counts.get(cat, 0) for cat in categories]
true_counts = [true_category_counts.get(cat, 0) for cat in categories]

x = np.arange(len(categories))
width = 0.35

axes[2, 1].bar(x - width/2, gen_counts, width, label='Generated', color='skyblue')
axes[2, 1].bar(x + width/2, true_counts, width, label='True', color='lightcoral')
axes[2, 1].set_xticks(x)
axes[2, 1].set_xticklabels(categories, rotation=45, ha='right')
axes[2, 1].set_title('Consistency Categories Comparison\n(Full Names)')
axes[2, 1].set_ylabel('Number of Heavy V Genes')
axes[2, 1].legend()

plt.tight_layout()
plt.show()


In [None]:


# Summary report
print("\n" + "="*80)
print("SUMMARY REPORT")
print("="*80)

print(f"\nKey Findings:")
print(f"1. V Gene Pairing Consistency:")
print(f"   - Average consistency (full names): {gen_full_consistency['consistency_pct'].mean():.1f}%")
print(f"   - Average consistency (simplified): {gen_simplified_consistency['consistency_pct'].mean():.1f}%")
print(f"   - Average consistency (family): {gen_family_consistency['consistency_pct'].mean():.1f}%")

high_consistency = (gen_full_consistency['consistency_pct'] >= 75).sum()
total_genes = len(gen_full_consistency)
print(f"2. High Consistency Genes: {high_consistency}/{total_genes} ({high_consistency/total_genes*100:.1f}%)")

print(f"3. Generated vs True Comparison:")
print(f"   - Matching preferences: {matches}/{total} ({matches/total*100:.1f}%)")
print(f"   - Generated consistency: {gen_avg:.1f}%")
print(f"   - True consistency: {true_avg:.1f}%")

print(f"\n4. Most Consistent Heavy V Genes (Generated):")
top_5 = gen_full_consistency.head(5)
for _, row in top_5.iterrows():
    print(f"   {row['heavy_gene']}: {row['consistency_pct']:.1f}% → {row['most_common_light']}")

print(f"\n5. Most Diverse Heavy V Genes (Generated):")
diverse_5 = gen_full_consistency.nlargest(5, 'unique_light_genes')
for _, row in diverse_5.iterrows():
    print(f"   {row['heavy_gene']}: {row['unique_light_genes']} different light genes")

print(f"\nAnalysis complete! The model shows {'high' if gen_avg >= 75 else 'moderate' if gen_avg >= 50 else 'low'} overall consistency in V gene pairing.")
print(f"Generated sequences {'match' if matches/total > 0.5 else 'differ from'} true pairing preferences in most cases.")

In [None]:
import pandas as pd
import numpy as np
from collections import Counter, defaultdict


def simplify_gene_name(gene_name):
    """Remove allele designation (part after *) from gene name"""
    if pd.isna(gene_name):
        return gene_name
    return str(gene_name).split('*')[0]

def extract_gene_family(gene_name):
    """Extract gene family (e.g., IGHV3 from IGHV3-23*01)"""
    if pd.isna(gene_name):
        return gene_name
    simplified = str(gene_name).split('*')[0]
    # Extract family part (everything before the dash)
    if '-' in simplified:
        return simplified.split('-')[0]
    return simplified

def create_test_data():
    """Create test data to verify the analysis logic"""
    
    print("Creating test data...")
    
    # Create test data with known patterns
    test_data = []
    
    # Heavy chain 1: IGHV3-23*01 - should consistently generate IGKV1-39*01 (8/10 times)
    # True light chain: IGKV1-5*03
    heavy1_data = [
        {"heavy_chain_number": 1, "gen_light_chain_number": i+1, 
         "heavy_chain_gene_name": "IGHV3-23*01", "true_light_gene_name": "IGKV1-5*03",
         "gen_light_gene_name": "IGKV1-39*01" if i < 8 else "IGKV2-28*01"}
        for i in range(10)
    ]
    
    # Heavy chain 2: IGHV4-34*01 - should generate IGLV2-14*01 perfectly (10/10 times)
    # True light chain: IGLV2-14*01 (perfect match)
    heavy2_data = [
        {"heavy_chain_number": 2, "gen_light_chain_number": i+1,
         "heavy_chain_gene_name": "IGHV4-34*01", "true_light_gene_name": "IGLV2-14*01",
         "gen_light_gene_name": "IGLV2-14*01"}
        for i in range(10)
    ]
    
    # Heavy chain 3: IGHV1-2*02 - should be diverse (no clear preference)
    # True light chain: IGKV3-20*01
    heavy3_genes = ["IGKV1-39*01", "IGKV2-28*01", "IGLV1-44*01", "IGLV2-14*01", "IGKV3-11*01"] * 2
    heavy3_data = [
        {"heavy_chain_number": 3, "gen_light_chain_number": i+1,
         "heavy_chain_gene_name": "IGHV1-2*02", "true_light_gene_name": "IGKV3-20*01",
         "gen_light_gene_name": heavy3_genes[i]}
        for i in range(10)
    ]
    
    # Heavy chain 4: IGHV3-23*02 (different allele) - should generate IGKV1-39*01 (7/10 times)
    # True light chain: IGKV1-39*01 (matches most common generated)
    heavy4_data = [
        {"heavy_chain_number": 4, "gen_light_chain_number": i+1,
         "heavy_chain_gene_name": "IGHV3-23*02", "true_light_gene_name": "IGKV1-39*01",
         "gen_light_gene_name": "IGKV1-39*01" if i < 7 else ("IGKV2-28*01" if i < 9 else "IGLV1-44*01")}
        for i in range(10)
    ]
    
    # Heavy chain 5: IGHV5-51*01 - moderate consistency
    # True light chain: IGLV3-19*01
    heavy5_data = [
        {"heavy_chain_number": 5, "gen_light_chain_number": i+1,
         "heavy_chain_gene_name": "IGHV5-51*01", "true_light_gene_name": "IGLV3-19*01",
         "gen_light_gene_name": "IGLV3-19*01" if i < 6 else "IGKV2-28*01"}
        for i in range(10)
    ]
    
    # Combine all test data
    test_data = heavy1_data + heavy2_data + heavy3_data + heavy4_data + heavy5_data
    
    # Create DataFrame
    df = pd.DataFrame(test_data)
    
    print("Test data created:")
    print(f"Total rows: {len(df)}")
    print(f"Heavy chains: {df['heavy_chain_number'].nunique()}")
    print(f"Expected patterns:")
    print("  Heavy 1 (IGHV3-23*01): 80% IGKV1-39*01, True: IGKV1-5*03")
    print("  Heavy 2 (IGHV4-34*01): 100% IGLV2-14*01, True: IGLV2-14*01 (perfect match)")
    print("  Heavy 3 (IGHV1-2*02): Diverse (20% each), True: IGKV3-20*01")
    print("  Heavy 4 (IGHV3-23*02): 70% IGKV1-39*01, True: IGKV1-39*01 (matches most common)")
    print("  Heavy 5 (IGHV5-51*01): 60% IGLV3-19*01, True: IGLV3-19*01 (matches most common)")
    
    return df

def analyze_vgene_consistency_numbers_only(df):
    """
    Analyze V gene consistency - numbers only, no plots
    """
    
    print("\n" + "="*80)
    print("V GENE CONSISTENCY ANALYSIS")
    print("="*80)
    
    # ============================
    # HANDLE DUPLICATES PROPERLY
    # ============================
    
    # Group by heavy chain to get unique heavy chains and their true light chains
    heavy_chains = df.groupby('heavy_chain_number').first()
    
    # Count unique heavy chains and total generated sequences
    unique_heavy_chains = len(heavy_chains)
    total_generated_sequences = len(df)
    
    # Get the true light chains (one per heavy chain)
    true_light_chains = heavy_chains['true_light_gene_name'].dropna()
    
    print(f"Total heavy chains: {unique_heavy_chains}")
    print(f"Total generated sequences: {total_generated_sequences}")
    print(f"Total true light chains: {len(true_light_chains)} (should equal heavy chains)")
    
    if len(true_light_chains) != unique_heavy_chains:
        print("⚠️  Warning: Number of true light chains doesn't match number of heavy chains")
        print(f"   Expected: {unique_heavy_chains}, Got: {len(true_light_chains)}")
    else:
        print("✅ Correct: One true light chain per heavy chain")
    
    # ============================
    # COLLECT V GENE ASSOCIATIONS
    # ============================
    
    # For each heavy chain, collect all its generated light chains
    heavy_gen_associations = defaultdict(list)
    heavy_true_mapping = {}
    
    # Group by heavy chain number to handle each heavy chain separately
    for heavy_num, group in df.groupby('heavy_chain_number'):
        heavy_gene = group['heavy_chain_gene_name'].iloc[0]  # Same for all rows in group
        true_light_gene = group['true_light_gene_name'].iloc[0]  # Same for all rows in group
        
        if pd.isna(heavy_gene) or pd.isna(true_light_gene):
            continue
            
        # Collect all generated light genes for this heavy chain
        gen_light_genes = group['gen_light_gene_name'].dropna().tolist()
        
        heavy_gen_associations[heavy_gene].extend(gen_light_genes)
        heavy_true_mapping[heavy_gene] = true_light_gene
    
    print(f"Unique heavy V genes: {len(heavy_gen_associations)}")
    
    # Verify the data structure
    total_gen_collected = sum(len(genes) for genes in heavy_gen_associations.values())
    print(f"Total generated sequences collected: {total_gen_collected}")
    
    if total_gen_collected != total_generated_sequences:
        print(f"⚠️  Warning: Collected sequences ({total_gen_collected}) don't match total ({total_generated_sequences})")
    
    # ============================
    # ANALYZE CONSISTENCY AT DIFFERENT LEVELS
    # ============================
    
    def analyze_level(associations, true_mapping, level_name, transform_func):
        """Analyze consistency at a specific level"""
        
        print(f"\n{level_name} LEVEL ANALYSIS:")
        print("-" * 50)
        
        # Transform gene names and group by transformed heavy gene
        transformed_associations = defaultdict(list)
        transformed_true_mapping = {}
        
        for heavy_gene, light_genes in associations.items():
            transformed_heavy = transform_func(heavy_gene)
            transformed_true = transform_func(true_mapping[heavy_gene])
            
            # If multiple original heavy genes map to the same transformed heavy gene,
            # combine their generated light genes
            transformed_associations[transformed_heavy].extend([transform_func(lg) for lg in light_genes])
            
            # For true mapping, if there's a conflict, keep the first one
            # (in practice, this shouldn't happen much)
            if transformed_heavy not in transformed_true_mapping:
                transformed_true_mapping[transformed_heavy] = transformed_true
        
        # Calculate consistency for each heavy gene
        consistency_results = []
        
        for heavy_gene, light_genes in transformed_associations.items():
            if len(light_genes) == 0:
                continue
            
            # Count light gene frequencies
            light_counts = Counter(light_genes)
            total_generated = len(light_genes)
            
            # Most common generated light gene
            most_common_light = light_counts.most_common(1)[0]
            most_common_gene = most_common_light[0]
            most_common_count = most_common_light[1]
            most_common_pct = (most_common_count / total_generated) * 100
            
            # True light gene for this heavy gene
            true_light_gene = transformed_true_mapping[heavy_gene]
            
            # Check if true matches most common generated
            true_matches_most_common = (true_light_gene == most_common_gene)
            
            # Count how many times true light gene was generated
            true_generated_count = light_counts.get(true_light_gene, 0)
            true_generated_pct = (true_generated_count / total_generated) * 100
            
            # Consistency classification
            if most_common_pct >= 95:
                consistency_level = "Perfect"
            elif most_common_pct >= 90:
                consistency_level = "Very High"
            elif most_common_pct >= 75:
                consistency_level = "High"
            elif most_common_pct >= 50:
                consistency_level = "Moderate"
            else:
                consistency_level = "Low"
            
            consistency_results.append({
                'heavy_gene': heavy_gene,
                'total_generated': total_generated,
                'unique_light_genes': len(light_counts),
                'most_common_light': most_common_gene,
                'most_common_count': most_common_count,
                'most_common_pct': most_common_pct,
                'true_light_gene': true_light_gene,
                'true_generated_count': true_generated_count,
                'true_generated_pct': true_generated_pct,
                'true_matches_most_common': true_matches_most_common,
                'consistency_level': consistency_level
            })
        
        # Sort by consistency percentage
        consistency_results.sort(key=lambda x: x['most_common_pct'], reverse=True)
        
        # Print results
        print(f"{'Heavy Gene':<20} {'N':<5} {'Most Common Generated':<25} {'%':<6} {'True Light':<25} {'True Gen %':<10} {'Match':<6} {'Level':<10}")
        print("-" * 120)
        
        for result in consistency_results:
            match_symbol = "✓" if result['true_matches_most_common'] else "✗"
            # Calculate number of heavy chains (sample size)
            n_heavy_chains = result['total_generated'] // 10  # Each heavy chain contributes 10 sequences
            print(f"{result['heavy_gene']:<20} {n_heavy_chains:<5} {result['most_common_light']:<25} "
                  f"{result['most_common_pct']:<6.1f} {result['true_light_gene']:<25} "
                  f"{result['true_generated_pct']:<10.1f} {match_symbol:<6} {result['consistency_level']:<10}")
        
        # Summary statistics with sample size info
        if consistency_results:
            total_heavy_genes = len(consistency_results)
            perfect = len([r for r in consistency_results if r['consistency_level'] == 'Perfect'])
            very_high = len([r for r in consistency_results if r['consistency_level'] == 'Very High'])
            high = len([r for r in consistency_results if r['consistency_level'] == 'High'])
            moderate = len([r for r in consistency_results if r['consistency_level'] == 'Moderate'])
            low = len([r for r in consistency_results if r['consistency_level'] == 'Low'])
            
            true_matches = len([r for r in consistency_results if r['true_matches_most_common']])
            
            # Sample size statistics
            sample_sizes = [r['total_generated'] // 10 for r in consistency_results]
            min_sample = min(sample_sizes)
            max_sample = max(sample_sizes)
            avg_sample = sum(sample_sizes) / len(sample_sizes)
            
            print(f"\nSUMMARY ({total_heavy_genes} heavy V genes):")
            print(f"  Sample sizes: Min={min_sample}, Max={max_sample}, Avg={avg_sample:.1f} heavy chains per gene")
            print(f"  Perfect (≥95%): {perfect} ({perfect/total_heavy_genes*100:.1f}%)")
            print(f"  Very High (90-94%): {very_high} ({very_high/total_heavy_genes*100:.1f}%)")
            print(f"  High (75-89%): {high} ({high/total_heavy_genes*100:.1f}%)")
            print(f"  Moderate (50-74%): {moderate} ({moderate/total_heavy_genes*100:.1f}%)")
            print(f"  Low (<50%): {low} ({low/total_heavy_genes*100:.1f}%)")
            print(f"  True matches most common: {true_matches}/{total_heavy_genes} ({true_matches/total_heavy_genes*100:.1f}%)")
            
            # Average consistency
            avg_consistency = sum(r['most_common_pct'] for r in consistency_results) / len(consistency_results)
            avg_true_generation = sum(r['true_generated_pct'] for r in consistency_results) / len(consistency_results)
            
            print(f"  Average consistency: {avg_consistency:.1f}%")
            print(f"  Average true light generation: {avg_true_generation:.1f}%")
            
            # Show genes with largest sample sizes
            top_samples = sorted(consistency_results, key=lambda x: x['total_generated'], reverse=True)[:10]
            print(f"\nTop 10 genes by sample size:")
            for result in top_samples:
                n_chains = result['total_generated'] // 10
                print(f"  {result['heavy_gene']:<20} {n_chains} heavy chains ({result['consistency_level']}, {result['most_common_pct']:.1f}%)")
        
        return consistency_results
    
    # Analyze at all levels
    full_results = analyze_level(heavy_gen_associations, heavy_true_mapping, "FULL NAMES", lambda x: x)
    simple_results = analyze_level(heavy_gen_associations, heavy_true_mapping, "SIMPLIFIED", simplify_gene_name)
    family_results = analyze_level(heavy_gen_associations, heavy_true_mapping, "GENE FAMILIES", extract_gene_family)
    
    return {
        'full_results': full_results,
        'simple_results': simple_results,
        'family_results': family_results
    }

def analyze_csv_file(csv_file):
    """Analyze real CSV file"""
    
    print("Reading CSV file...")
    df = pd.read_csv(csv_file)
    
    print(f"Dataset loaded:")
    print(f"Total rows: {len(df)}")
    print(f"Heavy chains: {df['heavy_chain_number'].nunique()}")
    print(f"Columns: {df.columns.tolist()}")
    
    # Check required columns
    required_cols = ['heavy_chain_number', 'heavy_chain_gene_name', 'gen_light_gene_name', 'true_light_gene_name']
    missing_cols = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        print(f"❌ Missing required columns: {missing_cols}")
        return None
    
    print("✅ All required columns found")
    
    return analyze_vgene_consistency_numbers_only(df)




In [None]:

# print("TESTING WITH SAMPLE DATA")
# print("="*50)
    
# test_df = create_test_data()
# test_results = analyze_vgene_consistency_numbers_only(test_df)
    
# print("\n" + "="*80)
# print("TEST COMPLETED - Check if results match expected patterns above")
# print("="*80)
    
    
csv_file = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv'  
    

real_results = analyze_csv_file(csv_file)
if real_results:
    print("\nReal data analysis completed successfully!")
        


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import re

def load_and_prepare_data(csv_file):
    """Load CSV and prepare data for analysis"""
    df = pd.read_csv(csv_file)
    
    print(f"Dataset shape: {df.shape}")
    print(f"Unique heavy chains: {df['heavy_chain_number'].nunique()}")
    print(f"Total rows: {len(df)}")
    
    return df

def create_groups(df):
    """Create the 4 groups based on predicted labels"""
    
    # Group 1: Both memory B cell origin (both labels = 1)
    group1 = df[(df['predicted_gen_light_seq_label'] == 1) & 
                (df['predicted_input_heavy_seq_label'] == 1)]
    
    # Group 2: Both naive B cell origin (both labels = 0)
    group2 = df[(df['predicted_gen_light_seq_label'] == 0) & 
                (df['predicted_input_heavy_seq_label'] == 0)]
    
    # Group 3: Heavy memory, Generated light naive (heavy=1, gen_light=0)
    group3 = df[(df['predicted_input_heavy_seq_label'] == 1) & 
                (df['predicted_gen_light_seq_label'] == 0)]
    
    # Group 4: Heavy naive, Generated light memory (heavy=0, gen_light=1)
    group4 = df[(df['predicted_input_heavy_seq_label'] == 0) & 
                (df['predicted_gen_light_seq_label'] == 1)]
    
    groups = {
        'Group 1 (H:Memory, L:Memory)': group1,
        'Group 2 (H:Naive, L:Naive)': group2,
        'Group 3 (H:Memory, L:Naive)': group3,
        'Group 4 (H:Naive, L:Memory)': group4
    }
    
    print("\nGroup sizes:")
    for name, group in groups.items():
        print(f"{name}: {len(group)} rows")
        if len(group) > 0:
            print(f"  Unique heavy chains: {group['heavy_chain_number'].nunique()}")
    
    return groups

def extract_v_gene_info(gene_name):
    """Extract different levels of V gene information"""
    if pd.isna(gene_name) or gene_name == '':
        return {'full': 'Unknown', 'simplified': 'Unknown', 'family': 'Unknown'}
    
    # Full gene name
    full = gene_name
    
    # Simplified (remove *XX part)
    simplified = re.sub(r'\*\d+', '', gene_name)
    
    # Family (extract the main family part)
    family_match = re.match(r'(IG[HKL]V\d+)', gene_name)
    family = family_match.group(1) if family_match else gene_name.split('-')[0]
    
    return {'full': full, 'simplified': simplified, 'family': family}

def create_heavy_light_pairing_matrix(group_df, gene_level='full'):
    """Create heavy-light V gene pairing matrix"""
    
    if len(group_df) == 0:
        return pd.DataFrame(), pd.DataFrame()
    
    # Extract V gene information for heavy and generated light chains
    pairing_data = []
    
    for _, row in group_df.iterrows():
        heavy_gene_info = extract_v_gene_info(row['heavy_chain_gene_name'])
        gen_light_gene_info = extract_v_gene_info(row['gen_light_gene_name'])
        
        heavy_gene = heavy_gene_info[gene_level]
        light_gene = gen_light_gene_info[gene_level]
        
        pairing_data.append({
            'heavy_chain_number': row['heavy_chain_number'],
            'heavy_gene': heavy_gene,
            'light_gene': light_gene
        })
    
    pairing_df = pd.DataFrame(pairing_data)
    
    # Create co-occurrence matrix (counts)
    cooccurrence_matrix = pd.crosstab(
        pairing_df['heavy_gene'], 
        pairing_df['light_gene'], 
        margins=False
    )
    
    # Calculate percentages (what percentage of each heavy gene generates each light gene)
    # This shows the preference of heavy genes for light genes
    percentage_matrix = cooccurrence_matrix.div(cooccurrence_matrix.sum(axis=1), axis=0) * 100
    
    # Fill NaN values with 0
    percentage_matrix = percentage_matrix.fillna(0)
    
    # Add sample sizes information
    heavy_gene_counts = pairing_df['heavy_gene'].value_counts()
    
    print(f"\nSample sizes for each heavy gene (total generated sequences):")
    for gene, count in heavy_gene_counts.head(10).items():
        unique_heavy_chains = pairing_df[pairing_df['heavy_gene'] == gene]['heavy_chain_number'].nunique()
        print(f"  {gene}: {count} total sequences from {unique_heavy_chains} heavy chains")
    
    return cooccurrence_matrix, percentage_matrix

def get_top_genes(matrix, top_n=15, axis='both'):
    """Get top N most frequent genes"""
    if axis == 'both':
        # Get top genes from both heavy and light
        heavy_totals = matrix.sum(axis=1)
        light_totals = matrix.sum(axis=0)
        
        top_heavy = heavy_totals.nlargest(top_n).index.tolist()
        top_light = light_totals.nlargest(top_n).index.tolist()
        
        return top_heavy, top_light
    elif axis == 'heavy':
        heavy_totals = matrix.sum(axis=1)
        return heavy_totals.nlargest(top_n).index.tolist()
    elif axis == 'light':
        light_totals = matrix.sum(axis=0)
        return light_totals.nlargest(top_n).index.tolist()

def create_pairing_heatmap(percentage_matrix, count_matrix, title, top_heavy=None, top_light=None, figsize=(14, 10)):
    """Create heatmap for heavy-light V gene pairing"""
    
    if percentage_matrix.empty:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=16)
        ax.set_title(title)
        return fig
    
    # Filter to top genes if specified
    plot_percentage = percentage_matrix.copy()
    plot_count = count_matrix.copy()
    
    if top_heavy:
        available_heavy = [gene for gene in top_heavy if gene in plot_percentage.index]
        if available_heavy:
            plot_percentage = plot_percentage.loc[available_heavy]
            plot_count = plot_count.loc[available_heavy]
    
    if top_light:
        available_light = [gene for gene in top_light if gene in plot_percentage.columns]
        if available_light:
            plot_percentage = plot_percentage[available_light]
            plot_count = plot_count[available_light]
    
    # Create annotations combining percentage and count
    annot_data = plot_percentage.round(1).astype(str) + '%\n(' + plot_count.astype(str) + ')'
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=figsize)
    
    sns.heatmap(plot_percentage, 
                annot=annot_data, 
                cmap='YlOrRd', 
                fmt='', 
                ax=ax,
                cbar_kws={'label': 'Percentage of Heavy Gene Usage'},
                linewidths=0.5)
    
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('Generated Light Chain V Gene', fontsize=12)
    ax.set_ylabel('Heavy Chain V Gene', fontsize=12)
    
    # Rotate labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    # Add sample size information
    total_sequences = plot_count.sum().sum()
    unique_heavy_genes = len(plot_percentage.index)
    unique_light_genes = len(plot_percentage.columns)
    
    sample_info = f"Total sequences: {total_sequences}\n" \
                  f"Heavy genes: {unique_heavy_genes}, Light genes: {unique_light_genes}"
    ax.text(0.02, 0.98, sample_info, transform=ax.transAxes, 
            verticalalignment='top', fontsize=10, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    return fig

def analyze_pairing_preferences(percentage_matrix, count_matrix, gene_level):
    """Analyze and print pairing preferences"""
    
    print(f"\n=== Pairing Analysis for {gene_level} level ===")
    
    # Find heavy genes with strong preferences (>50% for any light gene)
    strong_preferences = []
    for heavy_gene in percentage_matrix.index:
        max_pct = percentage_matrix.loc[heavy_gene].max()
        if max_pct > 50:
            preferred_light = percentage_matrix.loc[heavy_gene].idxmax()
            count = count_matrix.loc[heavy_gene, preferred_light]
            total_count = count_matrix.loc[heavy_gene].sum()
            strong_preferences.append({
                'heavy_gene': heavy_gene,
                'preferred_light': preferred_light,
                'percentage': max_pct,
                'count': count,
                'total_sequences': total_count
            })
    
    if strong_preferences:
        print("Heavy genes with strong light gene preferences (>50%):")
        for pref in sorted(strong_preferences, key=lambda x: x['percentage'], reverse=True):
            print(f"  {pref['heavy_gene']} → {pref['preferred_light']}: "
                  f"{pref['percentage']:.1f}% ({pref['count']}/{pref['total_sequences']} sequences)")
    else:
        print("No heavy genes show strong preferences (>50%) for specific light genes")
    
    # Find the most diverse heavy genes (spread across many light genes)
    diversity_scores = []
    for heavy_gene in percentage_matrix.index:
        # Count how many light genes this heavy gene pairs with (>5% threshold)
        num_partners = (percentage_matrix.loc[heavy_gene] > 5).sum()
        total_sequences = count_matrix.loc[heavy_gene].sum()
        diversity_scores.append({
            'heavy_gene': heavy_gene,
            'num_partners': num_partners,
            'total_sequences': total_sequences
        })
    
    print(f"\nMost diverse heavy genes (pairing with multiple light genes >5%):")
    for div in sorted(diversity_scores, key=lambda x: x['num_partners'], reverse=True)[:5]:
        print(f"  {div['heavy_gene']}: pairs with {div['num_partners']} light genes "
              f"({div['total_sequences']} total sequences)")

def create_all_pairing_heatmaps(groups):
    """Create all pairing heatmaps for all groups and gene levels"""
    
    gene_levels = ['full', 'simplified', 'family']
    level_names = ['Full V Gene', 'Simplified V Gene', 'V Gene Family']
    
    all_figures = {}
    
    for group_name, group_df in groups.items():
        print(f"\n{'='*50}")
        print(f"Processing {group_name}...")
        print(f"{'='*50}")
        
        if len(group_df) == 0:
            print(f"No data for {group_name}")
            continue
        
        group_figures = {}
        
        for level, level_name in zip(gene_levels, level_names):
            print(f"\n--- {level_name} ---")
            
            # Create pairing matrices
            count_matrix, percentage_matrix = create_heavy_light_pairing_matrix(group_df, level)
            
            if count_matrix.empty:
                continue
            
            # Analyze pairing preferences
            analyze_pairing_preferences(percentage_matrix, count_matrix, level_name)
            
            # Get top genes
            top_heavy, top_light = get_top_genes(count_matrix, top_n=15)
            
            # Create heatmap with all genes
            fig_all = create_pairing_heatmap(
                percentage_matrix, count_matrix,
                f"{group_name} - {level_name}\nHeavy-Light V Gene Pairing (All Genes)",
                figsize=(max(14, len(percentage_matrix.columns) * 0.8), 
                        max(10, len(percentage_matrix.index) * 0.5))
            )
            
            # Create heatmap with top genes only
            fig_top = create_pairing_heatmap(
                percentage_matrix, count_matrix,
                f"{group_name} - {level_name}\nHeavy-Light V Gene Pairing (Top 15 Genes)",
                top_heavy=top_heavy,
                top_light=top_light,
                figsize=(14, 10)
            )
            
            group_figures[f"{level}_all"] = fig_all
            group_figures[f"{level}_top"] = fig_top
        
        all_figures[group_name] = group_figures
    
    return all_figures



In [None]:

def main(csv_file):
    """Main analysis function"""
    
    print("Heavy-Light V Gene Pairing Analysis")
    print("="*50)
    
    # Load data
    df = load_and_prepare_data(csv_file)
    
    # Create groups
    groups = create_groups(df)
    
    # Create all heatmaps
    all_figures = create_all_pairing_heatmaps(groups)
    
    # Display all figures
    for group_name, group_figures in all_figures.items():
        print(f"\n{'='*60}")
        print(f"DISPLAYING HEATMAPS FOR {group_name}")
        print(f"{'='*60}")
        for fig_name, fig in group_figures.items():
            plt.figure(fig.number)
            plt.show()
    
    print(f"\n{'='*60}")
    print("ANALYSIS COMPLETE")
    print(f"Total heatmaps created: {sum(len(figs) for figs in all_figures.values())}")
    print(f"{'='*60}")
    
    return all_figures


figures = main('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv')


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import re

def load_and_prepare_data(csv_file):
    """Load CSV and prepare data for analysis"""
    df = pd.read_csv(csv_file)
    
    print(f"Dataset shape: {df.shape}")
    print(f"Unique heavy chains: {df['heavy_chain_number'].nunique()}")
    print(f"Total rows: {len(df)}")
    
    return df

def create_groups(df):
    """Create the 4 groups based on predicted labels"""
    
    # Group 1: Both memory B cell origin (both labels = 1)
    group1 = df[(df['predicted_gen_light_seq_label'] == 1) & 
                (df['predicted_input_heavy_seq_label'] == 1)]
    
    # Group 2: Both naive B cell origin (both labels = 0)
    group2 = df[(df['predicted_gen_light_seq_label'] == 0) & 
                (df['predicted_input_heavy_seq_label'] == 0)]
    
    # Group 3: Heavy memory, Generated light naive (heavy=1, gen_light=0)
    group3 = df[(df['predicted_input_heavy_seq_label'] == 1) & 
                (df['predicted_gen_light_seq_label'] == 0)]
    
    # Group 4: Heavy naive, Generated light memory (heavy=0, gen_light=1)
    group4 = df[(df['predicted_input_heavy_seq_label'] == 0) & 
                (df['predicted_gen_light_seq_label'] == 1)]
    
    groups = {
        'Group 1 (H:Memory, L:Memory)': group1,
        'Group 2 (H:Naive, L:Naive)': group2,
        'Group 3 (H:Memory, L:Naive)': group3,
        'Group 4 (H:Naive, L:Memory)': group4
    }
    
    print("\nGroup sizes:")
    for name, group in groups.items():
        print(f"{name}: {len(group)} rows")
        if len(group) > 0:
            print(f"  Unique heavy chains: {group['heavy_chain_number'].nunique()}")
    
    return groups

def extract_v_gene_info(gene_name):
    """Extract different levels of V gene information"""
    if pd.isna(gene_name) or gene_name == '':
        return {'full': 'Unknown', 'simplified': 'Unknown', 'family': 'Unknown'}
    
    # Full gene name
    full = gene_name
    
    # Simplified (remove *XX part)
    simplified = re.sub(r'\*\d+', '', gene_name)
    
    # Family (extract the main family part)
    family_match = re.match(r'(IG[HKL]V\d+)', gene_name)
    family = family_match.group(1) if family_match else gene_name.split('-')[0]
    
    return {'full': full, 'simplified': simplified, 'family': family}

def create_heavy_light_pairing_matrix(group_df, gene_level='full'):
    """Create heavy-light V gene pairing matrix"""
    
    if len(group_df) == 0:
        return pd.DataFrame(), pd.DataFrame()
    
    # Extract V gene information for heavy and generated light chains
    pairing_data = []
    
    for _, row in group_df.iterrows():
        heavy_gene_info = extract_v_gene_info(row['heavy_chain_gene_name'])
        gen_light_gene_info = extract_v_gene_info(row['gen_light_gene_name'])
        
        heavy_gene = heavy_gene_info[gene_level]
        light_gene = gen_light_gene_info[gene_level]
        
        pairing_data.append({
            'heavy_chain_number': row['heavy_chain_number'],
            'heavy_gene': heavy_gene,
            'light_gene': light_gene
        })
    
    pairing_df = pd.DataFrame(pairing_data)
    
    # Create co-occurrence matrix (counts)
    cooccurrence_matrix = pd.crosstab(
        pairing_df['heavy_gene'], 
        pairing_df['light_gene'], 
        margins=False
    )
    
    # Calculate percentages (what percentage of each heavy gene generates each light gene)
    # This shows the preference of heavy genes for light genes
    percentage_matrix = cooccurrence_matrix.div(cooccurrence_matrix.sum(axis=1), axis=0) * 100
    
    # Fill NaN values with 0
    percentage_matrix = percentage_matrix.fillna(0)
    
    # Add sample sizes information
    heavy_gene_counts = pairing_df['heavy_gene'].value_counts()
    
    print(f"\nSample sizes for each heavy gene (total generated sequences):")
    for gene, count in heavy_gene_counts.head(10).items():
        unique_heavy_chains = pairing_df[pairing_df['heavy_gene'] == gene]['heavy_chain_number'].nunique()
        print(f"  {gene}: {count} total sequences from {unique_heavy_chains} heavy chains")
    
    return cooccurrence_matrix, percentage_matrix

def get_top_genes(matrix, top_n=15, axis='both'):
    """Get top N most frequent genes"""
    if axis == 'both':
        # Get top genes from both heavy and light
        heavy_totals = matrix.sum(axis=1)
        light_totals = matrix.sum(axis=0)
        
        top_heavy = heavy_totals.nlargest(top_n).index.tolist()
        top_light = light_totals.nlargest(top_n).index.tolist()
        
        return top_heavy, top_light
    elif axis == 'heavy':
        heavy_totals = matrix.sum(axis=1)
        return heavy_totals.nlargest(top_n).index.tolist()
    elif axis == 'light':
        light_totals = matrix.sum(axis=0)
        return light_totals.nlargest(top_n).index.tolist()

def create_pairing_heatmap(percentage_matrix, count_matrix, title, top_heavy=None, top_light=None, figsize=(14, 10)):
    """Create heatmap for heavy-light V gene pairing"""
    
    if percentage_matrix.empty:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=16)
        ax.set_title(title)
        return fig
    
    # Filter to top genes if specified
    plot_percentage = percentage_matrix.copy()
    plot_count = count_matrix.copy()
    
    if top_heavy:
        available_heavy = [gene for gene in top_heavy if gene in plot_percentage.index]
        if available_heavy:
            plot_percentage = plot_percentage.loc[available_heavy]
            plot_count = plot_count.loc[available_heavy]
    
    if top_light:
        available_light = [gene for gene in top_light if gene in plot_percentage.columns]
        if available_light:
            plot_percentage = plot_percentage[available_light]
            plot_count = plot_count[available_light]
    
    # Create annotations combining percentage and count
    annot_data = plot_percentage.round(1).astype(str) + '%\n(' + plot_count.astype(str) + ')'
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=figsize)
    
    sns.heatmap(plot_percentage, 
                annot=annot_data, 
                cmap='YlOrRd', 
                fmt='', 
                ax=ax,
                cbar_kws={'label': 'Percentage of Heavy Gene Usage'},
                linewidths=0.5)
    
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('Generated Light Chain V Gene', fontsize=12)
    ax.set_ylabel('Heavy Chain V Gene', fontsize=12)
    
    # Rotate labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    # Add sample size information
    total_sequences = plot_count.sum().sum()
    unique_heavy_genes = len(plot_percentage.index)
    unique_light_genes = len(plot_percentage.columns)
    
    sample_info = f"Total sequences: {total_sequences}\n" \
                  f"Heavy genes: {unique_heavy_genes}, Light genes: {unique_light_genes}"
    ax.text(0.02, 0.98, sample_info, transform=ax.transAxes, 
            verticalalignment='top', fontsize=10, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    return fig

def analyze_pairing_preferences(percentage_matrix, count_matrix, gene_level):
    """Analyze and print pairing preferences"""
    
    print(f"\n=== Pairing Analysis for {gene_level} level ===")
    
    # Find heavy genes with strong preferences (>50% for any light gene)
    strong_preferences = []
    for heavy_gene in percentage_matrix.index:
        max_pct = percentage_matrix.loc[heavy_gene].max()
        if max_pct > 50:
            preferred_light = percentage_matrix.loc[heavy_gene].idxmax()
            count = count_matrix.loc[heavy_gene, preferred_light]
            total_count = count_matrix.loc[heavy_gene].sum()
            strong_preferences.append({
                'heavy_gene': heavy_gene,
                'preferred_light': preferred_light,
                'percentage': max_pct,
                'count': count,
                'total_sequences': total_count
            })
    
    if strong_preferences:
        print("Heavy genes with strong light gene preferences (>50%):")
        for pref in sorted(strong_preferences, key=lambda x: x['percentage'], reverse=True):
            print(f"  {pref['heavy_gene']} → {pref['preferred_light']}: "
                  f"{pref['percentage']:.1f}% ({pref['count']}/{pref['total_sequences']} sequences)")
    else:
        print("No heavy genes show strong preferences (>50%) for specific light genes")
    
    # Find the most diverse heavy genes (spread across many light genes)
    diversity_scores = []
    for heavy_gene in percentage_matrix.index:
        # Count how many light genes this heavy gene pairs with (>5% threshold)
        num_partners = (percentage_matrix.loc[heavy_gene] > 5).sum()
        total_sequences = count_matrix.loc[heavy_gene].sum()
        diversity_scores.append({
            'heavy_gene': heavy_gene,
            'num_partners': num_partners,
            'total_sequences': total_sequences
        })
    
    print(f"\nMost diverse heavy genes (pairing with multiple light genes >5%):")
    for div in sorted(diversity_scores, key=lambda x: x['num_partners'], reverse=True)[:5]:
        print(f"  {div['heavy_gene']}: pairs with {div['num_partners']} light genes "
              f"({div['total_sequences']} total sequences)")

def create_all_pairing_heatmaps(groups, save_plots=True, dpi=500):
    """Create all pairing heatmaps for all groups and gene levels (only full heatmaps)"""
    
    gene_levels = ['full', 'simplified', 'family']
    level_names = ['Full V Gene', 'Simplified V Gene', 'V Gene Family']
    
    all_figures = {}
    
    for group_name, group_df in groups.items():
        print(f"\n{'='*50}")
        print(f"Processing {group_name}...")
        print(f"{'='*50}")
        
        if len(group_df) == 0:
            print(f"No data for {group_name}")
            continue
        
        group_figures = {}
        
        for level, level_name in zip(gene_levels, level_names):
            print(f"\n--- {level_name} ---")
            
            # Create pairing matrices
            count_matrix, percentage_matrix = create_heavy_light_pairing_matrix(group_df, level)
            
            if count_matrix.empty:
                continue
            
            # Analyze pairing preferences
            analyze_pairing_preferences(percentage_matrix, count_matrix, level_name)
            
            # Create heatmap with all genes only
            fig_all = create_pairing_heatmap(
                percentage_matrix, count_matrix,
                f"{group_name} - {level_name}\nHeavy-Light V Gene Pairing (All Genes)",
                figsize=(max(14, len(percentage_matrix.columns) * 0.8), 
                        max(10, len(percentage_matrix.index) * 0.5))
            )
            
            # Save the plot with high DPI
            if save_plots:
                # Create safe filename
                safe_group_name = group_name.replace(':', '').replace('(', '').replace(')', '').replace(' ', '_')
                filename = f"{safe_group_name}_{level}_all_genes_heatmap.png"
                fig_all.savefig(filename, dpi=dpi, bbox_inches='tight', facecolor='white')
                print(f"  Saved: {filename}")
            
            group_figures[f"{level}_all"] = fig_all
        
        all_figures[group_name] = group_figures
    
    return all_figures


In [None]:
def main(csv_file, save_plots=True, dpi=500):
    """Main analysis function"""
    
    print("Heavy-Light V Gene Pairing Analysis")
    print("="*50)
    
    # Load data
    df = load_and_prepare_data(csv_file)
    
    # Create groups
    groups = create_groups(df)
    
    # Create all heatmaps (only full versions)
    all_figures = create_all_pairing_heatmaps(groups, save_plots=save_plots, dpi=dpi)
    
    # Display all figures
    for group_name, group_figures in all_figures.items():
        print(f"\n{'='*60}")
        print(f"DISPLAYING HEATMAPS FOR {group_name}")
        print(f"{'='*60}")
        for fig_name, fig in group_figures.items():
            plt.figure(fig.number)
            plt.show()
    
    print(f"\n{'='*60}")
    print("ANALYSIS COMPLETE")
    print(f"Total heatmaps created: {sum(len(figs) for figs in all_figures.values())}")
    if save_plots:
        print(f"All plots saved with DPI: {dpi}")
    print(f"{'='*60}")
    
    return all_figures

# Save plots with high DPI (500)
figures = main('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv', save_plots=True, dpi=200)

# Or run without saving plots
# figures = main('your_file.csv', save_plots=False)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import re

def load_and_prepare_data(csv_file):
    """Load CSV and prepare data for analysis"""
    df = pd.read_csv(csv_file)
    
    print(f"Dataset shape: {df.shape}")
    print(f"Unique heavy chains: {df['heavy_chain_number'].nunique()}")
    print(f"Total rows: {len(df)}")
    
    return df

def create_groups(df):
    """Create the 4 groups based on predicted labels"""
    
    # Group 1: Both memory B cell origin (both labels = 1)
    group1 = df[(df['predicted_gen_light_seq_label'] == 1) & 
                (df['predicted_input_heavy_seq_label'] == 1)]
    
    # Group 2: Both naive B cell origin (both labels = 0)
    group2 = df[(df['predicted_gen_light_seq_label'] == 0) & 
                (df['predicted_input_heavy_seq_label'] == 0)]
    
    # Group 3: Heavy memory, Generated light naive (heavy=1, gen_light=0)
    group3 = df[(df['predicted_input_heavy_seq_label'] == 1) & 
                (df['predicted_gen_light_seq_label'] == 0)]
    
    # Group 4: Heavy naive, Generated light memory (heavy=0, gen_light=1)
    group4 = df[(df['predicted_input_heavy_seq_label'] == 0) & 
                (df['predicted_gen_light_seq_label'] == 1)]
    
    groups = {
        'Group 1 (H:Memory, L:Memory)': group1,
        'Group 2 (H:Naive, L:Naive)': group2,
        'Group 3 (H:Memory, L:Naive)': group3,
        'Group 4 (H:Naive, L:Memory)': group4
    }
    
    return groups

def analyze_specific_switch(df, target_heavy='IGHV3-20*04', target_light='IGKV1-39*01'):
    """Analyze the specific switch from target heavy to target light gene"""
    
    print(f"\n{'='*80}")
    print(f"ANALYZING SWITCH: {target_heavy} → {target_light}")
    print(f"{'='*80}")
    
    # First, let's see how many sequences we have with the target heavy gene
    target_heavy_sequences = df[df['heavy_chain_gene_name'] == target_heavy]
    
    print(f"\nSequences with {target_heavy}:")
    print(f"  Total sequences: {len(target_heavy_sequences)}")
    print(f"  Unique heavy chains: {target_heavy_sequences['heavy_chain_number'].nunique()}")
    
    if len(target_heavy_sequences) == 0:
        print(f"No sequences found with heavy gene {target_heavy}")
        return {}
    
    # Analyze for each group
    groups = create_groups(df)
    results = {}
    
    for group_name, group_df in groups.items():
        print(f"\n--- {group_name} ---")
        
        # Filter for target heavy gene in this group
        group_target_heavy = group_df[group_df['heavy_chain_gene_name'] == target_heavy]
        
        if len(group_target_heavy) == 0:
            print(f"  No {target_heavy} sequences in this group")
            results[group_name] = {
                'total_sequences': 0,
                'unique_heavy_chains': 0,
                'generated_switch_count': 0,
                'generated_switch_percentage': 0,
                'true_switch_count': 0,
                'true_switch_percentage': 0,
                'generated_switch_details': [],
                'true_switch_details': []
            }
            continue
        
        total_sequences = len(group_target_heavy)
        unique_heavy_chains = group_target_heavy['heavy_chain_number'].nunique()
        
        print(f"  Total {target_heavy} sequences: {total_sequences}")
        print(f"  Unique {target_heavy} heavy chains: {unique_heavy_chains}")
        
        # Analyze generated light chain switches
        generated_switches = group_target_heavy[
            group_target_heavy['gen_light_gene_name'] == target_light
        ]
        generated_switch_count = len(generated_switches)
        generated_switch_percentage = (generated_switch_count / total_sequences) * 100 if total_sequences > 0 else 0
        
        print(f"  Generated light switches to {target_light}: {generated_switch_count}/{total_sequences} ({generated_switch_percentage:.1f}%)")
        
        # Get details of generated switches
        generated_switch_details = []
        if generated_switch_count > 0:
            for _, row in generated_switches.iterrows():
                generated_switch_details.append({
                    'heavy_chain_number': row['heavy_chain_number'],
                    'gen_light_chain_number': row['gen_light_chain_number'],
                    'true_light_gene': row['true_light_gene_name']
                })
        
        # Analyze true light chain switches (only unique heavy chains)
        unique_heavy_target = group_target_heavy.drop_duplicates(subset=['heavy_chain_number'])
        true_switches = unique_heavy_target[
            unique_heavy_target['true_light_gene_name'] == target_light
        ]
        true_switch_count = len(true_switches)
        true_switch_percentage = (true_switch_count / unique_heavy_chains) * 100 if unique_heavy_chains > 0 else 0
        
        print(f"  True light switches to {target_light}: {true_switch_count}/{unique_heavy_chains} ({true_switch_percentage:.1f}%)")
        
        # Get details of true switches
        true_switch_details = []
        if true_switch_count > 0:
            for _, row in true_switches.iterrows():
                true_switch_details.append({
                    'heavy_chain_number': row['heavy_chain_number'],
                    'true_light_gene': row['true_light_gene_name']
                })
        
        results[group_name] = {
            'total_sequences': total_sequences,
            'unique_heavy_chains': unique_heavy_chains,
            'generated_switch_count': generated_switch_count,
            'generated_switch_percentage': generated_switch_percentage,
            'true_switch_count': true_switch_count,
            'true_switch_percentage': true_switch_percentage,
            'generated_switch_details': generated_switch_details,
            'true_switch_details': true_switch_details
        }
    
    return results

def analyze_heavy_chain_specific_patterns(df, target_heavy='IGHV3-20*04'):
    """Analyze what light genes the target heavy gene typically pairs with"""
    
    print(f"\n{'='*80}")
    print(f"LIGHT GENE PREFERENCES FOR {target_heavy}")
    print(f"{'='*80}")
    
    target_heavy_sequences = df[df['heavy_chain_gene_name'] == target_heavy]
    
    if len(target_heavy_sequences) == 0:
        print(f"No sequences found with heavy gene {target_heavy}")
        return
    
    groups = create_groups(df)
    
    for group_name, group_df in groups.items():
        print(f"\n--- {group_name} ---")
        
        group_target_heavy = group_df[group_df['heavy_chain_gene_name'] == target_heavy]
        
        if len(group_target_heavy) == 0:
            print(f"  No {target_heavy} sequences in this group")
            continue
        
        # Generated light gene distribution
        gen_light_dist = group_target_heavy['gen_light_gene_name'].value_counts()
        print(f"  Generated light gene distribution (top 10):")
        for i, (gene, count) in enumerate(gen_light_dist.head(10).items()):
            percentage = (count / len(group_target_heavy)) * 100
            print(f"    {i+1}. {gene}: {count} ({percentage:.1f}%)")
        
        # True light gene distribution (unique heavy chains only)
        unique_heavy_target = group_target_heavy.drop_duplicates(subset=['heavy_chain_number'])
        true_light_dist = unique_heavy_target['true_light_gene_name'].value_counts()
        print(f"  True light gene distribution:")
        for i, (gene, count) in enumerate(true_light_dist.items()):
            percentage = (count / len(unique_heavy_target)) * 100
            print(f"    {i+1}. {gene}: {count} ({percentage:.1f}%)")

def create_switch_summary_table(results, target_heavy='IGHV3-20*04', target_light='IGKV1-39*01'):
    """Create a summary table of switch analysis results"""
    
    summary_data = []
    
    for group_name, result in results.items():
        summary_data.append({
            'Group': group_name,
            'Total_Sequences': result['total_sequences'],
            'Unique_Heavy_Chains': result['unique_heavy_chains'],
            'Generated_Switches': result['generated_switch_count'],
            'Generated_Switch_Pct': result['generated_switch_percentage'],
            'True_Switches': result['true_switch_count'],
            'True_Switch_Pct': result['true_switch_percentage']
        })
    
    summary_df = pd.DataFrame(summary_data)
    
    print(f"\n{'='*80}")
    print(f"SUMMARY TABLE: {target_heavy} → {target_light} SWITCHES")
    print(f"{'='*80}")
    print(summary_df.to_string(index=False))
    
    return summary_df

def visualize_switch_analysis(results, target_heavy='IGHV3-20*04', target_light='IGKV1-39*01'):
    """Create visualizations of the switch analysis"""
    
    # Prepare data for visualization
    groups = list(results.keys())
    generated_percentages = [results[group]['generated_switch_percentage'] for group in groups]
    true_percentages = [results[group]['true_switch_percentage'] for group in groups]
    
    # Create figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Generated light chain switches
    bars1 = ax1.bar(range(len(groups)), generated_percentages, 
                    color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A'], alpha=0.8)
    ax1.set_xlabel('Groups')
    ax1.set_ylabel('Switch Percentage (%)')
    ax1.set_title(f'Generated Light Chain Switches\n{target_heavy} → {target_light}')
    ax1.set_xticks(range(len(groups)))
    ax1.set_xticklabels([g.replace(' ', '\n') for g in groups], fontsize=10)
    ax1.set_ylim(0, max(generated_percentages + [1]) * 1.1)
    
    # Add value labels on bars
    for i, (bar, pct) in enumerate(zip(bars1, generated_percentages)):
        if pct > 0:
            count = results[groups[i]]['generated_switch_count']
            total = results[groups[i]]['total_sequences']
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{pct:.1f}%\n({count}/{total})', 
                    ha='center', va='bottom', fontsize=9)
    
    # Plot 2: True light chain switches
    bars2 = ax2.bar(range(len(groups)), true_percentages, 
                    color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A'], alpha=0.8)
    ax2.set_xlabel('Groups')
    ax2.set_ylabel('Switch Percentage (%)')
    ax2.set_title(f'True Light Chain Switches\n{target_heavy} → {target_light}')
    ax2.set_xticks(range(len(groups)))
    ax2.set_xticklabels([g.replace(' ', '\n') for g in groups], fontsize=10)
    ax2.set_ylim(0, max(true_percentages + [1]) * 1.1)
    
    # Add value labels on bars
    for i, (bar, pct) in enumerate(zip(bars2, true_percentages)):
        if pct > 0:
            count = results[groups[i]]['true_switch_count']
            total = results[groups[i]]['unique_heavy_chains']
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{pct:.1f}%\n({count}/{total})', 
                    ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    return fig



In [None]:
def main(csv_file, target_heavy='IGHV3-20*04', target_light='IGKV1-39*01'):
    """Main analysis function for specific switch analysis"""
    
    print(f"SPECIFIC SWITCH ANALYSIS")
    print(f"Target Switch: {target_heavy} → {target_light}")
    print("="*80)
    
    # Load data
    df = load_and_prepare_data(csv_file)
    
    # Analyze the specific switch
    results = analyze_specific_switch(df, target_heavy, target_light)
    
    # Analyze general patterns for the target heavy gene
    analyze_heavy_chain_specific_patterns(df, target_heavy)
    
    # Create summary table
    summary_df = create_switch_summary_table(results, target_heavy, target_light)
    
    # Create visualization
    fig = visualize_switch_analysis(results, target_heavy, target_light)
    plt.show()
    
    # Print detailed results
    print(f"\n{'='*80}")
    print("DETAILED SWITCH ANALYSIS RESULTS")
    print(f"{'='*80}")
    
    for group_name, result in results.items():
        if result['total_sequences'] > 0:
            print(f"\n{group_name}:")
            print(f"  Sample size: {result['total_sequences']} sequences from {result['unique_heavy_chains']} unique heavy chains")
            print(f"  Generated switches: {result['generated_switch_count']}/{result['total_sequences']} ({result['generated_switch_percentage']:.1f}%)")
            print(f"  True switches: {result['true_switch_count']}/{result['unique_heavy_chains']} ({result['true_switch_percentage']:.1f}%)")
            
            if result['generated_switch_count'] > 0:
                print(f"  Generated switch details:")
                for detail in result['generated_switch_details'][:5]:  # Show first 5
                    print(f"    Heavy chain {detail['heavy_chain_number']}, Gen light {detail['gen_light_chain_number']}, True light gene: {detail['true_light_gene']}")
                if len(result['generated_switch_details']) > 5:
                    print(f"    ... and {len(result['generated_switch_details']) - 5} more")
    
    return results, summary_df, fig

# Example usage:
results, summary, fig = main('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv', 'IGHV3-20*04', 'IGKV1-39*01')

# You can also analyze different switches:
#results, summary, fig = main('/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions_merged_genes.csv', 'IGHV1-69-2*01', 'IGKV6D-41*01')
