In [2]:

import pandas as pd
import re
from typing import List, Dict, Tuple
from collections import defaultdict
import os


In [2]:
def parse_antibody_data(file_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Parse antibody generation data from text file and create two DataFrames:
    1. Detailed sequences DataFrame
    2. Summary statistics DataFrame
    
    Args:
        file_path (str): Path to the input text file
        
    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: (sequences_df, summary_df)
    """
    
    with open(file_path, 'r') as file:
        content = file.read()
    
    # Initialize lists to store parsed data
    sequences_data = []
    summary_data = []
    
    # Split content by heavy chains
    heavy_chain_sections = re.split(r'Heavy chain (\d+)', content)[1:]  # Skip first empty element
    
    for i in range(0, len(heavy_chain_sections), 2):
        heavy_chain_num = int(heavy_chain_sections[i])
        section_content = heavy_chain_sections[i + 1]
        
        # Parse individual sequences for this heavy chain
        sequence_blocks = re.findall(
            r'Generated sequence (\d+):\s*\n'
            r'True Light Sequence: ([A-Z]+)\s*\n'
            r'Generated Light Sequence: ([A-Z]+)\s*\n'
            r'Input Heavy Sequence: ([A-Z]+)\s*\n'
            r'BLOSUM Score: ([0-9.]+)\s*\n'
            r'Similarity Percentage: ([0-9.]+)%\s*\n'
            r'Perplexity: ([0-9.]+)',
            section_content
        )
        
        # Store sequence data
        for seq_data in sequence_blocks:
            sequences_data.append({
                'heavy_chain_number': heavy_chain_num,
                'gen_light_chain_number': int(seq_data[0]),
                'true_light_seq': seq_data[1],
                'gen_light_seq': seq_data[2],
                'input_heavy_seq': seq_data[3],
                'BLOSUM': float(seq_data[4]),
                'similarity': float(seq_data[5]),
                'perplexity': float(seq_data[6])
            })
        
        # Parse summary data for this heavy chain
        summary_match = re.search(
            r'--- Summary for Heavy Chain \d+ ---\s*\n'
            r'Average BLOSUM Score for \d+ sequences: ([0-9.]+)\s*\n'
            r'Best BLOSUM Score: ([0-9.]+)\s*\n'
            r'Average Similarity for \d+ sequences: ([0-9.]+)%\s*\n'
            r'Best Similarity: ([0-9.]+)%\s*\n'
            r'Average Perplexity for \d+ sequences: ([0-9.]+)\s*\n'
            r'Best Perplexity: ([0-9.]+)',
            section_content
        )
        
        if summary_match:
            summary_data.append({
                'heavy_chain_number': heavy_chain_num,
                'avg_blosum': float(summary_match.group(1)),
                'best_blosum': float(summary_match.group(2)),
                'avg_similarity': float(summary_match.group(3)),
                'best_similarity': float(summary_match.group(4)),
                'average_perplexity': float(summary_match.group(5)),
                'best_perplexity': float(summary_match.group(6))
            })
    
    # Create DataFrames
    sequences_df = pd.DataFrame(sequences_data)
    summary_df = pd.DataFrame(summary_data)
    
    return sequences_df, summary_df


In [3]:
def save_to_csv(sequences_df: pd.DataFrame, summary_df: pd.DataFrame, 
                sequences_file: str = 'antibody_sequences.csv', 
                summary_file: str = 'antibody_summary.csv',
                output_dir: str = './') -> None:
    """
    Save DataFrames to CSV files
    
    Args:
        sequences_df (pd.DataFrame): Detailed sequences data
        summary_df (pd.DataFrame): Summary statistics data
        sequences_file (str): Output filename for sequences CSV
        summary_file (str): Output filename for summary CSV
    """
    sequences_df.to_csv(output_dir + sequences_file, index=False)
    summary_df.to_csv(output_dir + summary_file, index=False)
    
    print(f"Sequences data saved to: {sequences_file}")
    print(f"Summary data saved to: {summary_file}")
    print(f"\nSequences DataFrame shape: {sequences_df.shape}")
    print(f"Summary DataFrame shape: {summary_df.shape}")

In [4]:

#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/logs/full_eval_generate_multiple_light_seqs_203267.o"  
    
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_eval_generate_multiple_light_seqs_203276_10k.o"  

input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/logs/full_eval_generate_multiple_light_seqs_203276.o"  



In [5]:

sequences_df, summary_df = parse_antibody_data(input_file)
        
# Display preview of the data
print("=== SEQUENCES DATA PREVIEW ===")
print(sequences_df.head())
print(f"\nColumns: {list(sequences_df.columns)}")
        
print("\n=== SUMMARY DATA PREVIEW ===")
print(summary_df.head())
print(f"\nColumns: {list(summary_df.columns)}")
        
# Save to CSV files
save_to_csv(sequences_df, summary_df, sequences_file = "full_eval_generate_multiple_light_seqs_203276.csv", summary_file = "summary_full_eval_generate_multiple_light_seqs_203276.csv" ,output_dir='/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/')
        
# Optional: Display some basic statistics
print(f"\n=== BASIC STATISTICS ===")
print(f"Total sequences processed: {len(sequences_df)}")
print(f"Number of heavy chains: {len(summary_df)}")
print(f"Average BLOSUM score across all sequences: {sequences_df['BLOSUM'].mean():.2f}")
print(f"Average similarity across all sequences: {sequences_df['similarity'].mean():.2f}%")


=== SEQUENCES DATA PREVIEW ===
   heavy_chain_number  gen_light_chain_number  \
0                   1                       1   
1                   1                       2   
2                   1                       3   
3                   1                       4   
4                   1                       5   

                                      true_light_seq  \
0  DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
1  DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
2  DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
3  DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
4  DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   

                                       gen_light_seq  \
0  QSALTQPPSASGSPGQSVTISCTGTSSDIGGYNFVSWYQQHPGKAP...   
1  SSELTQDPAVSVALGQTVRITCQGDSLRSYYASWYQQKPGQAPLLV...   
2  DIVMTQSPDSLAVSLGERATINCKSSQSVLYSPNNKNYLGWYQQKP...   
3  EIVLTQSPATLSLSPGERATLSCRASQSVGTYLAWYQQKPGQAPRL...   
4  SSELTQDPAVSVALGQTVRITCQGDSLRSYYASWYQQKPGQAPVVV...   



In [8]:
import pandas as pd
import os

def split_sequences_to_csv(input_csv_path: str, 
                          matching_csv_path: str = None,
                          non_matching_csv_path: str = None) -> tuple:
    """
    Split sequences into two CSV files based on whether predicted_gen_light_seq_label 
    and predicted_input_heavy_seq_label match or not.
    
    Args:
        input_csv_path (str): Path to the input CSV file
        matching_csv_path (str, optional): Path for matching sequences CSV file
        non_matching_csv_path (str, optional): Path for non-matching sequences CSV file
    
    Returns:
        tuple: (matching_df, non_matching_df) - DataFrames for matching and non-matching sequences
    """
    
    # Read the CSV file
    print(f"Reading data from: {input_csv_path}")
    df = pd.read_csv(input_csv_path)
    
    print(f"Original dataset shape: {df.shape}")
    print(f"Original columns: {list(df.columns)}")
    
    # Check if required columns exist
    required_columns = ['predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']
    missing_columns = [col for col in required_columns if col not in df.columns]
    
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Show distribution of labels before filtering
    print(f"\n=== LABEL DISTRIBUTION BEFORE SPLITTING ===")
    print("predicted_gen_light_seq_label distribution:")
    print(df['predicted_gen_light_seq_label'].value_counts())
    print("\npredicted_input_heavy_seq_label distribution:")
    print(df['predicted_input_heavy_seq_label'].value_counts())
    
    # Split sequences into matching and non-matching
    matching_df = df[df['predicted_gen_light_seq_label'] == df['predicted_input_heavy_seq_label']].copy()
    non_matching_df = df[df['predicted_gen_light_seq_label'] != df['predicted_input_heavy_seq_label']].copy()
    
    print(f"\n=== SPLITTING RESULTS ===")
    print(f"Matching sequences: {len(matching_df)}")
    print(f"Non-matching sequences: {len(non_matching_df)}")
    print(f"Total sequences: {len(df)}")
    print(f"Percentage matching: {(len(matching_df) / len(df)) * 100:.2f}%")
    print(f"Percentage non-matching: {(len(non_matching_df) / len(df)) * 100:.2f}%")
    
    # Set default output paths if not provided
    if matching_csv_path is None:
        base_name = os.path.splitext(input_csv_path)[0]
        matching_csv_path = f"{base_name}_matching.csv"
    
    if non_matching_csv_path is None:
        base_name = os.path.splitext(input_csv_path)[0]
        non_matching_csv_path = f"{base_name}_non_matching.csv"
    
    # Save matching sequences to CSV
    matching_df.to_csv(matching_csv_path, index=False)
    print(f"\nMatching sequences saved to: {matching_csv_path}")
    
    # Save non-matching sequences to CSV
    non_matching_df.to_csv(non_matching_csv_path, index=False)
    print(f"Non-matching sequences saved to: {non_matching_csv_path}")
    
    # Show distribution of matching labels
    if len(matching_df) > 0:
        print(f"\n=== MATCHING LABELS DISTRIBUTION ===")
        print("Distribution of matching labels:")
        print(matching_df['predicted_gen_light_seq_label'].value_counts())
        
        # Show some examples of the matching data
        print(f"\n=== SAMPLE OF MATCHING DATA ===")
        print(matching_df[['heavy_chain_number', 'gen_light_chain_number', 
                          'predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']].head(5))
    
    # Show distribution of non-matching labels
    if len(non_matching_df) > 0:
        print(f"\n=== NON-MATCHING LABELS DISTRIBUTION ===")
        print("Light chain labels in non-matching:")
        print(non_matching_df['predicted_gen_light_seq_label'].value_counts())
        print("\nHeavy chain labels in non-matching:")
        print(non_matching_df['predicted_input_heavy_seq_label'].value_counts())
        
        # Show some examples of the non-matching data
        print(f"\n=== SAMPLE OF NON-MATCHING DATA ===")
        print(non_matching_df[['heavy_chain_number', 'gen_light_chain_number', 
                              'predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']].head(5))
    
    return matching_df, non_matching_df

In [9]:
input_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions.csv"
matching_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions.csv"
non_matching_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.csv"

split_sequences_to_csv(input_csv_path, 
                       matching_csv_path=matching_csv_path, 
                       non_matching_csv_path=non_matching_csv_path)

Reading data from: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/full_eval_generate_multiple_light_seqs_203276_cls_predictions.csv
Original dataset shape: (588289, 12)
Original columns: ['heavy_chain_number', 'gen_light_chain_number', 'true_light_seq', 'gen_light_seq', 'input_heavy_seq', 'BLOSUM', 'similarity', 'perplexity', 'label', 'predicted_true_light_seq_label', 'predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']

=== LABEL DISTRIBUTION BEFORE SPLITTING ===
predicted_gen_light_seq_label distribution:
predicted_gen_light_seq_label
0    358453
1    229836
Name: count, dtype: int64

predicted_input_heavy_seq_label distribution:
predicted_input_heavy_seq_label
1    449521
0    138768
Name: count, dtype: int64

=== SPLITTING RESULTS ===
Matching sequences: 341304
Non-matching sequences: 246985
Total sequences: 588289
Percentage matching: 58.02%
Percentage non-matching: 41.98%

Matchi

(        heavy_chain_number  gen_light_chain_number  \
 0                        1                       1   
 2                        1                       3   
 3                        1                       4   
 6                        1                       7   
 7                        1                       8   
 ...                    ...                     ...   
 588277               58838                       9   
 588278               58838                      10   
 588282               58839                       4   
 588284               58839                       6   
 588286               58839                       8   
 
                                            true_light_seq  \
 0       DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
 2       DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
 3       DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
 6       DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...   
 7       DIQMTQSPSTLSASVGDRV

In [None]:
def filter_matching_labels(input_csv_path: str, output_csv_path: str = None) -> pd.DataFrame:
    """
    Filter sequences to keep only those where predicted_gen_light_label 
    and predicted_input_heavy_label are the same.
    
    Args:
        input_csv_path (str): Path to the input CSV file
        output_csv_path (str, optional): Path for the filtered output CSV file.
                                       If None, defaults to adding '_filtered' to input filename.
    
    Returns:
        pd.DataFrame: Filtered DataFrame containing only matching label sequences
    """
    
    # Read the CSV file
    print(f"Reading data from: {input_csv_path}")
    df = pd.read_csv(input_csv_path)
    
    print(f"Original dataset shape: {df.shape}")
    print(f"Original columns: {list(df.columns)}")
    
    # Check if required columns exist
    required_columns = ['predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']
    missing_columns = [col for col in required_columns if col not in df.columns]
    
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Show distribution of labels before filtering
    print(f"\n=== LABEL DISTRIBUTION BEFORE FILTERING ===")
    print("predicted_gen_light_seq_label distribution:")
    print(df['predicted_gen_light_seq_label'].value_counts())
    print("\predicted_input_heavy_seq_label distribution:")
    print(df['predicted_input_heavy_seq_label'].value_counts())
    
    # Filter sequences where both predicted labels match
    filtered_df = df[df['predicted_gen_light_seq_label'] == df['predicted_input_heavy_seq_label']].copy()
    
    print(f"\n=== FILTERING RESULTS ===")
    print(f"Filtered dataset shape: {filtered_df.shape}")
    print(f"Sequences removed: {len(df) - len(filtered_df)}")
    print(f"Sequences kept: {len(filtered_df)}")
    print(f"Percentage kept: {(len(filtered_df) / len(df)) * 100:.2f}%")
    
    # Show distribution of matching labels
    if len(filtered_df) > 0:
        print(f"\n=== MATCHING LABELS DISTRIBUTION ===")
        print("Distribution of matching labels:")
        print(filtered_df['predicted_gen_light_seq_label'].value_counts())
        
        # Show some examples of the filtered data
        print(f"\n=== SAMPLE OF FILTERED DATA ===")
        print(filtered_df[['heavy_chain_number', 'gen_light_chain_number', 
                          'predicted_gen_light_seq_label', 'predicted_input_heavy_seq_label']].head(10))
    else:
        print("WARNING: No sequences found with matching labels!")
    
    # Save filtered data to CSV
    if output_csv_path is None:
        # Create default output filename
        if input_csv_path.endswith('.csv'):
            output_csv_path = input_csv_path.replace('.csv', '_filtered.csv')
        else:
            output_csv_path = input_csv_path + '_filtered.csv'
    
    filtered_df.to_csv(output_csv_path, index=False)
    print(f"\nFiltered data saved to: {output_csv_path}")
    
    return filtered_df

In [5]:
def analyze_label_mismatches(df: pd.DataFrame) -> pd.DataFrame:
    """
    Analyze sequences where labels don't match to understand the differences.
    
    Args:
        df (pd.DataFrame): Original DataFrame with all sequences
        
    Returns:
        pd.DataFrame: DataFrame containing only mismatched sequences
    """
    mismatched_df = df[df['predicted_gen_light_label'] != df['predicted_input_heavy_label']].copy()
    
    print(f"\n=== MISMATCH ANALYSIS ===")
    print(f"Number of mismatched sequences: {len(mismatched_df)}")
    
    if len(mismatched_df) > 0:
        print("\nMismatch patterns:")
        mismatch_patterns = mismatched_df.groupby(['predicted_gen_light_label', 'predicted_input_heavy_label']).size()
        for (gen_label, heavy_label), count in mismatch_patterns.items():
            print(f"  Gen Light: {gen_label} vs Input Heavy: {heavy_label} -> {count} sequences")
        
        print("\nSample of mismatched sequences:")
        print(mismatched_df[['heavy_chain_number', 'gen_light_chain_number', 
                            'predicted_gen_light_label', 'predicted_input_heavy_label', 
                            'BLOSUM', 'similarity']].head())
    
    return mismatched_df

def get_filtering_statistics(original_df: pd.DataFrame, filtered_df: pd.DataFrame) -> dict:
    """
    Generate detailed statistics about the filtering process.
    
    Args:
        original_df (pd.DataFrame): Original DataFrame
        filtered_df (pd.DataFrame): Filtered DataFrame
        
    Returns:
        dict: Dictionary containing filtering statistics
    """
    stats = {
        'original_count': len(original_df),
        'filtered_count': len(filtered_df),
        'removed_count': len(original_df) - len(filtered_df),
        'retention_rate': (len(filtered_df) / len(original_df)) * 100 if len(original_df) > 0 else 0
    }
    
    # Statistics by heavy chain
    heavy_chains_original = original_df['heavy_chain_number'].unique()
    heavy_chains_filtered = filtered_df['heavy_chain_number'].unique()
    
    stats['heavy_chains_original'] = len(heavy_chains_original)
    stats['heavy_chains_retained'] = len(heavy_chains_filtered)
    
    # Average scores comparison
    if len(filtered_df) > 0:
        stats['avg_blosum_original'] = original_df['BLOSUM'].mean()
        stats['avg_blosum_filtered'] = filtered_df['BLOSUM'].mean()
        stats['avg_similarity_original'] = original_df['similarity'].mean()
        stats['avg_similarity_filtered'] = filtered_df['similarity'].mean()
    
    return stats

In [6]:
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_eval_generate_multiple_light_seqs_203267_predictions.csv"  
input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_eval_generate_multiple_light_seqs_203276_10k_predictions.csv"  

output_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_full_eval_generate_multiple_light_seqs_203276_10k_predictions.csv"  

# Filter the data
filtered_df = filter_matching_labels(input_file, output_file)
        
# Read original data for comparison
original_df = pd.read_csv(input_file)
        
# Analyze mismatches
mismatched_df = analyze_label_mismatches(original_df)
        
# Get detailed statistics
stats = get_filtering_statistics(original_df, filtered_df)
        
print(f"\n=== DETAILED STATISTICS ===")
print(f"Original sequences: {stats['original_count']}")
print(f"Filtered sequences: {stats['filtered_count']}")
print(f"Removed sequences: {stats['removed_count']}")
print(f"Retention rate: {stats['retention_rate']:.2f}%")
print(f"Heavy chains in original: {stats['heavy_chains_original']}")
print(f"Heavy chains retained: {stats['heavy_chains_retained']}")
        
if 'avg_blosum_original' in stats:
    print(f"\nAverage BLOSUM score (original): {stats['avg_blosum_original']:.2f}")
    print(f"Average BLOSUM score (filtered): {stats['avg_blosum_filtered']:.2f}")
    print(f"Average similarity (original): {stats['avg_similarity_original']:.2f}%")
    print(f"Average similarity (filtered): {stats['avg_similarity_filtered']:.2f}%")



Reading data from: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_eval_generate_multiple_light_seqs_203276_10k_predictions.csv
Original dataset shape: (107229, 11)
Original columns: ['heavy_chain_number', 'gen_light_chain_number', 'true_light_seq', 'gen_light_seq', 'input_heavy_seq', 'BLOSUM', 'similarity', 'perplexity', 'label', 'predicted_input_heavy_label', 'predicted_gen_light_label']

=== LABEL DISTRIBUTION BEFORE FILTERING ===
predicted_gen_light_label distribution:
predicted_gen_light_label
0    71959
1    35270
Name: count, dtype: int64

predicted_input_heavy_label distribution:
predicted_input_heavy_label
1    84509
0    22720
Name: count, dtype: int64

=== FILTERING RESULTS ===
Filtered dataset shape: (53544, 11)
Sequences removed: 53685
Sequences kept: 53544
Percentage kept: 49.93%

=== MATCHING LABELS DISTRIBUTION ===
Distribution of matching labels:
predicted_gen_light_label
1    33047
0    20497
Name: count, dtype:

In [10]:


def csv_to_fasta(csv_file_path: str, output_fasta_path: str = None) -> None:
    """
    Convert CSV antibody data to FASTA format.
    
    For each heavy chain:
    1. Heavy chain sequence
    2. True light chain sequence  
    3. All generated light chain sequences
    
    Args:
        csv_file_path (str): Path to the input CSV file
        output_fasta_path (str, optional): Path for output FASTA file.
                                         If None, uses input filename with .fasta extension
    """
    
    # Read the CSV file
    print(f"Reading data from: {csv_file_path}")
    df = pd.read_csv(csv_file_path)
    
    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    
    # Check required columns
    required_columns = ['heavy_chain_number', 'gen_light_chain_number', 
                       'true_light_seq', 'gen_light_seq', 'input_heavy_seq']
    missing_columns = [col for col in required_columns if col not in df.columns]
    
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Group data by heavy chain
    heavy_chain_data = defaultdict(lambda: {
        'input_heavy_seq': '',
        'true_light_seq': '',
        'generated_sequences': []
    })
    
    # Process each row
    for _, row in df.iterrows():
        heavy_chain_num = row['heavy_chain_number']
        
        # Store heavy chain and true light sequences (they should be the same for all rows of same heavy chain)
        heavy_chain_data[heavy_chain_num]['input_heavy_seq'] = row['input_heavy_seq']
        heavy_chain_data[heavy_chain_num]['true_light_seq'] = row['true_light_seq']
        
        # Add generated sequence with its number
        heavy_chain_data[heavy_chain_num]['generated_sequences'].append({
            'gen_number': row['gen_light_chain_number'],
            'sequence': row['gen_light_seq']
        })
    
    # Sort generated sequences by generation number for each heavy chain
    for heavy_chain_num in heavy_chain_data:
        heavy_chain_data[heavy_chain_num]['generated_sequences'].sort(
            key=lambda x: x['gen_number']
        )
    
    # Generate output filename if not provided
    if output_fasta_path is None:
        if csv_file_path.endswith('.csv'):
            output_fasta_path = csv_file_path.replace('.csv', '.fasta')
        else:
            output_fasta_path = csv_file_path + '.fasta'
    
    # Write FASTA file
    print(f"Writing FASTA file to: {output_fasta_path}")
    
    with open(output_fasta_path, 'w') as fasta_file:
        # Process heavy chains in sorted order
        for heavy_chain_num in sorted(heavy_chain_data.keys()):
            data = heavy_chain_data[heavy_chain_num]
            
            # Write heavy chain sequence
            fasta_file.write(f">heavy_chain_{heavy_chain_num}\n")
            fasta_file.write(f"{data['input_heavy_seq']}\n")
            
            # Write true light chain sequence
            fasta_file.write(f">true_light_chain_heavy_chain_{heavy_chain_num}\n")
            fasta_file.write(f"{data['true_light_seq']}\n")
            
            # Write all generated light chain sequences
            for gen_data in data['generated_sequences']:
                fasta_file.write(f">gen_light_{gen_data['gen_number']}_heavy_chain_{heavy_chain_num}\n")
                fasta_file.write(f"{gen_data['sequence']}\n")
    
    # Print summary
    total_sequences = 0
    heavy_chains_processed = len(heavy_chain_data)
    
    for heavy_chain_num, data in heavy_chain_data.items():
        gen_count = len(data['generated_sequences'])
        total_sequences += 2 + gen_count  # heavy + true light + generated sequences
        print(f"Heavy chain {heavy_chain_num}: 1 heavy + 1 true light + {gen_count} generated = {2 + gen_count} sequences")
    
    print(f"\n=== FASTA GENERATION SUMMARY ===")
    print(f"Heavy chains processed: {heavy_chains_processed}")
    print(f"Total sequences written: {total_sequences}")
    print(f"FASTA file saved to: {output_fasta_path}")



In [8]:
def validate_fasta_output(fasta_file_path: str) -> None:
    """
    Validate the generated FASTA file by reading and analyzing it.
    
    Args:
        fasta_file_path (str): Path to the FASTA file to validate
    """
    print(f"\n=== VALIDATING FASTA FILE ===")
    
    sequence_counts = defaultdict(int)
    heavy_chains_found = set()
    
    with open(fasta_file_path, 'r') as fasta_file:
        current_header = None
        sequence_count = 0
        
        for line in fasta_file:
            line = line.strip()
            if line.startswith('>'):
                current_header = line[1:]  # Remove '>'
                sequence_count += 1
                
                # Parse header to categorize
                if current_header.startswith('heavy_chain_'):
                    sequence_counts['heavy_chains'] += 1
                    heavy_chain_num = current_header.split('_')[-1]
                    heavy_chains_found.add(int(heavy_chain_num))
                elif current_header.startswith('true_light_chain_'):
                    sequence_counts['true_light_chains'] += 1
                elif current_header.startswith('gen_light_'):
                    sequence_counts['generated_light_chains'] += 1
    
    print(f"Total sequences in FASTA: {sequence_count}")
    print(f"Heavy chains: {sequence_counts['heavy_chains']}")
    print(f"True light chains: {sequence_counts['true_light_chains']}")
    print(f"Generated light chains: {sequence_counts['generated_light_chains']}")
    print(f"Heavy chain numbers found: {sorted(heavy_chains_found)}")
    
    # Check consistency
    if sequence_counts['heavy_chains'] == sequence_counts['true_light_chains']:
        print("✓ Each heavy chain has a corresponding true light chain")
    else:
        print("✗ Mismatch between heavy chains and true light chains")


In [11]:

def create_fasta_with_metadata(csv_file_path: str, output_fasta_path: str = None, 
                              include_scores: bool = False) -> None:
    """
    Create FASTA file with optional metadata in headers.
    
    Args:
        csv_file_path (str): Path to the input CSV file
        output_fasta_path (str, optional): Path for output FASTA file
        include_scores (bool): Whether to include BLOSUM and similarity scores in headers
    """
    
    df = pd.read_csv(csv_file_path)
    
    # Group data by heavy chain
    heavy_chain_data = defaultdict(lambda: {
        'input_heavy_seq': '',
        'true_light_seq': '',
        'generated_sequences': []
    })
    
    for _, row in df.iterrows():
        heavy_chain_num = row['heavy_chain_number']
        
        heavy_chain_data[heavy_chain_num]['input_heavy_seq'] = row['input_heavy_seq']
        heavy_chain_data[heavy_chain_num]['true_light_seq'] = row['true_light_seq']
        
        # Include metadata if requested
        gen_info = {
            'gen_number': row['gen_light_chain_number'],
            'sequence': row['gen_light_seq']
        }
        
        if include_scores and 'BLOSUM' in df.columns and 'similarity' in df.columns:
            gen_info['blosum'] = row['BLOSUM']
            gen_info['similarity'] = row['similarity']
        
        heavy_chain_data[heavy_chain_num]['generated_sequences'].append(gen_info)
    
    # Sort generated sequences
    for heavy_chain_num in heavy_chain_data:
        heavy_chain_data[heavy_chain_num]['generated_sequences'].sort(
            key=lambda x: x['gen_number']
        )
    
    # Generate output filename
    if output_fasta_path is None:
        suffix = '_with_scores.fasta' if include_scores else '.fasta'
        if csv_file_path.endswith('.csv'):
            output_fasta_path = csv_file_path.replace('.csv', suffix)
        else:
            output_fasta_path = csv_file_path + suffix
    
    # Write FASTA file
    with open(output_fasta_path, 'w') as fasta_file:
        for heavy_chain_num in sorted(heavy_chain_data.keys()):
            data = heavy_chain_data[heavy_chain_num]
            
            # Heavy chain
            fasta_file.write(f">heavy_chain_{heavy_chain_num}\n")
            fasta_file.write(f"{data['input_heavy_seq']}\n")
            
            # True light chain
            fasta_file.write(f">true_light_chain_heavy_chain_{heavy_chain_num}\n")
            fasta_file.write(f"{data['true_light_seq']}\n")
            
            # Generated sequences
            for gen_data in data['generated_sequences']:
                if include_scores and 'blosum' in gen_data:
                    header = f">gen_light_{gen_data['gen_number']}_heavy_chain_{heavy_chain_num}_BLOSUM_{gen_data['blosum']}_similarity_{gen_data['similarity']:.2f}"
                else:
                    header = f">gen_light_{gen_data['gen_number']}_heavy_chain_{heavy_chain_num}"
                
                fasta_file.write(f"{header}\n")
                fasta_file.write(f"{gen_data['sequence']}\n")
    
    print(f"FASTA file with {'metadata' if include_scores else 'standard format'} saved to: {output_fasta_path}")


In [13]:
input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.csv"  
  
output_path =  "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs"

# Optional: Generate FASTA with scores in headers
create_fasta_with_metadata(input_file, include_scores=False, 
                            output_fasta_path=f"{output_path}/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.fasta")

FASTA file with standard format saved to: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.fasta


In [10]:
 # Input CSV file path 
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267.csv"
input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_full_eval_generate_multiple_light_seqs_203276_10k_predictions.csv"  
  
output_path =  "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy"

# Optional: Generate FASTA with scores in headers
create_fasta_with_metadata(input_file, include_scores=False, 
                            output_fasta_path=f"{output_path}/matching_prediction_full_eval_generate_multiple_light_seqs_203276_10k_predictions.fasta")

FASTA file with standard format saved to: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_full_eval_generate_multiple_light_seqs_203276_10k_predictions.fasta


In [5]:

import json
import pandas as pd
import re
from typing import List, Dict, Any

In [10]:
def parse_single_record(data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Parse a single JSON record and extract all required fields.
    
    Args:
        data (Dict[str, Any]): Single JSON record
        
    Returns:
        Dict[str, Any]: Parsed record with extracted fields
    """
    
    record = {}
    
    # Basic fields
    record['sequence_id'] = data.get('Sequence ID', '')
    record['raw_sequence'] = data.get('Raw Sequence', '')
    record['sequence_length'] = data.get('Sequence Length', 0)
    record['domain_classification'] = data.get('Domain Classification', '')
    record['nt_trimmed'] = data.get('NT-Trimmed', '')
    
    # Extract first hit information
    hits = data.get('Hits', [])
    if hits and len(hits) > 0:
        first_hit = hits[0]
        record['first_hit_gene'] = first_hit.get('gene', '')
        record['first_hit_bit_score'] = first_hit.get('bit_score', 0.0)
        record['first_hit_e_value'] = first_hit.get('e_value', 0.0)
        
        # Extract gene name (before the first space or "unnamed")
        gene_full = first_hit.get('gene', '')
        gene_name = extract_gene_name(gene_full)
        record['gene_name'] = gene_name
        
        # Extract light locus
        record['light_locus'] = extract_light_locus(gene_name)
    else:
        record['first_hit_gene'] = ''
        record['first_hit_bit_score'] = 0.0
        record['first_hit_e_value'] = 0.0
        record['gene_name'] = ''
        record['light_locus'] = ''
    
    # Extract region sequences
    regions = ['FR1', 'CDR1', 'FR2', 'CDR2', 'FR3', 'CDR3']
    
    for region in regions:
        region_data = data.get(region, {})
        if isinstance(region_data, dict):
            # Extract AA sequence
            aa_seq = region_data.get('AA', '')
            record[f'{region}_sequence'] = aa_seq
            
            # Extract additional region information
            record[f'{region}_length'] = region_data.get('length', 0.0)
            record[f'{region}_percent_identity'] = region_data.get('percent identity', 0.0)
            record[f'{region}_matches'] = region_data.get('matches', 0.0)
            record[f'{region}_mismatches'] = region_data.get('mismatches', 0.0)
            record[f'{region}_gaps'] = region_data.get('gaps', 0.0)
        else:
            # Handle case where region data might be missing or malformed
            record[f'{region}_sequence'] = ''
            record[f'{region}_length'] = 0.0
            record[f'{region}_percent_identity'] = 0.0
            record[f'{region}_matches'] = 0.0
            record[f'{region}_mismatches'] = 0.0
            record[f'{region}_gaps'] = 0.0
    
    return record


In [6]:
def parse_antibody_json_to_csv(json_file_path: str, output_csv_path: str = None) -> pd.DataFrame:
    """
    Parse JSON file with antibody analysis data and convert to CSV format.
    
    Args:
        json_file_path (str): Path to the input JSON file
        output_csv_path (str, optional): Path for output CSV file.
                                       If None, uses input filename with .csv extension
    
    Returns:
        pd.DataFrame: Parsed data as DataFrame
    """
    
    print(f"Reading JSON data from: {json_file_path}")
    
    # Read JSON file line by line (each line is a separate JSON object)
    parsed_data = []
    
    with open(json_file_path, 'r') as file:
        for line_num, line in enumerate(file, 1):
            line = line.strip()
            if not line:  # Skip empty lines
                continue
                
            try:
                # Parse JSON object from line
                data = json.loads(line)
                parsed_record = parse_single_record(data)
                parsed_data.append(parsed_record)
                
            except json.JSONDecodeError as e:
                print(f"Warning: Could not parse line {line_num}: {e}")
                continue
            except Exception as e:
                print(f"Warning: Error processing line {line_num}: {e}")
                continue
    
    print(f"Successfully parsed {len(parsed_data)} records")
    
    # Create DataFrame
    df = pd.DataFrame(parsed_data)
    
    # Generate output filename if not provided
    if output_csv_path is None:
        if json_file_path.endswith('.json'):
            output_csv_path = json_file_path.replace('.json', '.csv')
        else:
            output_csv_path = json_file_path + '.csv'
    
    # Save to CSV
    df.to_csv(output_csv_path, index=False)
    print(f"CSV file saved to: {output_csv_path}")
    
    # Display summary
    print(f"\n=== PARSING SUMMARY ===")
    print(f"Total records processed: {len(df)}")
    print(f"Columns created: {len(df.columns)}")
    print(f"Column names: {list(df.columns)}")
    
    if len(df) > 0:
        print(f"\n=== SAMPLE DATA ===")
        print(df.head())
        
        # Show locus distribution
        if 'light_locus' in df.columns:
            print(f"\n=== LIGHT LOCUS DISTRIBUTION ===")
            print(df['light_locus'].value_counts())
    
    return df

In [8]:
def extract_gene_name(gene_full: str) -> str:
    """
    Extract gene name from full gene description.
    
    Examples:
    "IGLV1-51*01 unnamed protein product" -> "IGLV1-51*01"
    "IGKV3-20*01 immunoglobulin kappa" -> "IGKV3-20*01"
    
    Args:
        gene_full (str): Full gene description
        
    Returns:
        str: Extracted gene name
    """
    if not gene_full:
        return ''
    
    # Split by space and take the first part (before "unnamed" or other descriptors)
    parts = gene_full.split()
    if parts:
        return parts[0]
    return gene_full

def extract_light_locus(gene_name: str) -> str:
    """
    Extract light chain locus from gene name.
    
    Args:
        gene_name (str): Gene name (e.g., "IGLV1-51*01")
        
    Returns:
        str: Light locus ("IGL", "IGK", or "")
    """
    if not gene_name:
        return ''
    
    # Check for common light chain patterns
    if gene_name.startswith('IGL'):
        return 'IGL'
    elif gene_name.startswith('IGK'):
        return 'IGK'
    elif gene_name.startswith('IGKV'):
        return 'IGK'
    elif gene_name.startswith('IGLV'):
        return 'IGL'
    
    return ''

def analyze_parsed_data(df: pd.DataFrame) -> None:
    """
    Analyze the parsed data and provide detailed statistics.
    
    Args:
        df (pd.DataFrame): Parsed antibody data
    """
    print(f"\n=== DETAILED ANALYSIS ===")
    
    if len(df) == 0:
        print("No data to analyze")
        return
    
    # Basic statistics
    print(f"Total sequences: {len(df)}")
    print(f"Unique sequence IDs: {df['sequence_id'].nunique()}")
    
    # Sequence length statistics
    if 'sequence_length' in df.columns:
        print(f"Average sequence length: {df['sequence_length'].mean():.1f}")
        print(f"Min sequence length: {df['sequence_length'].min()}")
        print(f"Max sequence length: {df['sequence_length'].max()}")
    
    # Light locus distribution
    if 'light_locus' in df.columns:
        print(f"\nLight locus distribution:")
        locus_counts = df['light_locus'].value_counts()
        for locus, count in locus_counts.items():
            percentage = (count / len(df)) * 100
            print(f"  {locus}: {count} ({percentage:.1f}%)")
    
    # Gene name patterns
    if 'gene_name' in df.columns:
        print(f"\nTop 10 most common genes:")
        top_genes = df['gene_name'].value_counts().head(10)
        for gene, count in top_genes.items():
            print(f"  {gene}: {count}")
    
    # Region sequence lengths
    regions = ['FR1', 'CDR1', 'FR2', 'CDR2', 'FR3', 'CDR3']
    print(f"\nAverage region lengths:")
    for region in regions:
        length_col = f'{region}_length'
        if length_col in df.columns:
            avg_length = df[length_col].mean()
            print(f"  {region}: {avg_length:.1f}")

In [9]:
def create_simplified_csv(df: pd.DataFrame, output_path: str) -> None:
    """
    Create a simplified CSV with only the most important columns.
    
    Args:
        df (pd.DataFrame): Full parsed data
        output_path (str): Output path for simplified CSV
    """
    
    # Select key columns
    key_columns = [
        'sequence_id', 'raw_sequence', 'gene_name', 'light_locus',
        'FR1_sequence', 'CDR1_sequence', 'FR2_sequence', 
        'CDR2_sequence', 'FR3_sequence', 'CDR3_sequence',
        'nt_trimmed'
    ]
    
    # Filter to only existing columns
    available_columns = [col for col in key_columns if col in df.columns]
    simplified_df = df[available_columns].copy()
    
    # Save simplified version
    simplified_df.to_csv(output_path, index=False)
    print(f"Simplified CSV saved to: {output_path}")

In [11]:
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267.json"  

#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions.json"
#output_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv"    


input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.json"
output_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv"    

    
# Parse JSON to CSV
df = parse_antibody_json_to_csv(input_file, 
                                 output_csv_path=output_csv_path)
        
# Analyze the parsed data
analyze_parsed_data(df)
        

Reading JSON data from: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions.json
Successfully parsed 339147 records
CSV file saved to: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv

=== PARSING SUMMARY ===
Total records processed: 339147
Columns created: 46
Column names: ['sequence_id', 'raw_sequence', 'sequence_length', 'domain_classification', 'nt_trimmed', 'first_hit_gene', 'first_hit_bit_score', 'first_hit_e_value', 'gene_name', 'light_locus', 'FR1_sequence', 'FR1_length', 'FR1_percent_identity', 'FR1_matches', 'FR1_mismatches', 'FR1_gaps', 'CDR1_sequence', 'CDR1_length', 'CDR1_percent_identity', 'CDR1_matches', 'CDR1_mismatches', 'CDR1_gaps', 'FR2_sequence', 'FR2_length', 'FR2_percen

In [12]:
import pandas as pd
import re
from collections import defaultdict
from typing import Dict, List, Any

In [13]:
def extract_heavy_chain_number(sequence_id: str) -> int:
    """
    Extract heavy chain number from sequence ID.
    
    Examples:
    "heavy_chain_32" -> 32
    "true_light_chain_heavy_chain_32" -> 32  
    "gen_light_3_heavy_chain_32" -> 32
    
    Args:
        sequence_id (str): Sequence identifier
        
    Returns:
        int: Heavy chain number, or -1 if not found
    """
    # Look for pattern "heavy_chain_X" at the end
    match = re.search(r'heavy_chain_(\d+)$', sequence_id)
    if match:
        return int(match.group(1))
    
    # Look for pattern "heavy_chain_X" anywhere
    match = re.search(r'heavy_chain_(\d+)', sequence_id)
    if match:
        return int(match.group(1))
    
    return -1

def extract_sequence_type_and_number(sequence_id: str) -> tuple:
    """
    Extract sequence type and generation number from sequence ID.
    
    Args:
        sequence_id (str): Sequence identifier
        
    Returns:
        tuple: (sequence_type, generation_number)
               sequence_type: 'heavy', 'true_light', 'gen_light'
               generation_number: int for generated sequences, None for others
    """
    if sequence_id.startswith('heavy_chain_'):
        return ('heavy', None)
    elif sequence_id.startswith('true_light_chain_'):
        return ('true_light', None)
    elif sequence_id.startswith('gen_light_'):
        # Extract generation number
        match = re.search(r'gen_light_(\d+)_', sequence_id)
        if match:
            return ('gen_light', int(match.group(1)))
        else:
            return ('gen_light', None)
    
    return ('unknown', None)

def reformat_csv_by_heavy_chain(input_csv_path: str, output_csv_path: str = None) -> pd.DataFrame:
    """
    Reformat CSV to group all sequences belonging to the same heavy chain on one row.
    
    Args:
        input_csv_path (str): Path to input CSV file
        output_csv_path (str, optional): Path for output CSV file
        
    Returns:
        pd.DataFrame: Reformatted DataFrame with grouped sequences
    """
    
    print(f"Reading data from: {input_csv_path}")
    df = pd.read_csv(input_csv_path)
    
    print(f"Original dataset shape: {df.shape}")
    print(f"Original columns: {list(df.columns)}")
    
    # Group sequences by heavy chain number
    heavy_chain_groups = defaultdict(lambda: {
        'heavy': None,
        'true_light': None,
        'generated': {}
    })
    
    # Process each row and group by heavy chain
    for idx, row in df.iterrows():
        sequence_id = row['sequence_id']
        heavy_chain_num = extract_heavy_chain_number(sequence_id)
        seq_type, gen_num = extract_sequence_type_and_number(sequence_id)
        
        if heavy_chain_num == -1:
            print(f"Warning: Could not extract heavy chain number from: {sequence_id}")
            continue
        
        # Store the row data based on sequence type
        if seq_type == 'heavy':
            heavy_chain_groups[heavy_chain_num]['heavy'] = row
        elif seq_type == 'true_light':
            heavy_chain_groups[heavy_chain_num]['true_light'] = row
        elif seq_type == 'gen_light':
            heavy_chain_groups[heavy_chain_num]['generated'][gen_num] = row
    
    print(f"Found {len(heavy_chain_groups)} heavy chain groups")
    
    # Create reformatted data
    reformatted_data = []
    
    for heavy_chain_num in sorted(heavy_chain_groups.keys()):
        group_data = heavy_chain_groups[heavy_chain_num]
        
        # Start with overall_id
        row_data = {'overall_id': heavy_chain_num}
        
        # Add heavy chain data
        if group_data['heavy'] is not None:
            heavy_row = group_data['heavy']
            for col in heavy_row.index:
                if col != 'sequence_id':  # Skip sequence_id to avoid confusion
                    row_data[f'heavy_{col}'] = heavy_row[col]
        
        # Add true light chain data
        if group_data['true_light'] is not None:
            true_light_row = group_data['true_light']
            for col in true_light_row.index:
                if col != 'sequence_id':
                    row_data[f'true_light_{col}'] = true_light_row[col]
        
        # Add generated light chain data
        generated_sequences = group_data['generated']
        max_gen_sequences = len(generated_sequences)
        
        # Sort generated sequences by generation number
        sorted_gen_nums = sorted(generated_sequences.keys())
        
        for i, gen_num in enumerate(sorted_gen_nums, 1):
            gen_row = generated_sequences[gen_num]
            for col in gen_row.index:
                if col != 'sequence_id':
                    row_data[f'gen_light_{i}_{col}'] = gen_row[col]
            # Also store the original generation number
            row_data[f'gen_light_{i}_original_number'] = gen_num
        
        # Add count information
        row_data['num_generated_sequences'] = len(generated_sequences)
        
        reformatted_data.append(row_data)
    
    # Create DataFrame
    reformatted_df = pd.DataFrame(reformatted_data)
    
    # Generate output filename if not provided
    if output_csv_path is None:
        if input_csv_path.endswith('.csv'):
            output_csv_path = input_csv_path.replace('.csv', '_reformatted.csv')
        else:
            output_csv_path = input_csv_path + '_reformatted.csv'
    
    # Save reformatted data
    reformatted_df.to_csv(output_csv_path, index=False)
    
    print(f"\n=== REFORMATTING SUMMARY ===")
    print(f"Heavy chain groups processed: {len(heavy_chain_groups)}")
    print(f"Reformatted dataset shape: {reformatted_df.shape}")
    print(f"Output saved to: {output_csv_path}")
    
    # Show sample of reformatted data
    print(f"\n=== SAMPLE OF REFORMATTED DATA ===")
    print("Column names:")
    for i, col in enumerate(reformatted_df.columns):
        print(f"  {i+1}. {col}")
    
    if len(reformatted_df) > 0:
        print(f"\nFirst few rows (showing key columns):")
        key_cols = ['overall_id', 'num_generated_sequences']
        # Add some sequence columns if they exist
        for col in reformatted_df.columns:
            if 'raw_sequence' in col:
                key_cols.append(col)
        
        available_key_cols = [col for col in key_cols if col in reformatted_df.columns]
        print(reformatted_df[available_key_cols].head())
    
    return reformatted_df

def analyze_reformatted_data(df: pd.DataFrame) -> None:
    """
    Analyze the reformatted data and provide statistics.
    
    Args:
        df (pd.DataFrame): Reformatted DataFrame
    """
    print(f"\n=== DETAILED ANALYSIS ===")
    
    if len(df) == 0:
        print("No data to analyze")
        return
    
    print(f"Total heavy chain groups: {len(df)}")
    
    # Analyze number of generated sequences per heavy chain
    if 'num_generated_sequences' in df.columns:
        gen_seq_stats = df['num_generated_sequences'].describe()
        print(f"\nGenerated sequences per heavy chain:")
        print(f"  Average: {gen_seq_stats['mean']:.1f}")
        print(f"  Min: {int(gen_seq_stats['min'])}")
        print(f"  Max: {int(gen_seq_stats['max'])}")
        print(f"  Distribution:")
        
        gen_counts = df['num_generated_sequences'].value_counts().sort_index()
        for count, freq in gen_counts.items():
            percentage = (freq / len(df)) * 100
            print(f"    {int(count)} generated sequences: {freq} heavy chains ({percentage:.1f}%)")
    
    # Check for missing data
    missing_heavy = df[[col for col in df.columns if col.startswith('heavy_')]].isnull().all(axis=1).sum()
    missing_true_light = df[[col for col in df.columns if col.startswith('true_light_')]].isnull().all(axis=1).sum()
    
    print(f"\nMissing data:")
    print(f"  Heavy chains without heavy sequence data: {missing_heavy}")
    print(f"  Heavy chains without true light sequence data: {missing_true_light}")

def create_sequence_only_format(df: pd.DataFrame, output_path: str) -> None:
    """
    Create a simplified version with only sequence data (no analysis results).
    
    Args:
        df (pd.DataFrame): Reformatted DataFrame
        output_path (str): Output path for sequence-only CSV
    """
    
    # Select only sequence-related columns
    sequence_columns = ['overall_id', 'num_generated_sequences']
    
    # Add sequence columns
    for col in df.columns:
        if any(seq_col in col for seq_col in ['raw_sequence', 'gene_name', 'light_locus']):
            sequence_columns.append(col)
        elif any(region in col for region in ['FR1_sequence', 'CDR1_sequence', 'FR2_sequence', 
                                             'CDR2_sequence', 'FR3_sequence', 'CDR3_sequence']):
            sequence_columns.append(col)
    
    # Filter to only existing columns
    available_columns = [col for col in sequence_columns if col in df.columns]
    sequence_df = df[available_columns].copy()
    
    # Save sequence-only version
    sequence_df.to_csv(output_path, index=False)
    print(f"Sequence-only CSV saved to: {output_path}")

def validate_reformatted_data(original_df: pd.DataFrame, reformatted_df: pd.DataFrame) -> None:
    """
    Validate that the reformatting preserved all data correctly.
    
    Args:
        original_df (pd.DataFrame): Original DataFrame
        reformatted_df (pd.DataFrame): Reformatted DataFrame
    """
    print(f"\n=== VALIDATION ===")
    
    # Count total sequences in original data
    original_heavy_chains = set()
    for _, row in original_df.iterrows():
        heavy_chain_num = extract_heavy_chain_number(row['sequence_id'])
        if heavy_chain_num != -1:
            original_heavy_chains.add(heavy_chain_num)
    
    print(f"Heavy chains in original data: {len(original_heavy_chains)}")
    print(f"Heavy chains in reformatted data: {len(reformatted_df)}")
    
    if len(original_heavy_chains) == len(reformatted_df):
        print("✓ All heavy chains preserved")
    else:
        print("✗ Some heavy chains may be missing")
        missing = original_heavy_chains - set(reformatted_df['overall_id'])
        if missing:
            print(f"  Missing heavy chains: {sorted(missing)}")



In [14]:
# Input CSV file path - update this to your actual file path
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267_parsed.csv" 
#input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv" 

input_file = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv" 

output_csv_path = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv"  

# Read original data for validation
original_df = pd.read_csv(input_file)
        
# Reformat the CSV
reformatted_df = reformat_csv_by_heavy_chain(input_file, output_csv_path=None)
        
# Analyze reformatted data
analyze_reformatted_data(reformatted_df)
        
# Validate the reformatting
validate_reformatted_data(original_df, reformatted_df)

Reading data from: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed.csv
Original dataset shape: (339147, 46)
Original columns: ['sequence_id', 'raw_sequence', 'sequence_length', 'domain_classification', 'nt_trimmed', 'first_hit_gene', 'first_hit_bit_score', 'first_hit_e_value', 'gene_name', 'light_locus', 'FR1_sequence', 'FR1_length', 'FR1_percent_identity', 'FR1_matches', 'FR1_mismatches', 'FR1_gaps', 'CDR1_sequence', 'CDR1_length', 'CDR1_percent_identity', 'CDR1_matches', 'CDR1_mismatches', 'CDR1_gaps', 'FR2_sequence', 'FR2_length', 'FR2_percent_identity', 'FR2_matches', 'FR2_mismatches', 'FR2_gaps', 'CDR2_sequence', 'CDR2_length', 'CDR2_percent_identity', 'CDR2_matches', 'CDR2_mismatches', 'CDR2_gaps', 'FR3_sequence', 'FR3_length', 'FR3_percent_identity', 'FR3_matches', 'FR3_mismatches', 'FR3_gaps', 'CDR3_sequence', 'CDR3_length', '

In [15]:
def quick_extract(input_csv: str, output_csv: str = None):
    """
    Quick extraction function.
    
    Args:
        input_csv (str): Input CSV file path
        output_csv (str, optional): Output CSV file path
    """
    df = pd.read_csv(input_csv)
    
    columns_to_keep = [
        'overall_id',
        'heavy_raw_sequence', 
        'true_light_raw_sequence',
        'true_light_first_hit_gene',
        'true_light_gene_name',
        'true_light_light_locus',
        'gen_light_1_raw_sequence',
        'gen_light_1_nt_trimmed',
        'gen_light_1_first_hit_gene',
        'gen_light_1_light_locus',
        'similarity',
        'predicted_input_heavy_label',
        'predicted_gen_light_label',
    ]
    
    existing_columns = [col for col in columns_to_keep if col in df.columns]
    filtered_df = df[existing_columns]
    
    if output_csv is None:
        output_csv = input_csv.replace('.csv', '_simple.csv')
    
    filtered_df.to_csv(output_csv, index=False)
    print(f"Saved {len(filtered_df)} rows with {len(existing_columns)} columns to: {output_csv}")
    
    return filtered_df




In [16]:

#quick_extract(input_csv="/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267_parsed_reformatted.csv",  output_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267_parsed_reformatted_rel_cols.csv")

#quick_extract(input_csv="/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv",  output_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols.csv")

quick_extract(input_csv="/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted.csv",  output_csv = "/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols.csv")




  df = pd.read_csv(input_csv)


Saved 46081 rows with 10 columns to: /ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/full_test_set_multiple_light_seqs/non_matching_seqs_multiple_light_seqs_203276_cls_predictions_parsed_reformatted_rel_cols.csv


Unnamed: 0,overall_id,heavy_raw_sequence,true_light_raw_sequence,true_light_first_hit_gene,true_light_gene_name,true_light_light_locus,gen_light_1_raw_sequence,gen_light_1_nt_trimmed,gen_light_1_first_hit_gene,gen_light_1_light_locus
0,1,QLQVQESGPGLVKPSETLSLTCTVSGASSSIKKYYWGWIRQSPGKG...,DIQMTQSPSTLSASVGDRVTITCRASHSINTWLAWYQQKPGKAPKL...,IGKV1-5*03 unnamed protein product,IGKV1-5*03,IGK,SSELTQDPAVSVALGQTVRITCQGDSLRSYYASWYQQKPGQAPLLV...,SSELTQDPAVSVALGQTVRITCQGDSLRSYYASWYQQKPGQAPLLV...,IGLV3-19*01 unnamed protein product,IGL
1,2,QLQLQESGPGLVKPSETLSLICSVSGGSITTSSYYWAWIRQSPGKG...,QSALTQPASVSGSPGQSITISCSGTSDDIGDYNYVSWYQQHPGKAP...,IGLV2-14*03 unnamed protein product,IGLV2-14*03,IGL,DIVMTQSPDSLAVSLGERATINCKSSQSVLYSSDNKNYLGWYQQKP...,DIVMTQSPDSLAVSLGERATINCKSSQSVLYSSDNKNYLGWYQQKP...,IGKV3-20*01 unnamed protein product,IGK
2,3,EVQLVESGGDLVRPGGSLRLSCAASGFPFSRAWMTWVRQAPGKGLD...,DIQMTQSPSSLSAFMGDRVTITCRASQSPKTYLHWYQQRPGGVPKL...,IGKV1-39*01 unnamed protein product,IGKV1-39*01,IGK,EIVLTQSPGTLSLSPGERATLSCRASQSVSRSYFAWYQQKPGQAPR...,EIVLTQSPGTLSLSPGERATLSCRASQSVSRSYFAWYQQKPGQAPR...,IGKV3-20*01 unnamed protein product,IGK
3,4,EVQLLESGGGLVQPGGSLRLSCAASGFNFANYDMSWVRQAPGKGLE...,QTVVTQEPSFSVSPGGTVTLTCGLSSGSVSTKYYPSWYQQTPGQAP...,IGLV8-61*01 unnamed protein product,IGLV8-61*01,IGL,DIQMTQSPSTLSASVGDRVTITCRASQSISNWLAWYQQKPGKAPKF...,DIQMTQSPSTLSASVGDRVTITCRASQSISNWLAWYQQKPGKAPKF...,IGKV1-5*03 unnamed protein product,IGK
4,5,QVQLQESGPGLVKPSGTLSLTCVVSGGSISTNNWWSWVRQPPGKGL...,EIVLTQSPGTLSLSPGERATLSCRASQSISNTYLAWYRQKPGQAPR...,IGKV3-20*01 unnamed protein product,IGKV3-20*01,IGK,EIVLTQSPGTLSLSPGERATLSCRASQSVSSRYLAWYQQKPGQAPR...,EIVLTQSPGTLSLSPGERATLSCRASQSVSSRYLAWYQQKPGQAPR...,IGKV3-20*01 unnamed protein product,IGK
...,...,...,...,...,...,...,...,...,...,...
46076,58832,QVQLVQSGAEVKKPGASVKVSCKASGYTFDVYGISWVRQAPGQGLE...,DIQMTQSPSTLSASVGDRVTITCQASQSINNWLAWYQQKPGKAPKL...,IGKV1-5*03 unnamed protein product,IGKV1-5*03,IGK,SYVLTQPPSVSVAPGQTARITCGGNNIGSKSVHWYQQKPGQAPVLV...,SYVLTQPPSVSVAPGQTARITCGGNNIGSKSVHWYQQKPGQAPVLV...,IGLV3-21*02 unnamed protein product,IGL
46077,58834,EVQLVESGGGLVQPGGSLRLSCAASGFIFSSFGMHWVRKAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCKSSLSLLNSGNQKNYLTWYQEKP...,IGKV1-39*01 unnamed protein product,IGKV1-39*01,IGK,DIVMTQTPLSLPVTPGEPASISCRSSQSLLDSDDGNTYLDWYLQKP...,DIVMTQTPLSLPVTPGEPASISCRSSQSLLDSDDGNTYLDWYLQKP...,IGKV2-28*01 unnamed protein product,IGK
46078,58835,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSDNDFSWVRQAPGQGLE...,QSALTQPRSVSGSPGQSVTISCTGTSSDVGGYNYVSWYQQHPGKAP...,IGLV2-11*01 unnamed protein product,IGLV2-11*01,IGL,DIQMTQSPSTLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKL...,DIQMTQSPSTLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKL...,IGKV1-5*03 unnamed protein product,IGK
46079,58837,EVQLVESGGGLVQPGGSLRLSCAASGFTFSTYYMSWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRASQGISNILAWYQQKPGKAPKL...,IGKV1-NL1*01 unnamed protein product,IGKV1-NL1*01,IGK,SYELTQPPSVSVSPGQTARITCSGDALPKKYAYWYQQKSGQAPVLV...,SYELTQPPSVSVSPGQTARITCSGDALPKKYAYWYQQKSGQAPVLV...,IGLV3-10*01 unnamed protein product,IGL


In [31]:
df = pd.read_csv("/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2GPT/multiple_light_seqs_from_single_heavy/matching_prediction_summary_full_eval_generate_multiple_light_seqs_203267_parsed_reformatted.csv")

In [36]:
# print first 2 rows the column gen_light_1_nt_trimmed
print(df.columns)
df['gen_light_1_nt_trimmed'].head(2)


Index(['overall_id', 'heavy_raw_sequence', 'heavy_sequence_length',
       'heavy_domain_classification', 'heavy_nt_trimmed',
       'heavy_first_hit_gene', 'heavy_first_hit_bit_score',
       'heavy_first_hit_e_value', 'heavy_gene_name', 'heavy_light_locus',
       ...
       'gen_light_10_FR3_matches', 'gen_light_10_FR3_mismatches',
       'gen_light_10_FR3_gaps', 'gen_light_10_CDR3_sequence',
       'gen_light_10_CDR3_length', 'gen_light_10_CDR3_percent_identity',
       'gen_light_10_CDR3_matches', 'gen_light_10_CDR3_mismatches',
       'gen_light_10_CDR3_gaps', 'gen_light_10_original_number'],
      dtype='object', length=552)


0    DIQLTQSPSFLSASVGDRVTITCRASQGISNFLAWYQQKPGKAPEL...
1    DIVMTQSPDSLAVSLGERATINCKSSQSLFHSSNNKNYLAWYQQKP...
Name: gen_light_1_nt_trimmed, dtype: object