In [2]:
import pandas as pd
import numpy as np
import random
from typing import List, Dict, Tuple
import ast
import re

class ProteinDataAugmenter:
    def __init__(self, csv_path: str):
        """
        Initialize the augmenter with the dataset path.
        
        Args:
            csv_path: Path to the merged_protein_dataset_ext.csv file
        """
        self.csv_path = csv_path
        self.df = None
        self.amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 
                           'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
        
    def load_data(self):
        """Load the CSV file into a pandas DataFrame."""
        print("Loading dataset...")
        self.df = pd.read_csv(self.csv_path)
        print(f"Dataset loaded: {len(self.df)} entries")
        return self.df
    
    def get_chain_columns(self) -> List[str]:
        """Get all chain sequence columns from the dataset."""
        sequence_cols = [col for col in self.df.columns if col.endswith('_sequence')]
        return sequence_cols
    
    def get_binding_array_columns(self) -> List[str]:
        """Get all binding array columns from the dataset."""
        binding_cols = [col for col in self.df.columns if col.endswith('_binding_array')]
        return binding_cols
    
    def parse_binding_array(self, array_str) -> List[int]:
        """
        Parse binding array string to list of integers.
        
        Args:
            array_str: String representation of binding array
            
        Returns:
            List of integers representing binding positions
        """
        if pd.isna(array_str) or array_str == '':
            return []
        
        try:
            # Handle string representation of list
            if isinstance(array_str, str):
                # Remove brackets and split by comma
                array_str = array_str.strip('[]')
                if not array_str:
                    return []
                return [int(x.strip()) for x in array_str.split(',')]
            else:
                return ast.literal_eval(array_str)
        except:
            return []
    
    def has_no_binding(self, row) -> bool:
        """
        Check if a row has no binding interactions (all binding arrays are 0s).
        
        Args:
            row: DataFrame row
            
        Returns:
            True if no binding interactions exist
        """
        binding_cols = self.get_binding_array_columns()
        
        for col in binding_cols:
            if pd.notna(row[col]) and row[col] != '':
                binding_array = self.parse_binding_array(row[col])
                if binding_array and any(x == 1 for x in binding_array):
                    return False
        
        return True
    
    def mutate_sequence(self, sequence: str, num_mutations: int) -> str:
        """
        Introduce random point mutations in a protein sequence.
        
        Args:
            sequence: Original protein sequence
            num_mutations: Number of mutations to introduce
            
        Returns:
            Mutated protein sequence
        """
        if not sequence or pd.isna(sequence):
            return sequence
            
        sequence_list = list(sequence)
        sequence_length = len(sequence_list)
        
        if sequence_length == 0:
            return sequence
            
        # Randomly select positions to mutate
        mutation_positions = random.sample(range(sequence_length), 
                                         min(num_mutations, sequence_length))
        
        for pos in mutation_positions:
            original_aa = sequence_list[pos]
            # Choose a different amino acid
            possible_mutations = [aa for aa in self.amino_acids if aa != original_aa]
            new_aa = random.choice(possible_mutations)
            sequence_list[pos] = new_aa
        
        return ''.join(sequence_list)
    
    def augment_row(self, row, augmentation_id: int) -> pd.Series:
        """
        Create an augmented version of a row with mutations.
        
        Args:
            row: Original DataFrame row
            augmentation_id: Unique identifier for this augmentation
            
        Returns:
            Augmented row
        """
        new_row = row.copy()
        
        # Add augmentation identifier to PDB ID
        original_pdb = row['pdb_id']
        new_row['pdb_id'] = f"{original_pdb}_aug_{augmentation_id}"
        
        # Get sequence columns and introduce mutations
        sequence_cols = self.get_chain_columns()
        num_mutations = random.randint(1, 10)
        
        for col in sequence_cols:
            if pd.notna(row[col]) and row[col] != '':
                new_row[col] = self.mutate_sequence(row[col], num_mutations)
        
        return new_row
    
    def identify_no_binding_entries(self) -> pd.DataFrame:
        """
        Identify entries with no binding interactions.
        
        Returns:
            DataFrame containing only no-binding entries
        """
        print("Identifying entries with no binding interactions...")
        no_binding_mask = self.df.apply(self.has_no_binding, axis=1)
        no_binding_df = self.df[no_binding_mask].copy()
        
        print(f"Found {len(no_binding_df)} entries with no binding interactions")
        print(f"Found {len(self.df) - len(no_binding_df)} entries with binding interactions")
        
        return no_binding_df
    
    def augment_dataset(self, augmentation_factor: int = 2) -> pd.DataFrame:
        """
        Augment the dataset by creating mutated versions of no-binding entries.
        
        Args:
            augmentation_factor: How many augmented versions to create for each no-binding entry
            
        Returns:
            Augmented dataset
        """
        if self.df is None:
            raise ValueError("Dataset not loaded. Call load_data() first.")
        
        print(f"Starting augmentation with factor {augmentation_factor}...")
        
        # Identify no-binding entries
        no_binding_df = self.identify_no_binding_entries()
        
        # Create augmented versions
        augmented_rows = []
        
        for idx, (_, row) in enumerate(no_binding_df.iterrows()):
            for aug_num in range(augmentation_factor):
                augmentation_id = f"{idx}_{aug_num}"
                augmented_row = self.augment_row(row, augmentation_id)
                augmented_rows.append(augmented_row)
        
        # Combine original dataset with augmented data
        augmented_df = pd.DataFrame(augmented_rows)
        final_df = pd.concat([self.df, augmented_df], ignore_index=True)
        
        print(f"Augmentation complete:")
        print(f"  Original dataset: {len(self.df)} entries")
        print(f"  Augmented entries: {len(augmented_df)} entries")
        print(f"  Final dataset: {len(final_df)} entries")
        
        return final_df
    
    def validate_augmentation(self, augmented_df: pd.DataFrame) -> Dict[str, int]:
        """
        Validate the augmentation results.
        
        Args:
            augmented_df: The augmented dataset
            
        Returns:
            Dictionary with validation statistics
        """
        print("\nValidating augmentation...")
        
        # Count original vs augmented entries
        original_entries = len([pdb for pdb in augmented_df['pdb_id'] if '_aug_' not in str(pdb)])
        augmented_entries = len([pdb for pdb in augmented_df['pdb_id'] if '_aug_' in str(pdb)])
        
        # Count binding vs no-binding in final dataset
        binding_count = 0
        no_binding_count = 0
        
        for _, row in augmented_df.iterrows():
            if self.has_no_binding(row):
                no_binding_count += 1
            else:
                binding_count += 1
        
        validation_stats = {
            'total_entries': len(augmented_df),
            'original_entries': original_entries,
            'augmented_entries': augmented_entries,
            'binding_entries': binding_count,
            'no_binding_entries': no_binding_count
        }
        
        print(f"Validation Results:")
        print(f"  Total entries: {validation_stats['total_entries']}")
        print(f"  Original entries: {validation_stats['original_entries']}")
        print(f"  Augmented entries: {validation_stats['augmented_entries']}")
        print(f"  Entries with binding: {validation_stats['binding_entries']}")
        print(f"  Entries without binding: {validation_stats['no_binding_entries']}")
        print(f"  Binding ratio: {validation_stats['binding_entries']/(validation_stats['binding_entries']+validation_stats['no_binding_entries']):.3f}")
        
        return validation_stats
    
    def save_augmented_dataset(self, augmented_df: pd.DataFrame, output_path: str):
        """
        Save the augmented dataset to a CSV file.
        
        Args:
            augmented_df: The augmented dataset
            output_path: Path to save the augmented dataset
        """
        print(f"\nSaving augmented dataset to {output_path}...")
        augmented_df.to_csv(output_path, index=False)
        print("Dataset saved successfully!")


def main():
    """Main function to run the augmentation process."""
    
    # Configuration
    input_file = "merged_protein_dataset_ext.csv"
    output_file = "augmented_protein_dataset.csv"
    augmentation_factor = 20  # Adjust this value as needed
    
    # Initialize augmenter
    augmenter = ProteinDataAugmenter(input_file)
    
    # Load data             
    augmenter.load_data()
    
    # Perform augmentation
    augmented_dataset = augmenter.augment_dataset(augmentation_factor=augmentation_factor)
    
    # Validate results
    validation_stats = augmenter.validate_augmentation(augmented_dataset)
    
    # Save augmented dataset
    augmenter.save_augmented_dataset(augmented_dataset, output_file)
    
    return augmented_dataset, validation_stats


if __name__ == "__main__":
    # Set random seed for reproducibility
    random.seed(42)
    np.random.seed(42)
    
    # Run augmentation
    augmented_data, stats = main()
    
    print(f"\nAugmentation process completed successfully!")

Loading dataset...


  self.df = pd.read_csv(self.csv_path)


Dataset loaded: 97541 entries
Starting augmentation with factor 20...
Identifying entries with no binding interactions...
Found 1062 entries with no binding interactions
Found 96479 entries with binding interactions
Augmentation complete:
  Original dataset: 97541 entries
  Augmented entries: 21240 entries
  Final dataset: 118781 entries

Validating augmentation...
Validation Results:
  Total entries: 118781
  Original entries: 97541
  Augmented entries: 21240
  Entries with binding: 96479
  Entries without binding: 22302
  Binding ratio: 0.812

Saving augmented dataset to augmented_protein_dataset.csv...
Dataset saved successfully!

Augmentation process completed successfully!
