In [None]:
#!/usr/bin/env python3
"""
Complete BindingDB Binding Site Extraction Script
=================================================

This script automatically:
1. Loads BindingDB data from the project root directory ./BindingDB_All.tsv
2. Extracts PDB structures
3. Identifies binding sites with full sequence information
4. Creates binding arrays showing which residues participate in binding
5. Exports comprehensive results to CSV

Requirements: pip install biopython pandas numpy matplotlib

Usage:
    python binding_site_extraction.py

Author: Auto-generated binding site extraction system
"""

import pandas as pd
import numpy as np
import os
import time
import warnings
import json
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
warnings.filterwarnings('ignore')

# Check and import BioPython
try:
    from Bio.PDB import PDBParser, PDBList
    from Bio.PDB.NeighborSearch import NeighborSearch
    from Bio.SeqUtils import seq1
    BIOPYTHON_AVAILABLE = True
    print("✓ BioPython available")
except ImportError:
    print("✗ BioPython not found. Install with: pip install biopython")
    BIOPYTHON_AVAILABLE = False

class BindingDBLoader:
    """Robust BindingDB data loader with sampling options"""
    
    @staticmethod
    def get_total_lines(filepath):
        """Get total number of lines in file"""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                total_lines = sum(1 for _ in f)
            return total_lines - 1  # Subtract header
        except:
            try:
                with open(filepath, 'r', encoding='latin1') as f:
                    total_lines = sum(1 for _ in f)
                return total_lines - 1
            except:
                return None
    
    @staticmethod
    def load_bindingdb_data(filepath, max_rows=None, sample_fraction=None):
        """Load BindingDB TSV file with error handling and sampling options"""
        print(f"Loading BindingDB data from: {filepath}")
        
        # Determine how many rows to load
        if sample_fraction is not None:
            total_lines = BindingDBLoader.get_total_lines(filepath)
            if total_lines:
                max_rows = int(total_lines * sample_fraction)
                print(f"  File has ~{total_lines:,} data rows")
                print(f"  Loading {sample_fraction:.1%} = {max_rows:,} rows")
            else:
                print(f"  Could not count lines, using sample fraction as max_rows")
                max_rows = int(sample_fraction * 1000000)  # Assume 1M rows
        
        if max_rows:
            print(f"  Target rows to load: {max_rows:,}")
        else:
            print("  Loading all rows...")
        
        strategies = [
            # Strategy 1: Standard loading with error skipping
            lambda: pd.read_csv(filepath, sep='\t', on_bad_lines='skip', 
                               low_memory=False, nrows=max_rows, encoding='utf-8'),
            
            # Strategy 2: Try different encoding
            lambda: pd.read_csv(filepath, sep='\t', on_bad_lines='skip',
                               low_memory=False, nrows=max_rows, encoding='latin1'),
            
            # Strategy 3: Python engine (slower but more flexible)
            lambda: pd.read_csv(filepath, sep='\t', engine='python',
                               on_bad_lines='skip', nrows=max_rows, encoding='utf-8')
        ]
        
        for i, strategy in enumerate(strategies, 1):
            try:
                print(f"  Trying loading strategy {i}...")
                data = strategy()
                print(f"✓ Successfully loaded {len(data):,} rows with {len(data.columns)} columns")
                return data
            except Exception as e:
                print(f"  Strategy {i} failed: {e}")
                continue
        
        raise Exception("All loading strategies failed. Check file format and path.")

class ThreadedBindingSiteExtractor:
    """
    Threaded binding site extraction with full sequence information
    """
    
    def __init__(self, output_dir="pdb_structures", contact_cutoff=5.0, max_workers=4):
        self.output_dir = output_dir
        self.contact_cutoff = contact_cutoff
        self.max_workers = max_workers
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Thread-safe storage
        self.binding_sites = []
        self.failed_pdbs = []
        self.lock = threading.Lock()
        
        # Initialize shared parser and downloader
        if BIOPYTHON_AVAILABLE:
            self.parser = PDBParser(QUIET=True)
            self.pdb_list = PDBList()
        
        print(f"Initialized extractor with {max_workers} threads, {contact_cutoff}Å cutoff")
    
    def clean_pdb_ids(self, pdb_list):
        """Extract valid PDB IDs from strings"""
        clean_ids = []
        
        for pdb_string in pdb_list:
            if isinstance(pdb_string, str):
                # Split by various separators
                for separator in [',', ';', '|', ' ']:
                    if separator in pdb_string:
                        pdb_string = pdb_string.replace(separator, ' ')
                
                # Extract individual PDB IDs
                for pdb_id in pdb_string.split():
                    pdb_id = pdb_id.strip().upper()
                    if len(pdb_id) == 4 and pdb_id.isalnum():
                        clean_ids.append(pdb_id)
        
        return list(set(clean_ids))
    
    def download_pdb_structure(self, pdb_id):
        """Download PDB structure (thread-safe)"""
        expected_filename = os.path.join(self.output_dir, f"pdb{pdb_id.lower()}.ent")
        
        if os.path.exists(expected_filename):
            return expected_filename
        
        try:
            filename = self.pdb_list.retrieve_pdb_file(
                pdb_id, pdir=self.output_dir, file_format='pdb'
            )
            return filename if filename and os.path.exists(filename) else None
        except Exception:
            return None
    
    def extract_full_sequences(self, structure):
        """Extract complete protein sequences from PDB structure"""
        sequences = {}
        
        for model in structure:
            for chain in model:
                chain_id = chain.id
                sequence = ""
                
                # Get all protein residues in order
                protein_residues = [res for res in chain if res.id[0] == ' ']
                protein_residues.sort(key=lambda x: x.id[1])
                
                for residue in protein_residues:
                    try:
                        aa = seq1(residue.resname)
                        sequence += aa
                    except KeyError:
                        sequence += 'X'  # Unknown amino acid
                
                if sequence:
                    sequences[chain_id] = {
                        'sequence': sequence,
                        'length': len(sequence),
                        'residue_numbers': [res.id[1] for res in protein_residues]
                    }
        
        return sequences
    
    def find_ligands(self, structure):
        """Find meaningful ligands in structure"""
        # Molecules to exclude (not drug-like ligands)
        exclude_molecules = {
            'HOH', 'WAT', 'H2O',  # Water
            'NA', 'CL', 'MG', 'CA', 'ZN', 'FE', 'MN', 'CO', 'NI', 'CU',  # Ions
            'SO4', 'PO4', 'HPO', 'PO3',  # Phosphates/sulfates  
            'GOL', 'EDO', 'PEG', 'MPD', 'DMS', 'ACT', 'ACE',  # Solvents
            'MSE', 'SEP', 'TPO',  # Modified amino acids
            'UNK', 'UNL',  # Unknown molecules
            'BME', 'DTT', 'TCEP',  # Reducing agents
            'TRS', 'HEPES', 'MES', 'TRIS',  # Buffers
            'GLY', 'GLYC',  # Glycerol variants
        }
        
        ligands = []
        for model in structure:
            for chain in model:
                for residue in chain:
                    if (residue.id[0] != ' ' and 
                        residue.resname not in exclude_molecules and
                        len(list(residue.get_atoms())) >= 5):  # At least 5 atoms
                        ligands.append(residue)
        
        return ligands
    
    def find_contact_residues(self, structure, ligand):
        """Find protein residues in contact with ligand"""
        all_atoms = list(structure.get_atoms())
        ns = NeighborSearch(all_atoms)
        ligand_atoms = list(ligand.get_atoms())
        
        contact_residues = set()
        for lig_atom in ligand_atoms:
            nearby_atoms = ns.search(lig_atom.coord, self.contact_cutoff)
            for atom in nearby_atoms:
                if atom.parent.id[0] == ' ':  # Protein residue
                    contact_residues.add(atom.parent)
        
        return list(contact_residues)
    
    def map_contacts_to_sequence(self, contact_residues, chain_sequences):
        """Create binding arrays mapping contacts to sequence positions"""
        binding_info = {}
        
        for chain_id, seq_info in chain_sequences.items():
            binding_array = np.zeros(seq_info['length'], dtype=int)
            binding_positions = []
            binding_residue_numbers = []
            
            # Find contacts for this chain
            chain_contacts = [res for res in contact_residues if res.parent.id == chain_id]
            
            for contact_res in chain_contacts:
                res_num = contact_res.id[1]
                try:
                    seq_position = seq_info['residue_numbers'].index(res_num)
                    binding_array[seq_position] = 1
                    binding_positions.append(seq_position)
                    binding_residue_numbers.append(res_num)
                except ValueError:
                    continue  # Residue not in sequence
            
            binding_info[chain_id] = {
                'sequence': seq_info['sequence'],
                'binding_array': binding_array.tolist(),
                'binding_positions': binding_positions,
                'binding_residue_numbers': binding_residue_numbers,
                'num_binding_residues': len(binding_positions)
            }
        
        return binding_info
    
    def extract_binding_site_info(self, pdb_id, structure, ligand, contact_residues, binding_info):
        """Extract comprehensive binding site information"""
        
        # Calculate ligand center
        coords = [atom.coord for atom in ligand.get_atoms()]
        if coords:
            center = np.mean(coords, axis=0)
            ligand_center = {'x': float(center[0]), 'y': float(center[1]), 'z': float(center[2])}
        else:
            ligand_center = {'x': 0, 'y': 0, 'z': 0}
        
        # Sort contact residues
        sorted_contacts = sorted(contact_residues, key=lambda x: (x.parent.id, x.id[1]))
        
        # Extract contact residue details
        contact_residue_info = []
        for residue in sorted_contacts:
            # Calculate minimum distance to ligand
            min_dist = float('inf')
            for res_atom in residue.get_atoms():
                for lig_atom in ligand.get_atoms():
                    dist = np.linalg.norm(res_atom.coord - lig_atom.coord)
                    min_dist = min(min_dist, dist)
            
            contact_residue_info.append({
                'chain': residue.parent.id,
                'resname': residue.resname,
                'resnum': residue.id[1],
                'insertion_code': residue.id[2] if residue.id[2] != ' ' else '',
                'distance_to_ligand': float(min_dist)
            })
        
        # Estimate binding site volume
        if contact_residues:
            coords = np.array([atom.coord for residue in contact_residues 
                              for atom in residue.get_atoms()])
            if len(coords) >= 4:
                min_coords, max_coords = np.min(coords, axis=0), np.max(coords, axis=0)
                volume = float(np.prod(max_coords - min_coords))
            else:
                volume = 0.0
        else:
            volume = 0.0
        
        return {
            'pdb_id': pdb_id,
            'ligand_name': ligand.resname,
            'ligand_chain': ligand.parent.id,
            'ligand_number': ligand.id[1],
            'num_contact_residues': len(contact_residues),
            'contact_residues': contact_residue_info,
            'contact_sequence': ''.join([res['resname'] for res in contact_residue_info]),
            'ligand_center': ligand_center,
            'binding_site_volume': volume,
            'chain_sequences': binding_info
        }
    
    def process_single_pdb(self, pdb_id):
        """Process a single PDB structure"""
        if not BIOPYTHON_AVAILABLE:
            return {'pdb_id': pdb_id, 'status': 'error', 'reason': 'BioPython not available'}
        
        try:
            # Download structure
            filename = self.download_pdb_structure(pdb_id)
            if not filename:
                return {'pdb_id': pdb_id, 'status': 'error', 'reason': 'Download failed'}
            
            # Parse structure (suppress warnings)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                structure = self.parser.get_structure(pdb_id, filename)
            
            # Extract sequences and find ligands
            chain_sequences = self.extract_full_sequences(structure)
            ligands = self.find_ligands(structure)
            
            if not ligands:
                return {'pdb_id': pdb_id, 'status': 'error', 'reason': 'No ligands found'}
            
            # Process each ligand
            pdb_binding_sites = []
            for ligand in ligands:
                contact_residues = self.find_contact_residues(structure, ligand)
                if contact_residues:
                    binding_info = self.map_contacts_to_sequence(contact_residues, chain_sequences)
                    binding_site = self.extract_binding_site_info(
                        pdb_id, structure, ligand, contact_residues, binding_info
                    )
                    pdb_binding_sites.append(binding_site)
            
            # Thread-safe storage
            with self.lock:
                self.binding_sites.extend(pdb_binding_sites)
            
            return {
                'pdb_id': pdb_id, 
                'status': 'success', 
                'binding_sites': pdb_binding_sites,
                'num_sites': len(pdb_binding_sites)
            }
            
        except Exception as e:
            error_msg = "Duplicate residue numbering" if "defined twice" in str(e) else str(e)
            with self.lock:
                self.failed_pdbs.append({'pdb_id': pdb_id, 'status': 'error', 'reason': error_msg})
            return {'pdb_id': pdb_id, 'status': 'error', 'reason': error_msg}
    
    def process_threaded(self, pdb_ids, max_structures=None, verbose=True):
        """Process PDB IDs using threading"""
        clean_ids = self.clean_pdb_ids(pdb_ids)
        
        if max_structures:
            clean_ids = clean_ids[:max_structures]
        
        print(f"Processing {len(clean_ids)} unique PDB IDs with {self.max_workers} threads...")
        
        completed = 0
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_pdb = {executor.submit(self.process_single_pdb, pdb_id): pdb_id 
                           for pdb_id in clean_ids}
            
            for future in as_completed(future_to_pdb):
                pdb_id = future_to_pdb[future]
                completed += 1
                
                try:
                    result = future.result(timeout=300)
                    if verbose and completed % 10 == 0:
                        print(f"Completed {completed}/{len(clean_ids)} structures... "
                              f"Found {len(self.binding_sites)} binding sites so far")
                    
                    if result['status'] != 'success' and verbose and 'Download failed' not in result['reason']:
                        print(f"Failed {pdb_id}: {result['reason']}")
                
                except Exception as e:
                    with self.lock:
                        self.failed_pdbs.append({'pdb_id': pdb_id, 'status': 'error', 'reason': str(e)})
                    if verbose:
                        print(f"Exception processing {pdb_id}: {str(e)}")
        
        print(f"\nCompleted! Found {len(self.binding_sites)} binding sites")
        print(f"Successfully processed: {len(clean_ids) - len(self.failed_pdbs)}/{len(clean_ids)} structures")
        print(f"Failed: {len(self.failed_pdbs)} structures")
        
        return self.binding_sites, self.failed_pdbs
    
    def export_enhanced_results(self, output_file="binding_sites_with_sequences.csv"):
        """Export results with full sequence information"""
        if not self.binding_sites:
            print("No binding sites to export")
            return None
        
        print(f"Exporting {len(self.binding_sites)} binding sites...")
        
        flattened_data = []
        for site in self.binding_sites:
            # Base information
            base_row = {
                'pdb_id': site['pdb_id'],
                'ligand_name': site['ligand_name'],
                'ligand_chain': site['ligand_chain'],
                'ligand_number': site['ligand_number'],
                'num_contact_residues': site['num_contact_residues'],
                'contact_sequence': site['contact_sequence'],
                'ligand_center_x': site['ligand_center']['x'],
                'ligand_center_y': site['ligand_center']['y'],
                'ligand_center_z': site['ligand_center']['z'],
                'binding_site_volume': site['binding_site_volume']
            }
            
            # Contact residue details
            contact_info = []
            contact_resnums = []
            contact_distances = []
            
            for residue in site['contact_residues']:
                res_string = f"{residue['resname']}{residue['resnum']}{residue['insertion_code']}"
                contact_info.append(res_string)
                contact_resnums.append(residue['resnum'])
                contact_distances.append(residue['distance_to_ligand'])
            
            base_row.update({
                'contact_residues_list': '|'.join(contact_info),
                'contact_residue_numbers': '|'.join(map(str, contact_resnums)),
                'contact_distances': '|'.join([f"{d:.2f}" for d in contact_distances])
            })
            
            # Full sequence information for each chain
            for chain_id, chain_info in site['chain_sequences'].items():
                base_row.update({
                    f'chain_{chain_id}_sequence': chain_info['sequence'],
                    f'chain_{chain_id}_length': len(chain_info['sequence']),
                    f'chain_{chain_id}_binding_array': json.dumps(chain_info['binding_array']),
                    f'chain_{chain_id}_binding_positions': json.dumps(chain_info['binding_positions']),
                    f'chain_{chain_id}_binding_residue_numbers': json.dumps(chain_info['binding_residue_numbers']),
                    f'chain_{chain_id}_num_binding_residues': chain_info['num_binding_residues']
                })
            
            flattened_data.append(base_row)
        
        # Export to CSV
        df = pd.DataFrame(flattened_data)
        df.to_csv(output_file, index=False)
        
        print(f"✓ Exported to {output_file}")
        return df

def extract_binding_sites_from_bindingdb(bindingdb_file, max_structures=100, max_workers=4, 
                                        output_file="binding_sites_with_sequences.csv",
                                        sample_fraction=0.5):
    """
    Complete pipeline: Load BindingDB data and extract binding sites
    
    Args:
        bindingdb_file: Path to BindingDB TSV file
        max_structures: Maximum PDB structures to process
        max_workers: Number of threads to use
        output_file: Output CSV filename
        sample_fraction: Fraction of dataset to load (0.5 = half, 1.0 = all)
    """
    print("="*60)
    print("AUTOMATED BINDING SITE EXTRACTION FROM BINDINGDB")
    print("="*60)
    
    if not BIOPYTHON_AVAILABLE:
        print("ERROR: BioPython is required. Install with: pip install biopython")
        return None, None, None
    
    # Step 1: Load BindingDB data (with sampling)
    try:
        loader = BindingDBLoader()
        data = loader.load_bindingdb_data(bindingdb_file, sample_fraction=sample_fraction)
    except Exception as e:
        print(f"Failed to load BindingDB data: {e}")
        return None, None, None
    
    # Step 2: Extract PDB IDs
    pdb_columns = ['PDB ID(s) for Ligand-Target Complex', 'PDB ID(s) of Target Chain 1']
    all_pdb_ids = []
    
    for col in pdb_columns:
        if col in data.columns:
            pdb_series = data[col].dropna()
            all_pdb_ids.extend(pdb_series.tolist())
    
    if not all_pdb_ids:
        print("No PDB IDs found in the data")
        return None, None, None
    
    print(f"Found {len(all_pdb_ids)} PDB entries in the loaded data")
    
    # Step 3: Extract binding sites
    extractor = ThreadedBindingSiteExtractor(max_workers=max_workers)
    binding_sites, failed_pdbs = extractor.process_threaded(
        all_pdb_ids, max_structures=max_structures
    )
    
    # Step 4: Export results
    results_df = extractor.export_enhanced_results(output_file)
    
    # Step 5: Generate summary
    summary = None
    if binding_sites:
        ligand_names = [site['ligand_name'] for site in binding_sites]
        ligand_counts = pd.Series(ligand_names).value_counts()
        contact_counts = [site['num_contact_residues'] for site in binding_sites]
        contact_stats = pd.Series(contact_counts).describe()
        
        print(f"\n" + "="*40)
        print("SUMMARY STATISTICS")
        print("="*40)
        print(f"Total binding sites: {len(binding_sites)}")
        print(f"Unique PDBs processed: {len(set([site['pdb_id'] for site in binding_sites]))}")
        print(f"Average contact residues: {contact_stats['mean']:.1f}")
        print(f"Contact residue range: {contact_stats['min']:.0f} - {contact_stats['max']:.0f}")
        
        print(f"\nTop 5 most common ligands:")
        for ligand, count in ligand_counts.head(5).items():
            print(f"  {ligand}: {count} binding sites")
        
        summary = {
            'binding_sites': len(binding_sites),
            'failed_pdbs': len(failed_pdbs),
            'ligand_counts': ligand_counts,
            'contact_stats': contact_stats,
            'data_fraction_loaded': sample_fraction
        }
    
    return binding_sites, results_df, summary

def analyze_results(results_df):
    """Analyze extracted binding sites"""
    if results_df is None:
        print("No results to analyze")
        return
    
    print(f"\n" + "="*40)
    print("BINDING SITE ANALYSIS")
    print("="*40)
    
    # Basic analysis
    print(f"Total binding sites extracted: {len(results_df)}")
    
    # Find examples with sequence data
    examples_found = 0
    for idx, row in results_df.head(5).iterrows():
        print(f"\nExample {examples_found + 1}:")
        print(f"  PDB: {row['pdb_id']}, Ligand: {row['ligand_name']}")
        print(f"  Contact residues: {row['num_contact_residues']}")
        
        # Look for sequence data
        for col in row.index:
            if col.endswith('_sequence') and pd.notna(row[col]):
                chain_id = col.split('_')[1]
                sequence = row[col]
                binding_array_col = f'chain_{chain_id}_binding_array'
                
                if binding_array_col in row and pd.notna(row[binding_array_col]):
                    binding_array = json.loads(row[binding_array_col])
                    binding_positions = json.loads(row[f'chain_{chain_id}_binding_positions'])
                    
                    print(f"  Chain {chain_id}: {len(sequence)} residues")
                    print(f"  Binding positions: {binding_positions[:10]}...")
                    print(f"  Binding %: {sum(binding_array)/len(binding_array)*100:.1f}%")
                    examples_found += 1
                    break
        
        if examples_found >= 3:
            break

def main():
    """Main execution function"""
    # Configuration
    BINDINGDB_FILE = "BindingDB_All.tsv"  # Change this to your file path
    MAX_STRUCTURES = 100000   # Start with 100, increase as needed
    MAX_WORKERS = 10        # Number of threads
    SAMPLE_FRACTION = 0.04  # Load 50% of the dataset (0.5 = half, 1.0 = all)
    OUTPUT_FILE = "binding_sites_with_sequences.csv"
    
    print("Starting automated binding site extraction...")
    print(f"File: {BINDINGDB_FILE}")
    print(f"Sample fraction: {SAMPLE_FRACTION:.1%} of dataset")
    print(f"Max structures: {MAX_STRUCTURES}")
    print(f"Threads: {MAX_WORKERS}")
    
    # Run extraction
    binding_sites, results_df, summary = extract_binding_sites_from_bindingdb(
        BINDINGDB_FILE, 
        max_structures=MAX_STRUCTURES,
        max_workers=MAX_WORKERS,
        output_file=OUTPUT_FILE,
        sample_fraction=SAMPLE_FRACTION
    )
    
    # Analyze results
    analyze_results(results_df)
    
    print(f"\n" + "="*60)
    if results_df is not None:
        print(f"SUCCESS! Extracted {len(results_df)} binding sites")
        print(f"Results saved to: {OUTPUT_FILE}")
        print(f"Data fraction used: {SAMPLE_FRACTION:.1%}")
        print("\nNext steps:")
        print("1. Increase SAMPLE_FRACTION to 1.0 for full dataset")
        print("2. Increase MAX_STRUCTURES for more PDB processing")
        print("3. Analyze binding arrays using json.loads()")
        print("4. Use binding positions for machine learning")
    else:
        print("FAILED! Check error messages above")
    print("="*60)

if __name__ == "__main__":
    main()

✓ BioPython available
Starting automated binding site extraction...
File: BindingDB_All.tsv
Sample fraction: 4.0% of dataset
Max structures: 100000
Threads: 10
AUTOMATED BINDING SITE EXTRACTION FROM BINDINGDB
Loading BindingDB data from: BindingDB_All.tsv
  File has ~3,046,134 data rows
  Loading 4.0% = 121,845 rows
  Target rows to load: 121,845
  Trying loading strategy 1...
✓ Successfully loaded 121,845 rows with 640 columns
Found 104553 PDB entries in the loaded data
Initialized extractor with 10 threads, 5.0Å cutoff
Processing 32845 unique PDB IDs with 10 threads...
Downloading PDB structure '8sor'...
Downloading PDB structure '2qet'...
Downloading PDB structure '8e59'...
Downloading PDB structure '6w8h'...
Downloading PDB structure '4txo'...
Downloading PDB structure '1yw8'...
Downloading PDB structure '4gme'...
Failed 5XDH: Duplicate residue numbering
Downloading PDB structure '1pyg'...
Failed 6E3Z: Duplicate residue numbering
Failed 1YW8: Duplicate residue numbering
Downloading