# 🧬💊 Boltz-2 Pipeline for Drug Screening, Mutant Discovery, and Protein Structure Prediction

A comprehensive interface for protein-ligand screening and structure prediction using, with integrated mutation discovery from public
databases for rapid assessment of mutation effects on drug binding.

## 🔬 What This Tool Does

This notebook provides a complete workflow for computational drug screening:

1. 🧬 Protein Structure Prediction: Uses Boltz-2 state-of-the-art AI to predict 3D protein structures from amino acid sequences
2. 🧪 Ligand Binding Analysis: Predicts how small molecules (drugs) bind to proteins and estimates binding strength
3. ⚗️ Co-factor Integration: Includes essential biological molecules (ATP, NAD, heme) that proteins need to function
4. 🎯 Targeted Design: Constrains binding to specific protein regions for precision drug design
5. 🔍 Mutation Discovery Mode: Automatically queries public databases (UniProt) to find known mutations for your protein
6. 🧬 Mutation Studies: Generates protein variants to evaluate how mutations affect drug binding
7. 📊 Comprehensive Analysis: Provides binding affinities, confidence scores, and downloadable 3D structures

## 📚 Quick Tutorial: How to Use This Tool

### Step 1: 🔧 Setup & Environment
**Run the Installation & Setup cell first:**
- ☁️ **Google Colab mode** (default): Automatically installs dependencies and enables file uploads
- 📁 **Local mode**: For running on your own computer with pre-installed packages

**Local Mode Setup:**
1. Check \"Running locally\" checkbox in the setup cell
2. Place these files in your working directory:
   - `protein.fasta` - Protein sequences in FASTA format
   - `drug.fasta` - SMILES strings in FASTA format (optional)
   - `template.cif` - Template structure file (optional)
3. Optionally uncheck \"Download prerequisites\" if packages are already installed

### Step 2: 🧬 Configure Your Screening
Choose your operation mode and configure inputs:

🔸 **Input Modes:**
- **Protein-Drug Screening**: Paste single wild-type protein sequence directly
- **Protein-Drug Screening - Upload FASTA File**:
  - ☁️ Colab: Upload file with multiple protein sequences
  - 📁 Local: Uses `protein.fasta` from working directory
- **Mutations Mode**: Start with wild-type sequence and generate specific mutant variants
- **Structure-only Mode**: Predict protein structure without ligands (for protein folding studies)

🔸 **Mutation Discovery Mode:**
- **Database Query**: Automatically search UniProt for known mutations of your protein
- **Manual Selection**: Choose specific mutations from suggested list or enter your own
- **Detailed Analysis**: Get mutation impact predictions and severity assessments

🔸 **Ligand Input Options:**
- **Text Input**: Enter SMILES strings
- **Protein-Drug Screening - Upload FASTA File**:
  - ☁️ Colab: Upload FASTA format file with named SMILES entries
  - 📁 Local: Uses `drug.fasta` from working directory
- **Co-factor Integration**: Add biological co-factors (ATP, NAD, heme) for realistic binding

### Step 3: ⚙️ Advanced Parameters (Optional)
Configure additional features:
- **Co-factors**: Add essential molecules like ATP, NAD, or heme
- **Binding Constraints**: Target specific protein regions
- **Post-Translational Modifications**: Add chemical modifications to proteins
- **Template Structures**:
  - ☁️ Colab: Upload .cif template files
  - 📁 Local: Uses `template.cif` from working directory

### Step 4: 🚀 Configure & Generate
Set computational parameters:
- **GPU Settings**: Enable/disable GPU acceleration
- **Sampling Parameters**: Control prediction quality vs speed
- **Error Handling**: Automatic retry settings

### Step 5: 🌠 Run Predictions
- Execute the Structure Predictions cell
- Monitor progress (typically 5-10 minutes per protein-ligand pair using Boltz-2)
- System will generate 3D structures and binding predictions

### Step 6: 📊 View Results
- Check the Results & Download cell for:
  - Binding affinity predictions (how strongly drugs bind)
  - 3D structure files (.pdb format)
  - Confidence scores and quality metrics
  - Downloadable results package

## 🏠 Local Execution vs ☁️ Cloud Execution

| Feature | Local Mode 📁 | Colab Mode ☁️ |
|---------|---------------|---------------|
| **Setup** | Check \"Running locally\" | Default mode |
| **Dependencies** | Pre-installed or skip | Auto-installs |
| **File Input** | Predefined files | Interactive uploads |
| **Protein Data** | `protein.fasta` | Upload or manual entry |
| **Drug Data** | `drug.fasta` | Upload or text input |
| **Templates** | `template.cif` | Upload .cif files |
| **Resources** | Your hardware | Google's servers |

---

In [None]:
#@title 1️⃣ 🔧 Installation & Setup { display-mode: "form" }
#@markdown Run this cell first to install required dependencies and load all functions:

#@markdown **Download prerequisites**: Install required packages. Disable only if running locally with packages pre-installed. It is normal to be disconnected temporarily from the Colab runtime during installation.
download_prerequisites = True  #@param {type:"boolean"}

#@markdown **Running locally**: Skip Colab-specific installations and use local files (protein.fasta, drug.fasta, template.cif)
running_locally = False  #@param {type:"boolean"}

# Minimal approach - work with what's already available in Colab
import sys
import subprocess
import importlib.util

print("🔄 Working with Colab's existing environment...")

# Import core libraries that are pre-installed
import json
import os
import re
import math
import glob
import time
import zipfile
import sys
from typing import List, Dict, Tuple, Optional, Union
from collections import OrderedDict
from datetime import datetime
import copy
import requests
import urllib.parse
import yaml

# Try importing pandas/numpy without installing anything new
try:
    import pandas as pd
    import numpy as np
    print("✅ Pandas and NumPy loaded from Colab")
    PANDAS_AVAILABLE = True
    NUMPY_AVAILABLE = True
except Exception as e:
    print(f"⚠️ Pandas/NumPy issue: {e}")
    print("🔄 Continuing without pandas - using basic Python data structures")
    PANDAS_AVAILABLE = False
    NUMPY_AVAILABLE = False
    # Create minimal substitutes
    class MockPandas:
        def __init__(self):
            self.__version__ = "unavailable"
    pd = MockPandas()

    class MockNumPy:
        def __init__(self):
            self.__version__ = "unavailable"
    np = MockNumPy()

# Google Colab specific imports - skip if running locally
if not running_locally:
    try:
        from google.colab import files
        from google.colab import widgets
        COLAB_AVAILABLE = True
        print("✅ Google Colab environment detected")
    except ImportError:
        COLAB_AVAILABLE = False
        print("⚠️ Warning: Google Colab not available. Some features may not work.")
else:
    COLAB_AVAILABLE = False
    print("📁 Local mode: Using local files instead of Colab uploads")
    # Create mock files module for local execution
    class MockFiles:
        def upload(self):
            print("📁 Local mode: Looking for predefined files in current directory")
            return {}
    files = MockFiles()

# Try optional packages without forcing installation if download_prerequisites is disabled
PLOTLY_AVAILABLE = False
YAML_AVAILABLE = False
BOLTZ_AVAILABLE = False

# Try to import RDKit for enhanced SMILES validation
RDKIT_AVAILABLE = False
try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Crippen
    RDKIT_AVAILABLE = True
    print("✅ RDKit available for enhanced SMILES validation")
except ImportError:
    print("⚠️ RDKit not available - will use basic SMILES validation")

if download_prerequisites:
    try:
        import plotly.graph_objects as go
        import plotly.express as px
        PLOTLY_AVAILABLE = True
        print("✅ Plotly available")
    except ImportError:
        print("⚠️ Plotly not available - will use text-based output")

    try:
        import yaml
        YAML_AVAILABLE = True
        print("✅ YAML available")
    except ImportError:
        print("⚠️ YAML not available - will use custom YAML generation")

    try:
        import torch
        print(f"✅ PyTorch {torch.__version__} available")
    except ImportError:
        print("⚠️ PyTorch not available")

    # Try to install and import Boltz-2 with detailed error logging
    print("🔍 Attempting Boltz-2 installation with detailed logging:")
    try:
        print("Step 1: Installing Boltz-2...")
        result = subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'boltz'],
                              capture_output=True, text=True, check=True)
        print("✅ Boltz-2 pip install successful")

        print("Step 2: Trying to import boltz...")
        import boltz
        BOLTZ_AVAILABLE = True
        print("✅ Boltz-2 import successful")
        print(f"   Boltz version: {getattr(boltz, '__version__', 'unknown')}")

    except subprocess.CalledProcessError as install_error:
        print("❌ Boltz-2 pip install failed:")
        print(f"   Return code: {install_error.returncode}")
        print(f"   STDOUT: {install_error.stdout}")
        print(f"   STDERR: {install_error.stderr}")
        BOLTZ_AVAILABLE = False

    except ImportError as import_error:
        print("❌ Boltz-2 import failed:")
        print(f"   Error: {import_error}")
        print("   Pip install succeeded but import failed")
        BOLTZ_AVAILABLE = False

    except Exception as other_error:
        print("❌ Boltz-2 unexpected error:")
        print(f"   Error type: {type(other_error).__name__}")
        print(f"   Error message: {other_error}")
        BOLTZ_AVAILABLE = False

    # If Boltz failed, try alternative installation methods
    if not BOLTZ_AVAILABLE:
        print("🔄 Trying alternative Boltz-2 installation methods:")

        # Method 1: Try installing without dependencies first
        try:
            print("Method 1: Installing boltz without dependencies...")
            result = subprocess.run([sys.executable, '-m', 'pip', 'install', '--no-deps', 'boltz'],
                                  capture_output=True, text=True, check=True)
            print("✅ No-deps install successful, trying import...")
            import boltz
            BOLTZ_AVAILABLE = True
            print("✅ Boltz-2 working with no-deps install")
        except Exception as e:
            print(f"❌ No-deps method failed: {e}")

        # Method 2: Try installing with specific index
        if not BOLTZ_AVAILABLE:
            try:
                print("Method 2: Installing from PyPI with specific flags...")
                result = subprocess.run([sys.executable, '-m', 'pip', 'install', 'boltz', '--force-reinstall'],
                                      capture_output=True, text=True, check=True)
                print("✅ Force reinstall successful, trying import...")
                import boltz
                BOLTZ_AVAILABLE = True
                print("✅ Boltz-2 working with force reinstall")
            except Exception as e:
                print(f"❌ Force reinstall method failed: {e}")
else:
    print("⚠️ Skipping dependency downloads (running locally or download_prerequisites=False)")
    try:
        import boltz
        BOLTZ_AVAILABLE = True
        print("✅ Boltz-2 already available locally")
    except ImportError:
        BOLTZ_AVAILABLE = False
        print("⚠️ Boltz-2 not available locally")

# Print environment info
print(f"📊 Environment Info:")
print(f"  - Python: {sys.version.split()[0]}")
if NUMPY_AVAILABLE:
    print(f"  - NumPy: {np.__version__}")
if PANDAS_AVAILABLE:
    print(f"  - Pandas: {pd.__version__}")
print(f"  - Requests: {requests.__version__}")
print(f"  - Local mode: {running_locally}")
print(f"  - Colab available: {COLAB_AVAILABLE}")

# Constants
ESTIMATED_TIME_PER_JOB = 300  # 5 minutes per job in seconds
RESULTS_DIR = "drug_screening_results"

# Additional chemical data
COMMON_CCD_CODES = {
    'HEM': 'Heme (protoporphyrin IX)',
    'ATP': 'Adenosine triphosphate',
    'ADP': 'Adenosine diphosphate',
    'AMP': 'Adenosine monophosphate',
    'NAD': 'Nicotinamide adenine dinucleotide',
    'NADP': 'Nicotinamide adenine dinucleotide phosphate',
    'FAD': 'Flavin adenine dinucleotide',
    'FMN': 'Flavin mononucleotide',
    'COA': 'Coenzyme A',
    'PLP': 'Pyridoxal 5\'-phosphate',
    'THF': 'Tetrahydrofolate',
    'B12': 'Vitamin B12 (cobalamin)'
}

# Standard amino acids
STANDARD_AA = set('ACDEFGHIKLMNPQRSTVWY')

# Amino acid properties for mutation analysis
AA_PROPERTIES = {
    'A': {'name': 'Alanine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Small'},
    'R': {'name': 'Arginine', 'type': 'Polar', 'charge': 'Positive', 'size': 'Large'},
    'N': {'name': 'Asparagine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Medium'},
    'D': {'name': 'Aspartic acid', 'type': 'Polar', 'charge': 'Negative', 'size': 'Medium'},
    'C': {'name': 'Cysteine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Small'},
    'Q': {'name': 'Glutamine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Medium'},
    'E': {'name': 'Glutamic acid', 'type': 'Polar', 'charge': 'Negative', 'size': 'Medium'},
    'G': {'name': 'Glycine', 'type': 'Special', 'charge': 'Neutral', 'size': 'Small'},
    'H': {'name': 'Histidine', 'type': 'Polar', 'charge': 'Positive', 'size': 'Medium'},
    'I': {'name': 'Isoleucine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Medium'},
    'L': {'name': 'Leucine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Medium'},
    'K': {'name': 'Lysine', 'type': 'Polar', 'charge': 'Positive', 'size': 'Medium'},
    'M': {'name': 'Methionine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Medium'},
    'F': {'name': 'Phenylalanine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Large'},
    'P': {'name': 'Proline', 'type': 'Special', 'charge': 'Neutral', 'size': 'Small'},
    'S': {'name': 'Serine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Small'},
    'T': {'name': 'Threonine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Small'},
    'W': {'name': 'Tryptophan', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Large'},
    'Y': {'name': 'Tyrosine', 'type': 'Polar', 'charge': 'Neutral', 'size': 'Large'},
    'V': {'name': 'Valine', 'type': 'Hydrophobic', 'charge': 'Neutral', 'size': 'Small'}
}

# Helper functions for mutation analysis
def classify_mutation_type(wt_aa: str, mut_aa: str) -> str:
    """Classify mutation based on amino acid properties."""
    if wt_aa not in AA_PROPERTIES or mut_aa not in AA_PROPERTIES:
        return "Unknown"

    wt_props = AA_PROPERTIES[wt_aa]
    mut_props = AA_PROPERTIES[mut_aa]

    # Check charge change
    if wt_props['charge'] != mut_props['charge']:
        return "Charge change"

    # Check chemical property change
    if wt_props['type'] != mut_props['type']:
        return "Chemical change"

    # Check size change
    if wt_props['size'] != mut_props['size']:
        return "Size change"

    return "Conservative"

def assess_mutation_severity(wt_aa: str, mut_aa: str) -> str:
    """Assess the predicted severity of a mutation."""
    if wt_aa not in AA_PROPERTIES or mut_aa not in AA_PROPERTIES:
        return "Unknown"

    wt_props = AA_PROPERTIES[wt_aa]
    mut_props = AA_PROPERTIES[mut_aa]

    # Charge changes are typically severe
    if wt_props['charge'] != mut_props['charge']:
        return "High"

    # Proline or glycine involved
    if wt_aa == 'P' or mut_aa == 'P' or wt_aa == 'G' or mut_aa == 'G':
        return "High"

    # Large size changes
    if (wt_props['size'] == 'Large' and mut_props['size'] == 'Small') or \
       (wt_props['size'] == 'Small' and mut_props['size'] == 'Large'):
        return "Moderate"

    # Chemical property changes
    if wt_props['type'] != mut_props['type']:
        return "Moderate"

    return "Low"

def predict_structural_impact(wt_aa: str, mut_aa: str) -> str:
    """Predict structural impact of mutation."""
    if wt_aa not in AA_PROPERTIES or mut_aa not in AA_PROPERTIES:
        return "Unknown"

    wt_props = AA_PROPERTIES[wt_aa]
    mut_props = AA_PROPERTIES[mut_aa]

    # Proline changes
    if wt_aa == 'P' and mut_aa != 'P':
        return "Loss of rigidity"
    elif wt_aa != 'P' and mut_aa == 'P':
        return "Increased rigidity"

    # Glycine changes
    if wt_aa == 'G' and mut_aa != 'G':
        return "Loss of flexibility"
    elif wt_aa != 'G' and mut_aa == 'G':
        return "Increased flexibility"

    # Cysteine changes (disulfide bonds)
    if wt_aa == 'C' and mut_aa != 'C':
        return "Disulfide loss"
    elif wt_aa != 'C' and mut_aa == 'C':
        return "Potential disulfide"

    # Size changes
    if wt_props['size'] != mut_props['size']:
        if wt_props['size'] == 'Large' and mut_props['size'] == 'Small':
            return "Cavity formation"
        elif wt_props['size'] == 'Small' and mut_props['size'] == 'Large':
            return "Steric clash"
        else:
            return "Size change"

    return "Minimal"

def generate_example_mutations_detailed(sequence: str, chain_start: int, max_mutations: int) -> List[Dict]:
    """Generate example mutations with detailed analysis when database lookup fails."""
    mutations = []

    # Generate mutations for positions throughout the sequence
    positions = list(range(10, len(sequence) - 10, max(1, len(sequence) // max_mutations)))[:max_mutations]

    common_mutations = ['A', 'V', 'L', 'I', 'F', 'Y', 'W', 'S', 'T', 'N', 'Q', 'R', 'K', 'D', 'E', 'C', 'G', 'P', 'H', 'M']

    for i, pos in enumerate(positions):
        if pos < len(sequence):
            wt_aa = sequence[pos]
            mut_aa = common_mutations[i % len(common_mutations)]

            if wt_aa != mut_aa:  # Only include if different
                input_pos = chain_start + pos
                mutation_entry = {
                    'mutation_code': f"{wt_aa}{input_pos}{mut_aa}",
                    'canonical_code': f"{wt_aa}{input_pos}{mut_aa}",
                    'input_position': input_pos,
                    'canonical_position': input_pos,
                    'wild_type_aa': wt_aa,
                    'mutant_aa': mut_aa,
                    'description': f"Example mutation {wt_aa} to {mut_aa} at position {input_pos}",
                    'evidence_count': 0,
                    'change_type': classify_mutation_type(wt_aa, mut_aa),
                    'severity': assess_mutation_severity(wt_aa, mut_aa),
                    'structural_impact': predict_structural_impact(wt_aa, mut_aa),
                    'matches_input': True
                }
                mutations.append(mutation_entry)

    return mutations

# Protein parsing functions (from utils.py)
def parse_protein_chains(protein_sequence):
    """
    Parse protein sequence into individual chains.
    Converts input to uppercase before splitting and storing.
    Based on utils.py implementation.

    Args:
        protein_sequence (str): Protein sequence(s) separated by colons

    Returns:
        list: List of dictionaries with chain information
    """
    protein_sequence = protein_sequence.upper()
    protein_chains = []
    if protein_sequence.strip():
        # Split by colon to get individual chains
        chains = protein_sequence.split(':')

        # Validate number of chains (max 23: A-W, X reserved for ligand)
        if len(chains) > 23:
            raise ValueError(f"Too many protein chains ({len(chains)}). Maximum allowed is 23 chains (A-W). Chain X is reserved for ligands.")

        for i, chain in enumerate(chains):
            chain_id = chr(65 + i)  # A, B, C, ..., W (65-87)
            protein_chains.append({
                "protein": {
                    "id": chain_id,
                    "sequence": chain.strip()
                }
            })
    return protein_chains

# Mutation Discovery Functions
def identify_protein_from_sequence(protein_seq: str) -> Dict:
    """
    Identify protein using multiple search strategies via UniProt API.
    Returns basic protein information for mutation lookup.
    """
    try:
        # Clean sequence
        clean_seq = re.sub(r'[^A-Z]', '', protein_seq.upper())
        if len(clean_seq) < 20:
            return {"error": "Sequence too short for reliable identification"}

        seq_length = len(clean_seq)

        # Strategy 1: Try exact sequence match first
        url = "https://rest.uniprot.org/uniprotkb/search"

        # Search for exact sequence first
        exact_params = {
            'query': f'sequence:"{clean_seq}"',
            'format': 'json',
            'size': 1,
            'fields': 'accession,id,protein_name,organism_name,length,sequence'
        }

        response = requests.get(url, params=exact_params, timeout=15)
        if response.status_code == 200:
            data = response.json()
            results = data.get('results', [])
            if results:
                result = results[0]
                return {
                    "accession": result.get('primaryAccession', ''),
                    "protein_name": result.get('proteinDescription', {}).get('recommendedName', {}).get('fullName', {}).get('value', 'Unknown'),
                    "organism": result.get('organism', {}).get('scientificName', 'Unknown'),
                    "similarity": 1.0,
                    "length": result.get('sequence', {}).get('length', 0)
                }

        # Strategy 2: Search by protein family keywords for common proteins
        common_protein_patterns = [
            ('hemoglobin', ['hemoglobin', 'alpha', 'beta']),
            ('myoglobin', ['myoglobin']),
            ('insulin', ['insulin']),
            ('lysozyme', ['lysozyme']),
            ('trypsin', ['trypsin']),
            ('chymotrypsin', ['chymotrypsin']),
            ('pepsin', ['pepsin']),
            ('albumin', ['albumin']),
            ('immunoglobulin', ['immunoglobulin', 'antibody']),
            ('collagen', ['collagen']),
            ('elastin', ['elastin']),
            ('keratin', ['keratin']),
            ('actin', ['actin']),
            ('myosin', ['myosin']),
            ('tubulin', ['tubulin']),
            ('histone', ['histone'])
        ]

        # Try to identify by known protein signatures
        for protein_type, keywords in common_protein_patterns:
            keyword_params = {
                'query': f'reviewed:true AND {" OR ".join(keywords)} AND length:[{seq_length-20} TO {seq_length+20}]',
                'format': 'json',
                'size': 20,
                'fields': 'accession,id,protein_name,organism_name,length,sequence'
            }

            response = requests.get(url, params=keyword_params, timeout=15)
            if response.status_code == 200:
                data = response.json()
                results = data.get('results', [])

                # Find best sequence match from keyword results
                best_match = None
                best_similarity = 0

                for result in results:
                    seq = result.get('sequence', {}).get('value', '')
                    if seq and len(seq) > 0:
                        # Use a better similarity calculation
                        shorter_len = min(len(clean_seq), len(seq))
                        longer_len = max(len(clean_seq), len(seq))

                        # Calculate alignment-like similarity
                        matches = 0
                        for i in range(shorter_len):
                            if i < len(clean_seq) and i < len(seq) and clean_seq[i] == seq[i]:
                                matches += 1

                        # Penalize length differences
                        length_penalty = abs(len(clean_seq) - len(seq)) / longer_len
                        similarity = (matches / shorter_len) * (1 - length_penalty * 0.5)

                        if similarity > best_similarity and similarity > 0.6:
                            best_similarity = similarity
                            best_match = result

                if best_match:
                    return {
                        "accession": best_match.get('primaryAccession', ''),
                        "protein_name": best_match.get('proteinDescription', {}).get('recommendedName', {}).get('fullName', {}).get('value', 'Unknown'),
                        "organism": best_match.get('organism', {}).get('scientificName', 'Unknown'),
                        "similarity": best_similarity,
                        "length": best_match.get('sequence', {}).get('length', 0)
                    }

        # Strategy 3: Broad search with sequence fragments
        # Use first and last 20 amino acids as signature
        if len(clean_seq) >= 40:
            n_terminus = clean_seq[:20]
            c_terminus = clean_seq[-20:]

            fragment_params = {
                'query': f'reviewed:true AND length:[{seq_length-50} TO {seq_length+50}]',
                'format': 'json',
                'size': 100,
                'fields': 'accession,id,protein_name,organism_name,length,sequence'
            }

            response = requests.get(url, params=fragment_params, timeout=20)
            if response.status_code == 200:
                data = response.json()
                results = data.get('results', [])

                best_match = None
                best_similarity = 0

                for result in results:
                    seq = result.get('sequence', {}).get('value', '')
                    if seq and len(seq) >= 40:
                        # Check N and C terminal similarity
                        n_matches = sum(1 for a, b in zip(n_terminus, seq[:20]) if a == b)
                        c_matches = sum(1 for a, b in zip(c_terminus, seq[-20:]) if a == b)

                        # Calculate overall similarity with terminal emphasis
                        terminal_similarity = (n_matches + c_matches) / 40.0

                        # Also check internal regions
                        internal_matches = 0
                        internal_total = 0
                        step = max(1, len(clean_seq) // 20)  # Sample every step residues

                        for i in range(20, min(len(clean_seq), len(seq)) - 20, step):
                            if clean_seq[i] == seq[i]:
                                internal_matches += 1
                            internal_total += 1

                        internal_similarity = internal_matches / max(1, internal_total)

                        # Combined similarity score
                        combined_similarity = terminal_similarity * 0.6 + internal_similarity * 0.4

                        if combined_similarity > best_similarity and combined_similarity > 0.3:
                            best_similarity = combined_similarity
                            best_match = result

                if best_match:
                    return {
                        "accession": best_match.get('primaryAccession', ''),
                        "protein_name": best_match.get('proteinDescription', {}).get('recommendedName', {}).get('fullName', {}).get('value', 'Unknown'),
                        "organism": best_match.get('organism', {}).get('scientificName', 'Unknown'),
                        "similarity": best_similarity,
                        "length": best_match.get('sequence', {}).get('length', 0)
                    }

        return {"error": f"No protein matches found using multiple search strategies"}

    except Exception as e:
        return {"error": f"Error identifying protein: {str(e)}"}

def query_mutations_for_protein_detailed(protein_info: Dict, input_sequence: str, chain_start: int = 1, max_mutations: int = 20) -> List[Dict]:
    """
    Query known mutations for a protein from public databases with detailed information.
    Returns a list of mutation dictionaries with comprehensive details.
    """
    mutations = []

    if "error" in protein_info:
        return []

    try:
        accession = protein_info.get("accession", "")
        if not accession:
            return []

        # Clean input sequence for position mapping
        clean_input_seq = re.sub(r'[^A-Z]', '', input_sequence.upper())

        # Query UniProt for variants of this protein
        url = "https://rest.uniprot.org/uniprotkb/search"
        params = {
            'query': f'accession:{accession}',
            'format': 'json',
            'size': 1,
            'fields': 'accession,sequence,ft_variant'
        }

        response = requests.get(url, params=params, timeout=15)
        if response.status_code == 200:
            data = response.json()
            results = data.get('results', [])

            if results:
                result = results[0]
                canonical_seq = result.get('sequence', {}).get('value', '')
                features = result.get('features', [])

                # Create sequence alignment mapping
                seq_mapping = create_sequence_mapping(clean_input_seq, canonical_seq, chain_start)

                for feature in features:
                    if feature.get('type') == 'Natural variant':
                        # Extract mutation details
                        location = feature.get('location', {})
                        canonical_pos = location.get('start', {}).get('value')
                        description = feature.get('description', '')
                        evidence = feature.get('evidences', [])

                        if canonical_pos and description:
                            # Map canonical position to input sequence position
                            input_pos = map_canonical_to_input_position(canonical_pos, seq_mapping, chain_start)

                            # Extract wild-type and mutant amino acids
                            wt_aa, mut_aa = extract_mutation_residues(description, canonical_seq, canonical_pos)

                            if wt_aa and mut_aa and input_pos:
                                # Verify the wild-type residue matches input sequence
                                if input_pos <= len(clean_input_seq):
                                    actual_wt = clean_input_seq[input_pos - 1] if input_pos > 0 else None

                                    mutation_entry = {
                                        'mutation_code': f"{wt_aa}{input_pos}{mut_aa}",
                                        'canonical_code': f"{wt_aa}{canonical_pos}{mut_aa}",
                                        'input_position': input_pos,
                                        'canonical_position': canonical_pos,
                                        'wild_type_aa': wt_aa,
                                        'mutant_aa': mut_aa,
                                        'description': description,
                                        'evidence_count': len(evidence),
                                        'change_type': classify_mutation_type(wt_aa, mut_aa),
                                        'severity': assess_mutation_severity(wt_aa, mut_aa),
                                        'structural_impact': predict_structural_impact(wt_aa, mut_aa),
                                        'matches_input': actual_wt == wt_aa if actual_wt else False
                                    }
                                    mutations.append(mutation_entry)

        # If no variants found, generate example mutations with detailed info
        if not mutations:
            mutations = generate_example_mutations_detailed(clean_input_seq, chain_start, max_mutations)

    except Exception as e:
        print(f"Error querying detailed mutations: {e}")
        # Fallback to example mutations
        clean_input_seq = re.sub(r'[^A-Z]', '', input_sequence.upper())
        mutations = generate_example_mutations_detailed(clean_input_seq, chain_start, max_mutations)

    # Sort by input position and limit results
    mutations.sort(key=lambda x: x['input_position'])
    return mutations[:max_mutations]

def create_sequence_mapping(input_seq: str, canonical_seq: str, chain_start: int) -> Dict[int, int]:
    """Create mapping between canonical and input sequence positions."""
    mapping = {}

    if not canonical_seq:
        # Simple 1:1 mapping if no canonical sequence
        for i in range(len(input_seq)):
            mapping[i + 1] = chain_start + i
        return mapping

    # Simple alignment - find best offset
    best_offset = 0
    best_matches = 0

    for offset in range(-min(50, len(canonical_seq) // 4), min(50, len(canonical_seq) // 4)):
        matches = 0
        for i in range(min(len(input_seq), len(canonical_seq) - abs(offset))):
            canonical_idx = i + offset if offset >= 0 else i
            input_idx = i - offset if offset < 0 else i

            if (0 <= canonical_idx < len(canonical_seq) and
                0 <= input_idx < len(input_seq) and
                canonical_seq[canonical_idx] == input_seq[input_idx]):
                matches += 1

        if matches > best_matches:
            best_matches = matches
            best_offset = offset

    # Create mapping with best offset
    for i in range(len(input_seq)):
        canonical_pos = i + best_offset + 1
        input_pos = chain_start + i
        if 1 <= canonical_pos <= len(canonical_seq):
            mapping[canonical_pos] = input_pos

    return mapping

def map_canonical_to_input_position(canonical_pos: int, mapping: Dict[int, int], chain_start: int) -> Optional[int]:
    """Map canonical position to input sequence position."""
    if canonical_pos in mapping:
        return mapping[canonical_pos]

    # Fallback: assume 1:1 mapping if not found
    return canonical_pos - 1 + chain_start if canonical_pos > 0 else None

def extract_mutation_residues(description: str, canonical_seq: str, position: int) -> Tuple[Optional[str], Optional[str]]:
    """Extract wild-type and mutant amino acids from variant description."""
    # Try different patterns
    patterns = [
        r'([A-Z])\s*->\s*([A-Z])',  # A -> V
        r'([A-Z])(\d+)([A-Z])',      # A50V
        r'p\.([A-Z])[a-z]*(\d+)([A-Z])[a-z]*',  # p.Ala50Val
        r'([A-Z])[a-z]*\s*->\s*([A-Z])[a-z]*'   # Ala -> Val
    ]

    for pattern in patterns:
        match = re.search(pattern, description.upper())
        if match:
            groups = match.groups()
            if len(groups) >= 2:
                return groups[0], groups[-1]

    # Fallback: use canonical sequence if position is valid
    if canonical_seq and 1 <= position <= len(canonical_seq):
        wt_aa = canonical_seq[position - 1]
        # Look for any single letter in description that's not the wild-type
        for aa in 'ACDEFGHIKLMNPQRSTVWY':
            if aa != wt_aa and aa in description.upper():
                return wt_aa, aa

    return None, None

def format_mutations_for_input(mutations: List[Dict]) -> str:
    """Format mutations as a simple comma-separated list for user input."""
    if not mutations:
        return ""

    mutation_codes = [m['mutation_code'] for m in mutations if 'mutation_code' in m]
    return ", ".join(mutation_codes[:10])  # Limit to first 10 mutations

def format_mutations_table(mutations: List[Dict]) -> str:
    """Format mutations as a readable text table with consistent spacing."""
    if not mutations:
        return "No mutations found."

    # Calculate actual maximum widths for each column
    mutation_width = max(len("Mutation"), max(len(m['mutation_code']) for m in mutations)) + 1
    canonical_width = max(len("Canonical"), max(len(m['canonical_code']) for m in mutations)) + 1
    input_pos_width = max(len("Input Pos"), max(len(str(m['input_position'])) for m in mutations)) + 1
    canonical_pos_width = max(len("Canon Pos"), max(len(str(m['canonical_position'])) for m in mutations)) + 1
    change_width = max(len("Change Type"), max(len(m.get('change_type', '')) for m in mutations)) + 1
    severity_width = max(len("Severity"), max(len(m.get('severity', '')) for m in mutations)) + 1
    impact_width = max(len("Impact"), max(len(m.get('structural_impact', '')) for m in mutations)) + 1

    # Create header
    header = (f"{'Mutation':<{mutation_width}} "
             f"{'Canonical':<{canonical_width}} "
             f"{'Input Pos':<{input_pos_width}} "
             f"{'Canon Pos':<{canonical_pos_width}} "
             f"{'Change Type':<{change_width}} "
             f"{'Severity':<{severity_width}} "
             f"{'Impact':<{impact_width}}")

    # Create separator
    separator = "=" * len(header)

    # Create rows
    rows = []
    for m in mutations:
        row = (f"{m['mutation_code']:<{mutation_width}} "
               f"{m['canonical_code']:<{canonical_width}} "
               f"{m['input_position']:<{input_pos_width}} "
               f"{m['canonical_position']:<{canonical_pos_width}} "
               f"{m.get('change_type', ''):<{change_width}} "
               f"{m.get('severity', ''):<{severity_width}} "
               f"{m.get('structural_impact', ''):<{impact_width}}")
        rows.append(row)

    return "\n".join([header, separator] + rows)

def parse_fasta_sequences(fasta_content: str) -> List[Tuple[str, str]]:
    """Parse FASTA format sequences into a list of (name, sequence) tuples."""
    sequences = []
    current_id = None
    current_seq = []

    for line in fasta_content.strip().split('\n'):
        line = line.strip()
        if line.startswith('>'):
            # Save previous sequence if exists
            if current_id:
                sequences.append((current_id, ''.join(current_seq)))

            # Start new sequence
            current_id = line[1:].split()[0]  # Get first part after >
            current_seq = []
        elif line and current_id:
            # Add to current sequence
            current_seq.append(line)

    # Save last sequence
    if current_id:
        sequences.append((current_id, ''.join(current_seq)))

    return sequences

def validate_smiles(smiles: str) -> Tuple[bool, str]:
    """Validate SMILES string and return validation results."""
    if not smiles or not smiles.strip():
        return False, "Empty SMILES string"

    smiles = smiles.strip()

    # Use RDKit if available for enhanced validation
    if RDKIT_AVAILABLE:
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return False, "Invalid SMILES - RDKit cannot parse"

            # Additional checks
            try:
                mw = Descriptors.MolWt(mol)
                logp = Crippen.MolLogP(mol)
                canonical_smiles = Chem.MolToSmiles(mol)
                return True, f"Valid (MW: {mw:.1f}, LogP: {logp:.1f})"
            except:
                return True, "Valid SMILES"
        except Exception as e:
            return False, f"RDKit validation error: {str(e)}"

    # Basic validation without RDKit
    try:
        # Check for basic SMILES characters
        valid_chars = set('CNOPSFClBrI[]()=#+-.0123456789@\\/%')
        if not all(c in valid_chars for c in smiles):
            return False, "Contains invalid characters for SMILES"

        # Check balanced parentheses and brackets
        paren_count = 0
        bracket_count = 0
        for char in smiles:
            if char == '(':
                paren_count += 1
            elif char == ')':
                paren_count -= 1
            elif char == '[':
                bracket_count += 1
            elif char == ']':
                bracket_count -= 1

            if paren_count < 0 or bracket_count < 0:
                return False, "Unbalanced parentheses or brackets"

        if paren_count != 0 or bracket_count != 0:
            return False, "Unmatched parentheses or brackets"

        return True, "Valid SMILES (basic validation)"

    except Exception as e:
        return False, f"Basic validation error: {str(e)}"

def validate_ccd_code(ccd_code: str) -> bool:
    """Validate PDB Chemical Component Dictionary (CCD) code."""
    if not ccd_code or not ccd_code.strip():
        return False

    ccd_code = ccd_code.strip().upper()

    # Basic format validation
    if len(ccd_code) < 1 or len(ccd_code) > 5:
        return False

    # Check if it's a known common code
    if ccd_code in COMMON_CCD_CODES:
        return True

    # Check for valid characters (letters and numbers)
    if not ccd_code.replace('_', '').isalnum():
        return False

    return True

def parse_mutations(mutation_string: str) -> List[List[Tuple[str, int, str]]]:
    """Parse mutation string into individual mutations."""
    if not mutation_string or not mutation_string.strip():
        return []

    # Split by common delimiters
    mutations = []
    for delimiter in [',', ';', '\n', '\t']:
        if delimiter in mutation_string:
            mutations = [m.strip() for m in mutation_string.split(delimiter)]
            break
    else:
        # Try space separation if no other delimiter found
        mutations = mutation_string.split()

    # Filter out empty strings and validate format
    mutation_lists = []
    for mut in mutations:
        mut = mut.strip()
        if mut and re.match(r'^[A-Z]\d+[A-Z](-[A-Z]\d+[A-Z])*$', mut):
            # Parse individual or compound mutations
            parts = mut.split('-')
            mutation_list = []
            for part in parts:
                match = re.match(r'^([A-Z])(\d+)([A-Z])$', part)
                if match:
                    wt_aa, pos_str, mut_aa = match.groups()
                    mutation_list.append(('A', int(pos_str), mut_aa))  # Default to chain A
            if mutation_list:
                mutation_lists.append(mutation_list)

    return mutation_lists

def generate_mutant_name(mutations: List[Tuple[str, int, str]]) -> str:
    """Generate a descriptive name for the mutant protein."""
    if not mutations:
        return "WT"

    # Convert mutations to standard format
    mut_strs = []
    for chain_id, position, mut_aa in mutations:
        # Get wild-type amino acid from original sequence (simplified)
        mut_strs.append(f"{position}{mut_aa}")

    if len(mut_strs) == 1:
        return mut_strs[0]
    elif len(mut_strs) <= 3:
        return "-".join(mut_strs)
    else:
        return "-".join(mut_strs[:2]) + f"-plus{len(mut_strs)-2}more"

def apply_mutations_to_sequence(wt_sequence: str, mutations: List[Tuple[str, int, str]], chain_starts: Dict[str, int], chains_dict: Dict[str, str] = None) -> str:
    """Apply mutations to a protein sequence."""
    if not mutations or not wt_sequence:
        return wt_sequence

    # Handle multi-chain vs single-chain
    if ':' in wt_sequence and chains_dict:
        # Multi-chain sequence
        chain_sequences = {chain_id: list(seq) for chain_id, seq in chains_dict.items()}

        for chain_id, position, mut_aa in mutations:
            if chain_id in chain_sequences and chain_id in chain_starts:
                chain_start = chain_starts[chain_id]
                seq_index = position - chain_start

                if 0 <= seq_index < len(chain_sequences[chain_id]):
                    chain_sequences[chain_id][seq_index] = mut_aa

        # Reconstruct sequence with colons
        return ':'.join(''.join(chain_sequences[chain_id]) for chain_id in sorted(chain_sequences.keys()))
    else:
        # Single chain sequence
        seq_list = list(wt_sequence)
        chain_start = chain_starts.get('A', 1)

        for chain_id, position, mut_aa in mutations:
            seq_index = position - chain_start
            if 0 <= seq_index < len(seq_list):
                seq_list[seq_index] = mut_aa

        return ''.join(seq_list)

def create_screening_yaml(workspace_name: str, design_name: str, protein_sequence: str, ligand_smiles: str,
                         project_name: str, **kwargs) -> str:
    """
    Create YAML configuration for drug screening using the same logic as utils.py.
    Based on create_boltz_yaml from utils.py and create_screening_boltz_yaml from drug_screening.py.
    """

    # Extract additional parameters
    cofactor_info = kwargs.get('cofactor_info', [])
    binding_pocket_constraints = kwargs.get('binding_pocket_constraints')
    ptm_modifications = kwargs.get('ptm_modifications')
    template_cif_path = kwargs.get('template_cif_path')
    structure_only = kwargs.get('structure_only', False)

    # Create project directory (similar to drug_screening.py)
    project_dir = os.path.join(RESULTS_DIR, project_name)
    os.makedirs(project_dir, exist_ok=True)

    # Create filename
    filename = f"{workspace_name}_{design_name}.yaml"
    filepath = os.path.join(project_dir, filename)

    # Parse protein sequence into chains (from utils.py)
    try:
        protein_chains = parse_protein_chains(protein_sequence)
    except ValueError as e:
        raise ValueError(f"Invalid protein sequence format: {str(e)}")

    # Add PTM modifications to protein chains if provided (from utils.py)
    if ptm_modifications and ptm_modifications.get('modifications'):
        modifications = ptm_modifications.get('modifications', [])
        for mod in modifications:
            chain_id = mod.get('chain_id', 'A')
            position = mod.get('position')
            ccd = mod.get('ccd')

            if chain_id and position and ccd:
                # Find the corresponding protein chain and add modification
                for chain in protein_chains:
                    if chain.get('protein', {}).get('id') == chain_id:
                        if 'modifications' not in chain['protein']:
                            chain['protein']['modifications'] = []
                        chain['protein']['modifications'].append({
                            'position': position,
                            'ccd': ccd
                        })
                        break

    # Create YAML content with protein chains (from utils.py)
    yaml_content = {"sequences": protein_chains}

    # Add main ligand if not structure-only mode
    if not structure_only and ligand_smiles:
        yaml_content["sequences"].append({
            "ligand": {
                "id": "X",
                "smiles": ligand_smiles
            }
        })

    # Add co-factors if provided (from utils.py)
    if cofactor_info and isinstance(cofactor_info, list) and len(cofactor_info) > 0:
        for i, cofactor in enumerate(cofactor_info):
            if i >= 4:  # Limit to 4 cofactors
                break
            if cofactor and (cofactor.get('smiles') or cofactor.get('ccd')):
                # Use chain IDs T, U, V, W
                chain_id = chr(ord('T') + i)
                cofactor_entry = {
                    "ligand": {
                        "id": chain_id
                    }
                }

                # Add either SMILES or CCD code
                if cofactor.get('smiles'):
                    cofactor_entry["ligand"]["smiles"] = cofactor['smiles']
                elif cofactor.get('ccd'):
                    cofactor_entry["ligand"]["ccd"] = cofactor['ccd']

                # Add co-factor to sequences
                yaml_content["sequences"].append(cofactor_entry)
    # Backward compatibility: handle single cofactor dict (from utils.py)
    elif cofactor_info and isinstance(cofactor_info, dict) and (cofactor_info.get('smiles') or cofactor_info.get('ccd')):
        cofactor_entry = {
            "ligand": {
                "id": "T"  # Use T for backward compatibility
            }
        }

        # Add either SMILES or CCD code
        if cofactor_info.get('smiles'):
            cofactor_entry["ligand"]["smiles"] = cofactor_info['smiles']
        elif cofactor_info.get('ccd'):
            cofactor_entry["ligand"]["ccd"] = cofactor_info['ccd']

        # Add co-factor to sequences
        yaml_content["sequences"].append(cofactor_entry)

    # Add templates section if template_cif_path is provided (from utils.py)
    if template_cif_path:
        yaml_content["templates"] = [{"cif": os.path.abspath(template_cif_path)}]

    # Add constraints if provided and valid (from utils.py)
    if binding_pocket_constraints and binding_pocket_constraints.get('contacts'):
        contacts = []
        for c in binding_pocket_constraints.get('contacts', []):
            if len(c) >= 2:
                # Try to cast residue index to int if possible
                try:
                    res_idx = int(c[1])
                except (ValueError, TypeError):
                    res_idx = c[1]
                contacts.append([c[0], res_idx])
        pocket_constraint = {
            "pocket": {
                "binder": binding_pocket_constraints.get('binder', 'X'),
                "contacts": contacts,
                "max_distance": float(binding_pocket_constraints.get('max_distance', 5.0))
            }
        }
        # Create a copy to avoid modifying original (from utils.py)
        yaml_content_copy = copy.deepcopy(yaml_content)
        yaml_content_copy["constraints"] = [pocket_constraint]
        constraints = yaml_content_copy.pop("constraints")

        # Write YAML with custom constraints formatting (from utils.py)
        with open(filepath, 'w') as f:
            yaml.dump(yaml_content_copy, f, default_flow_style=False)
            f.write("constraints:\n")
            for constraint in constraints:
                f.write("  - pocket:\n")
                f.write(f"      binder: {constraint['pocket']['binder']}\n")
                contacts_str = yaml.dump(constraint['pocket']['contacts'], default_flow_style=True).strip()
                f.write(f"      contacts: {contacts_str}\n")
                f.write(f"      max_distance: {constraint['pocket']['max_distance']}\n")
            # Always append the properties block for affinity prediction (from utils.py)
            if not structure_only:
                f.write("properties:\n")
                f.write("  - affinity:\n")
                f.write("      binder: X\n")
    else:
        # Write standard YAML without constraints (from utils.py)
        with open(filepath, 'w') as f:
            yaml.dump(yaml_content, f, default_flow_style=False)
            # Always append the properties block for affinity prediction (from utils.py)
            if not structure_only:
                f.write("properties:\n")
                f.write("  - affinity:\n")
                f.write("      binder: X\n")

    return filepath

# Local file handling functions
def read_local_file(filename: str) -> Optional[str]:
    """Read a local file if it exists"""
    if os.path.exists(filename):
        try:
            with open(filename, 'r') as f:
                return f.read()
        except Exception as e:
            print(f"⚠️ Error reading {filename}: {e}")
    return None

def parse_fasta_from_local_file(filename: str) -> List[Tuple[str, str]]:
    """Parse FASTA format from a local file"""
    content = read_local_file(filename)
    if not content:
        return []

    sequences = []
    current_name = None
    current_seq = []

    for line in content.split('\n'):
        line = line.strip()
        if line.startswith('>'):
            if current_name and current_seq:
                sequences.append((current_name, ''.join(current_seq)))
            current_name = line[1:]  # Remove '>'
            current_seq = []
        elif line:
            current_seq.append(line)

    # Add last sequence
    if current_name and current_seq:
        sequences.append((current_name, ''.join(current_seq)))

    return sequences

print("✅ All functions ready! YAML creation updated to match utils.py reference implementation:")
print("🔧 Features implemented:")
print("   - Proper protein chain parsing (A-W, X reserved for ligand)")
print("   - PTM modifications support")
print("   - Multi-cofactor support (T, U, V, W chain IDs)")
print("   - Binding pocket constraints with custom YAML formatting")
print("   - Template CIF file support")
print("   - Properties block for affinity prediction")
print("   - Structure-only mode support")
if running_locally:
    print("📁 Local files expected:")
    print("   - protein.fasta: Protein sequences in FASTA format")
    print("   - drug.fasta: SMILES strings in FASTA format")
    print("   - template.cif: Template structure file")
else:
    print("☁️ Colab mode: File uploads available")

In [None]:
#@title 2️⃣ 📱 Basic Input Configuration { display-mode: "form" }

#@markdown ## 📦 Project Settings
#@markdown Configure your screening project with a unique name for organizing results and outputs.

#@markdown **Project name**: Unique identifier for your screening project (used in output files and results organization)
project_name = "my_drug_screening"  #@param {type:"string"}

#@markdown ---

#@markdown ## 🧬 Protein Input
#@markdown Define how to specify your protein sequences for screening. Choose between different input modes and mutation discovery options:
#@markdown
#@markdown 🔴 **Protein-Drug Screening - Manual Protein Entry** = paste single wild-type sequence directly (for screening one protein against multiple ligands)
#@markdown
#@markdown 🔴 **Protein-Drug Screening - Upload FASTA File** = upload file with multiple protein sequences (for screening many proteins against multiple ligands; if running local mode, the script will look for **protein.fasta** in working directory)
#@markdown
#@markdown 🔵 **Mutants-Drug Screening** = start with wild-type sequence and generate specific mutant variants using database discovery or manual entry
input_method = "Protein-Drug Screening - Manual Protein Entry"  #@param ["Protein-Drug Screening - Manual Protein Entry", "Protein-Drug Screening - Upload FASTA File", "Mutants-Drug Screening"]

#@markdown **Structure-only mode**: Predict protein structure without ligands. Enable for protein folding studies only.
structure_only = False  #@param {type:"boolean"}

#@markdown ---
#@markdown ### 🔴 Protein-Drug Screening - Manual Protein Entry (if selected as input method)
#@markdown Paste a single protein sequence directly. Use this for one wild-type protein or when you want to enter the sequence manually.

#@markdown Protein sequence: Single letter amino acid codes (A-Z). For multi-chain: use colon-separated format like 'SEQUENCE1:SEQUENCE2:SEQUENCE3'.
protein_sequence = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"  #@param {type:"string"}

#@markdown **Protein name**: Descriptive name for your protein (used in output files and results).
protein_name = "Hemoglobin_Alpha"  #@param {type:"string"}

#@markdown ---
#@markdown ### 🔵 Mutants-Drug Screening (if selected as input method)
#@markdown Start with a wild-type sequence and automatically generate specific mutant variants. Use this to study the effects of specific amino acid changes.

#@markdown **Wild-type sequence**: Original protein sequence for generating mutants. For multi-chain proteins, use colon (:) to separate chains (e.g., CHAIN1:CHAIN2:CHAIN3).
wt_protein_sequence = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"  #@param {type:"string"}

#@markdown 🔍 **Discover mutations from databases (optional)**: If you don't know what mutations to test, enable this to search public databases for known mutations of your protein sequence. After running the cell, copy the suggested mutations to the mutations input field below.
query_mutations_from_database = True  #@param {type:"boolean"}

#@markdown 🔍 **Discover mutations from databases - Residue range filter (optional)**: Filter mutations by residue ranges per chain. Format: (start-end),(start-end),... for each chain. Example: (1-50),(100-200) finds mutations in residues 1-50 of chain 1 and 100-200 of chain 2. Leave empty to include all residues.
residue_range_filter = ""  #@param {type:"string"}

#@markdown **Chain start numbers**: Comma-separated starting residue numbers for each chain (default: 1 for all chains). Format: '1,1,1' for 3 chains all starting at 1, or '1,150,300' for chains starting at different positions.
chain_start_numbers = "1"  #@param {type:"string"}

#@markdown **Mutations**: Comma-separated list of mutations. Single mutations: A50V. Double mutations: A50V-G60P (format: A50V means Alanine at position 50 changed to Valine). 💡 If you don't know what mutations to try, check the box above, run this cell to get suggested mutations, then paste them here.
mutations_input = "A50V, L60P, V70I"  #@param {type:"string"}

#@markdown ---

#@markdown ## 🧪 Ligand SMILES Input
#@markdown Specify small molecules for binding analysis (skip if structure-only mode). Choose between:
#@markdown
#@markdown **Text Input** = enter SMILES strings separated by semicolons (;)
#@markdown
#@markdown **Upload FASTA File** = upload FASTA format file containing multiple named SMILES entries for easier multi-line input (local mode: looks for **drug.fasta** in working directory).

ligand_input_method = "Text Input"  #@param ["Text Input", "Upload FASTA File"]

#@markdown **SMILES strings**: Separate multiple SMILES with semicolons (;). Format options: 'SMILES1;SMILES2;SMILES3' or 'Name1,SMILES1;Name2,SMILES2'. Examples: 'CCO;CCN;CC(C)O' or 'Ethanol,CCO;Ethylamine,CCN;Isopropanol,CC(C)O'.
ligand_smiles_input = "CCO;"  #@param {type:"string"}

#@markdown ---
#@markdown Click RUN ▶️ to validate and upload your protein/ligand inputs. Then proceed to the next cell for additional parameters.

# Additional helper functions for improved mutation analysis
def parse_chain_start_numbers(chain_start_input: str, num_chains: int) -> Dict[str, int]:
    """Parse chain start numbers input into a dictionary mapping chain IDs to start numbers."""
    chain_starts = {}

    if not chain_start_input.strip():
        # Default to 1 for all chains
        for i in range(num_chains):
            if i < 23:  # A-W
                chain_id = chr(65 + i)
            else:  # Y-Z
                chain_id = chr(89 + (i - 23))
            chain_starts[chain_id] = 1
        return chain_starts

    # Parse comma-separated values
    start_numbers = [s.strip() for s in chain_start_input.split(',')]

    for i in range(num_chains):
        if i < 23:  # A-W
            chain_id = chr(65 + i)
        else:  # Y-Z
            chain_id = chr(89 + (i - 23))

        # Use provided start number or default to 1
        if i < len(start_numbers):
            try:
                chain_starts[chain_id] = int(start_numbers[i])
            except ValueError:
                chain_starts[chain_id] = 1
        else:
            chain_starts[chain_id] = 1

    return chain_starts

def parse_residue_range_filter(range_filter_input: str, num_chains: int) -> Dict[str, List[Tuple[int, int]]]:
    """Parse residue range filter input into chain-specific ranges."""
    if not range_filter_input.strip():
        return {}

    ranges_per_chain = {}

    # Extract ranges in format (start-end),(start-end),...
    range_pattern = r'\((\d+)-(\d+)\)'
    matches = re.findall(range_pattern, range_filter_input)

    for i, (start_str, end_str) in enumerate(matches):
        if i < num_chains:
            if i < 23:  # A-W
                chain_id = chr(65 + i)
            else:  # Y-Z
                chain_id = chr(89 + (i - 23))

            try:
                start_pos = int(start_str)
                end_pos = int(end_str)
                if start_pos <= end_pos:
                    if chain_id not in ranges_per_chain:
                        ranges_per_chain[chain_id] = []
                    ranges_per_chain[chain_id].append((start_pos, end_pos))
            except ValueError:
                continue

    return ranges_per_chain

def filter_mutations_by_range(mutations: List[Dict], range_filter: Dict[str, List[Tuple[int, int]]], chain_starts: Dict[str, int]) -> List[Dict]:
    """Filter mutations based on residue range criteria."""
    if not range_filter:
        return mutations

    filtered_mutations = []

    for mutation in mutations:
        input_pos = mutation.get('input_position', 0)

        # Determine which chain this position belongs to
        target_chain = None
        for chain_id, chain_start in chain_starts.items():
            if chain_id in range_filter:
                # Check if position falls within any range for this chain
                for start_range, end_range in range_filter[chain_id]:
                    adjusted_start = start_range + chain_start - 1
                    adjusted_end = end_range + chain_start - 1
                    if adjusted_start <= input_pos <= adjusted_end:
                        target_chain = chain_id
                        break
                if target_chain:
                    break

        if target_chain:
            filtered_mutations.append(mutation)

    return filtered_mutations

def format_mutations_table_improved(mutations: List[Dict]) -> str:
    """Format mutations as a readable text table without redundant position columns."""
    if not mutations:
        return "No mutations found."

    # Calculate maximum widths for each column (removed Input Pos and Canon Pos)
    mutation_width = max(len("Mutation"), max(len(m['mutation_code']) for m in mutations)) + 1
    change_width = max(len("Change Type"), max(len(m.get('change_type', '')) for m in mutations)) + 1
    severity_width = max(len("Severity"), max(len(m.get('severity', '')) for m in mutations)) + 1
    impact_width = max(len("Structural Impact"), max(len(m.get('structural_impact', '')) for m in mutations)) + 1
    evidence_width = max(len("Evidence"), max(len(str(m.get('evidence_count', 0))) for m in mutations)) + 1

    # Create header (removed redundant position columns)
    header = (f"{'Mutation':<{mutation_width}} "
             f"{'Change Type':<{change_width}} "
             f"{'Severity':<{severity_width}} "
             f"{'Structural Impact':<{impact_width}} "
             f"{'Evidence':<{evidence_width}}")

    # Create separator
    separator = "=" * len(header)

    # Create rows
    rows = []
    for m in mutations:
        row = (f"{m['mutation_code']:<{mutation_width}} "
               f"{m.get('change_type', ''):<{change_width}} "
               f"{m.get('severity', ''):<{severity_width}} "
               f"{m.get('structural_impact', ''):<{impact_width}} "
               f"{m.get('evidence_count', 0):<{evidence_width}}")
        rows.append(row)

    return "\n".join([header, separator] + rows)

# Import required validation functions from drug_screening_input.py logic
import re
import requests
import urllib.parse

def validate_protein_sequence(protein_seq: str):
    """
    Validate protein sequence input for multi-chain format using colon separation.
    Based on drug_screening_input.py validation logic.
    """
    protein_seq = re.sub(r'\s+', '', protein_seq.upper())
    if not protein_seq.strip():
        return True, "", {}, protein_seq

    # Check for invalid characters (only uppercase letters and : allowed)
    invalid_chars = re.findall(r'[^A-Z:]', protein_seq)
    if invalid_chars:
        invalid_chars_str = ', '.join(set(invalid_chars))
        return False, f"Invalid characters found: {invalid_chars_str}. Only uppercase letters (A-Z) and colon (:) are allowed.", {}, protein_seq

    # Split by colon to get chains
    chains = protein_seq.split(':')

    # Check if we have too many chains (max 25 chains: A-W, Y-Z, X is reserved for ligand)
    if len(chains) > 25:
        return False, f"Too many protein chains ({len(chains)}). Maximum allowed is 25 chains (A-W, Y-Z). Chain X is reserved for ligands.", {}, protein_seq

    # Validate each chain contains only uppercase letters
    for i, chain in enumerate(chains):
        if not chain.strip():
            return False, f"Empty chain found at position {i+1}. Each chain must contain amino acid sequence.", {}, protein_seq
        if not re.match(r'^[A-Z]+$', chain.strip()):
            return False, f"Chain {i+1} contains invalid characters. Only uppercase letters are allowed.", {}, protein_seq

    # Create chains dictionary with incremental IDs (A-W, Y-Z, X reserved for ligand)
    chains_dict = {}
    for i, chain in enumerate(chains):
        if i < 23:  # A-W (0-22)
            chain_id = chr(65 + i)  # A, B, C, ..., W (65-87)
        else:  # Y-Z (23-24)
            chain_id = chr(89 + (i - 23))  # Y, Z (89-90)
        chains_dict[chain_id] = chain.strip()

    return True, "", chains_dict, protein_seq

def parse_smiles_list_colab(smiles_input: str):
    """
    Parse SMILES input that can be either semicolon-separated or newline-separated.
    Optimized for Google Colab where semicolons are preferred.
    """
    smiles_list = []

    # Handle both semicolon and newline separation
    if ';' in smiles_input:
        # Semicolon-separated (preferred for Colab)
        lines = smiles_input.strip().split(';')
    else:
        # Fallback to newline-separated
        lines = smiles_input.strip().split('\n')

    for i, line in enumerate(lines):
        line = line.strip()
        if not line:
            continue

        # Check if line contains comma (name, smiles format)
        if ',' in line:
            parts = line.split(',', 1)  # Split only on first comma
            if len(parts) == 2:
                name = parts[0].strip()
                smiles = parts[1].strip()
                smiles_list.append((name, smiles))
            else:
                smiles_list.append((f"Compound_{i+1}", line))
        else:
            # Just SMILES without name
            smiles_list.append((f"Compound_{i+1}", line))

    return smiles_list

# Validate inputs
print("🔍 Validating Basic Inputs...\n")

# Handle mutation discovery if requested
discovered_mutations = []
detailed_mutations_data = []
if input_method == "Mutants-Drug Screening" and query_mutations_from_database and wt_protein_sequence.strip():
    print("🔍 Discovering mutations from public databases...\n")

    # Clean and validate the sequence first
    is_valid, error_msg, chains_dict, clean_seq = validate_protein_sequence(wt_protein_sequence.strip())

    if not is_valid:
        print(f"❌ Invalid wild-type sequence: {error_msg}")
    else:
        # Parse chain start numbers
        num_chains = len(chains_dict) if chains_dict else 1
        chain_starts = parse_chain_start_numbers(chain_start_numbers, num_chains)

        # Parse residue range filter
        range_filter = parse_residue_range_filter(residue_range_filter, num_chains)

        print(f"🧬 Multi-chain analysis:")
        print(f"   Chains detected: {list(chains_dict.keys()) if chains_dict else ['A (single chain)']}")
        print(f"   Chain start numbers: {chain_starts}")
        if range_filter:
            print(f"   Residue range filter: {range_filter}")

        # Use first/longest chain for identification
        if chains_dict and len(chains_dict) > 1:
            # Use the longest chain for protein identification
            first_chain = max(chains_dict.values(), key=len)
            chain_id_for_analysis = max(chains_dict.items(), key=lambda x: len(x[1]))[0]
            print(f"   Using chain {chain_id_for_analysis} (longest, {len(first_chain)} residues) for protein identification")
        else:
            first_chain = clean_seq.replace(':', '') if ':' in clean_seq else clean_seq
            chain_id_for_analysis = 'A'
            print(f"   Single chain analysis ({len(first_chain)} residues)")

        print(f"\n🔍 Analyzing sequence (length: {len(first_chain)} residues)...")

        # Identify protein
        protein_info = identify_protein_from_sequence(first_chain)

        if "error" not in protein_info:
            print(f"✅ Protein identified:")
            print(f"   📛 Name: {protein_info.get('protein_name', 'Unknown')}")
            print(f"   🧬 Organism: {protein_info.get('organism', 'Unknown')}")
            print(f"   🔗 UniProt: {protein_info.get('accession', 'Unknown')}")
            print(f"   📊 Similarity: {protein_info.get('similarity', 0):.1%}")

            print(f"\n🔍 Searching for known mutations...")

            # Use the chain start number for the analysis chain
            analysis_chain_start = chain_starts.get(chain_id_for_analysis, 1)

            # Use the detailed mutation query function
            detailed_mutations_data = query_mutations_for_protein_detailed(
                protein_info, first_chain, analysis_chain_start, 50  # Increased default to show more mutations
            )

            if detailed_mutations_data:
                # Apply residue range filter if specified
                if range_filter:
                    print(f"   🎯 Applying residue range filter...")
                    original_count = len(detailed_mutations_data)
                    detailed_mutations_data = filter_mutations_by_range(detailed_mutations_data, range_filter, chain_starts)
                    filtered_count = len(detailed_mutations_data)
                    print(f"   📊 Filtered {original_count} → {filtered_count} mutations based on range criteria")

                discovered_mutations = [m['mutation_code'] for m in detailed_mutations_data]
                formatted_mutations = format_mutations_for_input(detailed_mutations_data)

                print(f"✅ Found {len(detailed_mutations_data)} mutations:")

                # Display detailed table with improved format (no redundant columns)
                print(f"\n📊 Detailed Mutation Analysis:")
                mutation_table = format_mutations_table_improved(detailed_mutations_data)
                print(mutation_table)

                print(f"\n📋 Copy this to the mutations input field below:")
                print(f"🔗 {formatted_mutations}")

                print(f"\n📚 Column Explanations:")
                print(f"• Mutation: Mutation in your sequence numbering (WT_AA + Position + MUT_AA)")
                print(f"• Change Type: Chemical property change classification")
                print(f"• Severity: Predicted functional impact (High/Moderate/Low)")
                print(f"• Structural Impact: Expected structural consequences")
                print(f"• Evidence: Number of experimental observations in databases")

                if range_filter:
                    print(f"\n🎯 Range Filter Applied:")
                    for chain_id, ranges in range_filter.items():
                        for start_pos, end_pos in ranges:
                            actual_start = start_pos + chain_starts[chain_id] - 1
                            actual_end = end_pos + chain_starts[chain_id] - 1
                            print(f"   Chain {chain_id}: residues {start_pos}-{end_pos} (absolute positions {actual_start}-{actual_end})")

            else:
                print(f"⚠️ No mutations found in the specified criteria.")
        else:
            print(f"❌ Protein identification failed: {protein_info['error']}")
            print(f"💡 You can still manually enter mutations in the standard format (e.g., A50V, L60P)")

# Handle file upload if selected
protein_sequences = []
if input_method == "Protein-Drug Screening - Upload FASTA File":
    if COLAB_AVAILABLE and not running_locally:
        print("📁 Please upload your FASTA file:")
        uploaded = files.upload()

        for filename, content in uploaded.items():
            if isinstance(content, bytes):
                content = content.decode('utf-8')

            sequences = parse_fasta_sequences(content)
            protein_sequences.extend(sequences)
            print(f"✅ Parsed {len(sequences)} sequences from {filename}")

            for name, seq in sequences:
                print(f"  - {name}: {len(seq)} residues")
    elif running_locally:
        print("📁 Local mode: Looking for protein.fasta in current directory...")
        sequences = parse_fasta_from_local_file("protein.fasta")
        if sequences:
            protein_sequences.extend(sequences)
            print(f"✅ Parsed {len(sequences)} sequences from protein.fasta")
            for name, seq in sequences:
                print(f"  - {name}: {len(seq)} residues")
        else:
            print("❌ protein.fasta not found or empty in current directory")
    else:
        print("❌ Cannot upload files - not in Colab environment and not in local mode")

elif input_method == "Protein-Drug Screening - Manual Protein Entry":
    # Validate manual protein input using colon-separated format
    is_valid, error_msg, chains_dict, clean_seq = validate_protein_sequence(protein_sequence)
    if is_valid:
        protein_sequences.append((protein_name, protein_sequence))
        print(f"✅ Valid protein sequence: {protein_name}")
        if len(chains_dict) > 1:
            print(f"   Multi-chain detected: {len(chains_dict)} chains ({list(chains_dict.keys())})")
            for chain_id, chain_seq in chains_dict.items():
                print(f"     Chain {chain_id}: {len(chain_seq)} residues")
        else:
            # Single chain
            clean_single_seq = re.sub(r'\s+', '', protein_sequence.upper())
            print(f"   Length: {len(clean_single_seq)} residues")
    else:
        print(f"❌ Invalid protein sequence: {error_msg}")
        protein_sequences = []

elif input_method == "Mutants-Drug Screening":
    # Process mutation mode
    if wt_protein_sequence.strip() and mutations_input.strip():
        is_valid, error_msg, chains_dict, clean_seq = validate_protein_sequence(wt_protein_sequence.strip())

        if is_valid:
            # Parse chain start numbers
            num_chains = len(chains_dict) if chains_dict else 1
            chain_starts = parse_chain_start_numbers(chain_start_numbers, num_chains)

            # Parse mutations
            mutation_lists = parse_mutations(mutations_input)

            # Create wild-type entry
            protein_sequences.append(("WT", clean_seq))

            # Create mutant entries
            mutant_strings = [s.strip() for s in mutations_input.split(',')]

            for i, mutant_str in enumerate(mutant_strings):
                if not mutant_str:
                    continue

                mutant_name = generate_mutant_name(mutation_lists[i] if i < len(mutation_lists) else [])

                if i < len(mutation_lists):
                    mutations = mutation_lists[i]
                    mutated_seq = apply_mutations_to_sequence(clean_seq, mutations, chain_starts, chains_dict)
                    protein_sequences.append((mutant_name, mutated_seq))

            print(f"✅ Generated {len(protein_sequences)} protein variants (WT + {len(protein_sequences)-1} mutants)")

            # Display chain information if multi-chain
            if chains_dict and len(chains_dict) > 1:
                print(f"   Multi-chain setup: {len(chains_dict)} chains with start numbers {chain_starts}")
        else:
            print(f"❌ Invalid wild-type sequence: {error_msg}")

# Validate SMILES input
ligand_smiles = []
if not structure_only:
    if ligand_input_method == "Protein-Drug Screening":
        if COLAB_AVAILABLE and not running_locally:
            print("📁 Please upload your SMILES FASTA file:")
            uploaded_smiles = files.upload()

            for filename, content in uploaded_smiles.items():
                if isinstance(content, bytes):
                    content = content.decode('utf-8')

                parsed_smiles = parse_smiles_list(content)
                print(f"✅ Parsed {len(parsed_smiles)} SMILES from {filename}")
                ligand_smiles.extend(parsed_smiles)
        elif running_locally:
            print("📁 Local mode: Looking for drug.fasta in current directory...")
            sequences = parse_fasta_from_local_file("drug.fasta")
            if sequences:
                # Convert FASTA sequences to SMILES format
                parsed_smiles = [(name, seq) for name, seq in sequences]
                ligand_smiles.extend(parsed_smiles)
                print(f"✅ Parsed {len(parsed_smiles)} SMILES from drug.fasta")
            else:
                print("❌ drug.fasta not found or empty in current directory")
        else:
            print("❌ Cannot upload files - not in Colab environment and not in local mode")

    elif ligand_input_method == "Text Input" and ligand_smiles_input.strip():
        parsed_smiles = parse_smiles_list_colab(ligand_smiles_input)
        print(f"\n🧪 Validating {len(parsed_smiles)} SMILES strings:")

        for name, smiles in parsed_smiles:
            is_valid, error_msg = validate_smiles(smiles)
            if is_valid:
                ligand_smiles.append((name, smiles))
                print(f"✅ {name}: {smiles}")
                if RDKIT_AVAILABLE and "Valid (" in error_msg:
                    print(f"   {error_msg}")
            else:
                print(f"❌ {name}: {error_msg}")

# Store validated inputs for next cell
print(f"\n📋 Basic Input Summary:")
print(f"📦 Project: {project_name}")
print(f"🧬 Valid proteins: {len(protein_sequences)}")
print(f"🧪 Valid ligands: {len(ligand_smiles) if not structure_only else 'N/A (structure only)'}")

if discovered_mutations:
    print(f"🔍 Discovered mutations: {len(discovered_mutations)}")
    if detailed_mutations_data:
        severity_counts = {}
        for mut in detailed_mutations_data:
            severity = mut.get('severity', 'Unknown')
            severity_counts[severity] = severity_counts.get(severity, 0) + 1

        severity_summary = ", ".join([f"{count} {sev}" for sev, count in severity_counts.items()])
        print(f"   Severity breakdown: {severity_summary}")

if protein_sequences and (ligand_smiles or structure_only):
    print(f"\n✅ Basic inputs validated! Proceed to the next cell for additional parameters.")
else:
    if structure_only:
        needed = "at least 1 protein"
    else:
        needed = "at least 1 protein and 1 ligand"
    print(f"\n❌ Cannot proceed: Need {needed}")

In [None]:
#@title 3️⃣ ⚙️ Additional Parameters (modifications are optional, run to proceed to next step) { display-mode: "form" }

#@markdown ## 💎 Co-factors (Optional)
#@markdown Add essential co-factors like ATP, NAD, or heme that are required for protein function. Leave empty if not needed.

#@markdown ### 🔸 Co-factor 1
#@markdown **Method**: 'SMILES' = provide chemical structure, 'CCD Code' = use standard biochemical codes (ATP, NAD, HEM, etc.).
cofactor1_method = "None"  #@param ["None", "SMILES", "CCD Code"]
#@markdown **SMILES**: Chemical structure notation (only if SMILES method selected)
cofactor1_smiles = ""  #@param {type:"string"}
#@markdown **CCD Code**: Standard biochemical code (only if CCD Code method selected)
cofactor1_ccd = ""  #@param {type:"string"}

#@markdown ### 🔸 Co-factor 2
#@markdown **Method**: Choose input method for second co-factor
cofactor2_method = "None"  #@param ["None", "SMILES", "CCD Code"]
#@markdown **SMILES**: Chemical structure notation
cofactor2_smiles = ""  #@param {type:"string"}
#@markdown **CCD Code**: Standard biochemical code
cofactor2_ccd = ""  #@param {type:"string"}

#@markdown ### 🔸 Co-factor 3
#@markdown **Method**: Choose input method for third co-factor
cofactor3_method = "None"  #@param ["None", "SMILES", "CCD Code"]
#@markdown **SMILES**: Chemical structure notation
cofactor3_smiles = ""  #@param {type:"string"}
#@markdown **CCD Code**: Standard biochemical code
cofactor3_ccd = ""  #@param {type:"string"}

#@markdown ### 🔸 Co-factor 4
#@markdown **Method**: Choose input method for fourth co-factor
cofactor4_method = "None"  #@param ["None", "SMILES", "CCD Code"]
#@markdown **SMILES**: Chemical structure notation
cofactor4_smiles = ""  #@param {type:"string"}
#@markdown **CCD Code**: Standard biochemical code
cofactor4_ccd = ""  #@param {type:"string"}

#@markdown ---

#@markdown ## 🎯 Binding Pocket Constraints (Optional)
#@markdown Define specific residues that should interact with ligands. Use for targeted drug design.

#@markdown **Enable constraints**: Force ligands to bind near specific protein residues.
enable_binding_constraints = False  #@param {type:"boolean"}

#@markdown **Binder ID**: Which molecule should satisfy constraints ('X' = main ligand).
binder_id = "X"  #@param {type:"string"}

#@markdown **Contact residues**: Comma-separated list (format: 'A:25,A:48,B:52' = Chain A residue 25, Chain A residue 48, Chain B residue 52).
pocket_contacts = ""  #@param {type:"string"}

#@markdown **Max distance**: Maximum allowed distance in Angstroms between ligand and specified residues.
max_distance = 5.0  #@param {type:"number"}

#@markdown ---

#@markdown ## ⚗️ Post-Translational Modifications - PTM (Optional)
#@markdown Add chemical modifications to proteins (phosphorylation, glycosylation, etc.).

#@markdown **Enable PTM**: Add post-translational modifications to specific residues.
enable_ptm = False  #@param {type:"boolean"}

#@markdown **PTM chain**: Which protein chain to modify (usually 'A' for single chains).
ptm_chain_id = "A"  #@param {type:"string"}

#@markdown **PTM position**: Residue number where modification should be added.
ptm_position = 100  #@param {type:"integer"}

#@markdown **PTM type**: Chemical modification code (e.g., 'PLP' for phosphorylation, 'NAG' for glycosylation).
ptm_ccd_code = ""  #@param {type:"string"}

#@markdown ---
#@markdown Click RUN ▶️ to validate additional parameters. Then proceed to the next cell for computational settings and generation.

# Check if basic inputs were validated
if 'protein_sequences' not in globals() or not protein_sequences:
    print("❌ Please run the Basic Input Configuration cell first!")
else:
    print("🔍 Validating Additional Parameters...\n")

    # Process co-factors automatically based on user input
    cofactor_info = []
    cofactor_configs = [
        (cofactor1_method, cofactor1_smiles, cofactor1_ccd),
        (cofactor2_method, cofactor2_smiles, cofactor2_ccd),
        (cofactor3_method, cofactor3_smiles, cofactor3_ccd),
        (cofactor4_method, cofactor4_smiles, cofactor4_ccd)
    ]

    valid_cofactors = 0
    for i, (method, smiles_val, ccd_val) in enumerate(cofactor_configs):
        if method == "None":
            continue

        if method == "SMILES" and smiles_val.strip():
            is_valid, _ = validate_smiles(smiles_val.strip())
            if is_valid:
                cofactor_info.append({'smiles': smiles_val.strip()})
                valid_cofactors += 1
                print(f"✅ Co-factor {valid_cofactors}: SMILES - {smiles_val.strip()}")
            else:
                print(f"❌ Co-factor {i+1}: Invalid SMILES - {smiles_val.strip()}")
        elif method == "CCD Code" and ccd_val.strip():
            if validate_ccd_code(ccd_val.strip()):
                cofactor_info.append({'ccd': ccd_val.strip().upper()})
                ccd_name = COMMON_CCD_CODES.get(ccd_val.strip().upper(), "Unknown")
                valid_cofactors += 1
                print(f"✅ Co-factor {valid_cofactors}: {ccd_val.strip().upper()} - {ccd_name}")
            else:
                print(f"❌ Co-factor {i+1}: Invalid CCD code - {ccd_val.strip()}")

    if valid_cofactors > 0:
        print(f"\n💎 Total valid co-factors: {valid_cofactors}")

    # Process binding pocket constraints
    binding_pocket_constraints = None
    if enable_binding_constraints and pocket_contacts.strip():
        print(f"\n🎯 Processing binding pocket constraints:")
        contacts = []
        for contact_str in pocket_contacts.split(','):
            contact_str = contact_str.strip()
            if ':' in contact_str:
                chain_id, pos_str = contact_str.split(':', 1)
                try:
                    position = int(pos_str.strip())
                    contacts.append([chain_id.strip(), position])
                    print(f"✅ Contact: Chain {chain_id.strip()} position {position}")
                except ValueError:
                    print(f"❌ Invalid contact format: {contact_str}")
            else:
                try:
                    position = int(contact_str)
                    contacts.append(['A', position])
                    print(f"✅ Contact: Chain A position {position}")
                except ValueError:
                    print(f"❌ Invalid contact format: {contact_str}")

        if contacts:
            binding_pocket_constraints = {
                'binder': binder_id,
                'contacts': contacts,
                'max_distance': max_distance
            }

    # Process PTM modifications
    ptm_modifications = None
    if enable_ptm and ptm_ccd_code.strip():
        print(f"\n⚗️ Processing PTM modifications:")
        if validate_ccd_code(ptm_ccd_code.strip()):
            ptm_modifications = {
                'modifications': [{
                    'chain_id': ptm_chain_id,
                    'position': ptm_position,
                    'ccd': ptm_ccd_code.strip().upper()
                }]
            }
            ptm_name = COMMON_CCD_CODES.get(ptm_ccd_code.strip().upper(), "Unknown")
            print(f"✅ PTM: {ptm_ccd_code.strip().upper()} at Chain {ptm_chain_id} position {ptm_position} - {ptm_name}")
        else:
            print(f"❌ Invalid PTM CCD code: {ptm_ccd_code.strip()}")

    # Summary
    print(f"\n📋 Complete Input Summary:")
    print(f"📦 Project: {project_name}")
    print(f"🧬 Valid proteins: {len(protein_sequences)}")
    print(f"🧪 Valid ligands: {len(ligand_smiles) if not structure_only else 'N/A (structure only)'}")
    print(f"💎 Co-factors: {len(cofactor_info)}")
    print(f"🎯 Binding constraints: {'Yes' if binding_pocket_constraints else 'No'}")
    print(f"⚗️ PTM modifications: {'Yes' if ptm_modifications else 'No'}")

    if protein_sequences and (ligand_smiles or structure_only):
        total_combinations = len(protein_sequences) * (len(ligand_smiles) if not structure_only else 1)
        estimated_time = total_combinations * ESTIMATED_TIME_PER_JOB / 60
        print(f"🔢 Total combinations: {total_combinations}")
        print(f"⏱️ Estimated time: {estimated_time:.1f} minutes")
        print(f"\n✅ All parameters validated! Proceed to the next cell to configure and generate screening.")
    else:
        if structure_only:
            needed = "at least 1 protein"
        else:
            needed = "at least 1 protein and 1 ligand"
        print(f"\n❌ Cannot proceed: Need {needed}")

In [None]:
#@title 4️⃣ 🚀 Configure Computational Settings & Generate Screening (modifications are optional, run to proceed to next step) { display-mode: "form" }

#@markdown ## ⚙️ Computation Settings
#@markdown Configure the computational parameters for structure prediction and analysis.

#@markdown ### 🔸 Basic Settings

#@markdown **GPU acceleration**: Enable for faster computation (~10x speedup). Disable only if GPU unavailable.
use_gpu = True  #@param {type:"boolean"}

#@markdown **Use existing results**: Load previously computed results to save time. Disable to force fresh computation.
use_existing_results = True  #@param {type:"boolean"}

#@markdown **Parallel samples**: Number of simultaneous predictions. Higher values use more memory but run faster.
max_parallel_samples = 5  #@param {type:"integer"}

#@markdown ### 🔸 Advanced Sampling Parameters

#@markdown **Recycling steps**: Number of iterative refinement cycles (1-10). More steps = higher accuracy, longer time.
recycling_steps = 4  #@param {type:"integer"}

#@markdown **Sampling steps**: Diffusion sampling iterations (50-500). More steps = better quality, longer computation.
sampling_steps = 300  #@param {type:"integer"}

#@markdown **Diffusion samples**: Multiple structure samples per prediction (1-5). More samples = better statistics.
diffusion_samples = 1  #@param {type:"integer"}

#@markdown **Step scale**: Controls sampling temperature (1.0-2.0). Lower = more diversity, higher = more precision.
step_scale = 1.638  #@param {type:"number"}

#@markdown ### 🔸 Multiple Sequence Alignment

#@markdown **MSA sequences**: Maximum evolutionary sequences to use (1024-16384). More sequences = better accuracy.
max_msa_seqs = 8192  #@param {type:"integer"}

#@markdown **Subsample MSA**: Randomly reduce MSA size to increase diversity while saving memory.
subsample_msa = False  #@param {type:"boolean"}

#@markdown **Subsampled count**: Number of sequences when subsampling (only if subsampling enabled).
num_subsampled_msa = 1024  #@param {type:"integer"}

#@markdown ### 🔸 Affinity Prediction

#@markdown **Molecular weight correction**: Apply size-based correction to binding affinity predictions.
affinity_mw_correction = False  #@param {type:"boolean"}

#@markdown **Affinity sampling steps**: Dedicated steps for binding affinity calculation.
sampling_steps_affinity = 300  #@param {type:"integer"}

#@markdown **Affinity samples**: Multiple samples for robust affinity estimation.
diffusion_samples_affinity = 7  #@param {type:"integer"}

#@markdown ### 🔸 Error Handling

#@markdown **Enable retries**: Automatically retry failed predictions with exponential backoff delay.
enable_retries = True  #@param {type:"boolean"}

#@markdown **Max retries**: Maximum retry attempts per failed prediction (1-5).
max_retry_attempts = 2  #@param {type:"integer"}

#@markdown **Retry delay**: Base delay in seconds between retries (doubles each attempt).
retry_delay_base = 5  #@param {type:"integer"}

#@markdown ---

#@markdown ## 📋 Template Upload (Optional)
#@markdown **Upload template structure**: Provide a .cif template file to guide protein folding (advanced users only). Local mode: looks for template.cif in working directory.
upload_template = False  #@param {type:"boolean"}

#@markdown ---
#@markdown Click RUN ▶️ to generate comprehensive screening configurations.

# Check prerequisites
if 'protein_sequences' not in globals() or 'ligand_smiles' not in globals():
    print("❌ Please run the previous configuration cells first!")
elif not protein_sequences:
    print("❌ No valid protein sequences found. Check your inputs.")
elif not structure_only and not ligand_smiles:
    print("❌ No valid ligands found and not in structure-only mode. Check your inputs.")
else:
    print(f"🔄 Generating Comprehensive Screening Configurations\n")
    print(f"📦 Project: {project_name}")
    print(f"🧬 Proteins: {len(protein_sequences)}")
    print(f"🧪 Ligands: {len(ligand_smiles) if not structure_only else 'N/A (structure only)'}")
    print(f"🟢 Co-factors: {len(cofactor_info) if 'cofactor_info' in globals() else 0}")
    print(f"🎯 Binding constraints: {'Yes' if 'binding_pocket_constraints' in globals() and binding_pocket_constraints else 'No'}")
    print(f"🧪 PTM modifications: {'Yes' if 'ptm_modifications' in globals() and ptm_modifications else 'No'}")

    # Handle template file upload
    template_cif_path = None
    if upload_template and COLAB_AVAILABLE:
        print(f"\n📁 Template File Upload:")
        uploaded_template = files.upload()

        for filename, content in uploaded_template.items():
            if filename.endswith('.cif'):
                # Save template file
                project_dir = os.path.join(RESULTS_DIR, project_name)
                os.makedirs(project_dir, exist_ok=True)
                cif_filename = f"template_{datetime.now().strftime('%Y%m%d_%H%M%S')}.cif"
                template_cif_path = os.path.join(project_dir, cif_filename)
                with open(template_cif_path, "wb") as f:
                    f.write(content)
                template_cif_path = os.path.abspath(template_cif_path)
                print(f"✅ Template saved: {cif_filename}")
            else:
                print(f"❌ Unsupported file type: {filename} (only .cif supported)")

    # Store configuration for next cells
    screening_config = {
        'use_gpu': use_gpu,
        'recycling_steps': recycling_steps,
        'sampling_steps': sampling_steps,
        'diffusion_samples': diffusion_samples,
        'step_scale': step_scale,
        'max_msa_seqs': max_msa_seqs,
        'subsample_msa': subsample_msa,
        'num_subsampled_msa': num_subsampled_msa if subsample_msa else None,
        'affinity_mw_correction': affinity_mw_correction,
        'sampling_steps_affinity': sampling_steps_affinity,
        'diffusion_samples_affinity': diffusion_samples_affinity,
        'enable_retries': enable_retries,
        'max_retry_attempts': max_retry_attempts,
        'retry_delay_base': retry_delay_base,
        'use_existing_results': use_existing_results,
        'max_parallel_samples': max_parallel_samples
    }

    # Prepare configuration with all advanced parameters
    advanced_config = screening_config.copy()

    # Add optional components
    if 'cofactor_info' in globals():
        advanced_config['cofactor_info'] = cofactor_info
    if 'binding_pocket_constraints' in globals():
        advanced_config['binding_pocket_constraints'] = binding_pocket_constraints
    if 'ptm_modifications' in globals():
        advanced_config['ptm_modifications'] = ptm_modifications
    if template_cif_path:
        advanced_config['template_cif_path'] = template_cif_path

    advanced_config['structure_only'] = structure_only

    print(f"\n⚙️ Configuration Summary:")
    print(f"🔧 GPU: {advanced_config['use_gpu']}")
    print(f"🔄 Recycling steps: {advanced_config['recycling_steps']}")
    print(f"📊 Sampling steps: {advanced_config['sampling_steps']}")
    print(f"🌊 Diffusion samples: {advanced_config['diffusion_samples']}")
    print(f"⚡ Step scale: {advanced_config['step_scale']}")
    print(f"🧬 MSA sequences: {advanced_config['max_msa_seqs']}")
    if advanced_config.get('subsample_msa'):
        print(f"📉 MSA subsampling: {advanced_config['num_subsampled_msa']} sequences")
    if advanced_config.get('affinity_mw_correction'):
        print(f"⚗️ Affinity MW correction: Enabled")
    if advanced_config.get('enable_retries'):
        print(f"🔁 Retry attempts: {advanced_config['max_retry_attempts']}")

    # Create YAML files for each combination
    yaml_files = []
    print(f"\n📄 Creating Configuration Files:")

    # Structure-only mode or ligand mode
    if structure_only:
        combinations = [(prot_name, prot_seq, None, None) for prot_name, prot_seq in protein_sequences]
    else:
        combinations = [(prot_name, prot_seq, lig_name, lig_smiles)
                       for prot_name, prot_seq in protein_sequences
                       for lig_name, lig_smiles in ligand_smiles]

    for i, (prot_name, prot_seq, lig_name, lig_smiles) in enumerate(combinations):
        workspace_name = f"screening_{i+1:03d}"
        if structure_only:
            design_name = f"{prot_name}".replace(' ', '_')
        else:
            design_name = f"{prot_name}_{lig_name}".replace(' ', '_')

        try:
            yaml_path = create_screening_yaml(
                workspace_name=workspace_name,
                design_name=design_name,
                protein_sequence=prot_seq,
                ligand_smiles=lig_smiles or "",  # Empty for structure-only
                project_name=project_name,
                **advanced_config
            )
            yaml_files.append(yaml_path)

            # Show progress with details
            config_details = []
            if advanced_config.get('cofactor_info'):
                config_details.append(f"{len(advanced_config['cofactor_info'])} co-factors")
            if advanced_config.get('binding_pocket_constraints'):
                constraint_count = len(advanced_config['binding_pocket_constraints'].get('contacts', []))
                config_details.append(f"{constraint_count} pocket constraints")
            if advanced_config.get('ptm_modifications'):
                ptm_count = len(advanced_config['ptm_modifications'].get('modifications', []))
                config_details.append(f"{ptm_count} PTM modifications")
            if template_cif_path:
                config_details.append("template structure")

            details_str = f" ({', '.join(config_details)})" if config_details else ""
            print(f"✅ {design_name}{details_str}")

        except Exception as e:
            print(f"❌ Error creating {design_name}: {str(e)}")

    print(f"\n📁 Configuration Complete!")
    print(f"📄 Generated {len(yaml_files)} YAML files")
    print(f"📂 Location: {RESULTS_DIR}/{project_name}/")

    # Display configuration file preview if any were created
    if yaml_files:
        print(f"\n🔍 Sample Configuration Preview (first file):")
        try:
            with open(yaml_files[0], 'r') as f:
                preview_content = f.read()
                # Show first 20 lines
                preview_lines = preview_content.split('\n')[:20]
                print("```yaml")
                print('\n'.join(preview_lines))
                if len(preview_content.split('\n')) > 20:
                    print("... (truncated)")
                print("```")
        except Exception as e:
            print(f"❌ Error reading preview: {e}")

    # Store results for next steps
    screening_yaml_files = yaml_files
    screening_advanced_config = advanced_config

    # Estimate computational requirements
    total_combinations = len(yaml_files)
    estimated_gpu_hours = total_combinations * (ESTIMATED_TIME_PER_JOB / 3600)  # Convert to hours
    estimated_cost_estimate = estimated_gpu_hours * 0.50  # Rough estimate at $0.50/GPU hour


In [None]:
#@title 5️⃣ 🌠 Run Structure Predictions { display-mode: "form" }

#@markdown ---
#@markdown Click RUN ▶️ to proceed with structure predictions for all designs. This will execute the Boltz-2 prediction workflow using the configurations created in the previous steps.

def run_boltz_prediction(yaml_filepath, use_gpu=True, override=False, recycling_steps=3, sampling_steps=200, diffusion_samples=1, max_parallel_samples=5, step_scale=1.638, affinity_mw_correction=False, max_msa_seqs=8192, sampling_steps_affinity=200, diffusion_samples_affinity=5, subsample_msa=False, num_subsampled_msa=1024):
    """Run Boltz prediction on the YAML file."""
    try:
        # Change to the directory containing the YAML file
        yaml_dir = os.path.dirname(yaml_filepath)
        yaml_filename = os.path.basename(yaml_filepath)
        # Run boltz predict command
        cmd = ["boltz", "predict", yaml_filename, "--use_msa_server", "--output_format", "pdb"]
        # Add override flag if specified
        if override:
            cmd.append("--override")
        # Add CPU accelerator flag if GPU is disabled
        if not use_gpu:
            cmd.append("--accelerator")
            cmd.append("cpu")
        # Add Boltz parameters
        cmd.extend(["--recycling_steps", str(int(recycling_steps))])
        cmd.extend(["--sampling_steps", str(int(sampling_steps))])
        cmd.extend(["--diffusion_samples", str(int(diffusion_samples))])
        cmd.extend(["--max_parallel_samples", str(int(max_parallel_samples))])
        cmd.extend(["--step_scale", str(float(step_scale))])
        if affinity_mw_correction:
            cmd.append("--affinity_mw_correction")
        cmd.extend(["--max_msa_seqs", str(int(max_msa_seqs))])
        cmd.extend(["--sampling_steps_affinity", str(int(sampling_steps_affinity))])
        cmd.extend(["--diffusion_samples_affinity", str(int(diffusion_samples_affinity))])
        # --- NEW: MSA Subsampling ---
        if subsample_msa:
            cmd.append("--subsample_msa")
            cmd.extend(["--num_subsampled_msa", str(int(num_subsampled_msa))])
        # Print the command line for user reference
        print("[DEBUG]", " ".join(cmd))
        result = subprocess.run(
            cmd,
            cwd=yaml_dir,
            capture_output=True,
            text=True,
            timeout=300  # 5 minute timeout
        )
        if result.returncode != 0:
            raise Exception(f"Boltz command failed: {result.stderr}")
        return result.stdout
    except subprocess.TimeoutExpired:
        raise Exception("Boltz prediction timed out after 5 minutes")
    except Exception as e:
        raise Exception(f"Error running Boltz prediction: {str(e)}")

def parse_boltz_results(yaml_filepath, structure_only=False):
    """Parse Boltz results from the JSON output files."""
    try:
        # Construct the path to the results files
        yaml_dir = os.path.dirname(yaml_filepath)
        yaml_filename = os.path.basename(yaml_filepath)
        yaml_name = os.path.splitext(yaml_filename)[0]
        
        # Path to the affinity results JSON file - corrected path structure with underscores
        affinity_results_path = os.path.join(yaml_dir, f"boltz_results_{yaml_name}", "predictions", yaml_name, f"affinity_{yaml_name}.json")
        
        # Path to the confidence results JSON file - corrected path structure with underscores
        confidence_results_path = os.path.join(yaml_dir, f"boltz_results_{yaml_name}", "predictions", yaml_name, f"confidence_{yaml_name}_model_0.json")
        
        results = {}
        
        if not structure_only:
            # Parse affinity results
            if os.path.exists(affinity_results_path):
                with open(affinity_results_path, 'r') as f:
                    affinity_results = json.load(f)
                # Extract the key values from affinity results
                results.update({
                    "affinity_pred_value": affinity_results.get("affinity_pred_value"),
                    "affinity_probability_binary": affinity_results.get("affinity_probability_binary"),
                    "affinity_pred_value1": affinity_results.get("affinity_pred_value1"),
                    "affinity_probability_binary1": affinity_results.get("affinity_probability_binary1"),
                    "affinity_pred_value2": affinity_results.get("affinity_pred_value2"),
                    "affinity_probability_binary2": affinity_results.get("affinity_probability_binary2")
                })
            else:
                raise Exception(f"Affinity results file not found: {affinity_results_path}")
        # Parse confidence results
        if os.path.exists(confidence_results_path):
            with open(confidence_results_path, 'r') as f:
                confidence_results = json.load(f)
            # Extract the key values from confidence results
            results.update({
                "confidence_score": confidence_results.get("confidence_score"),
                "ptm": confidence_results.get("ptm"),
                "iptm": confidence_results.get("iptm"),
                "ligand_iptm": confidence_results.get("ligand_iptm"),
                "protein_iptm": confidence_results.get("protein_iptm"),
                "complex_plddt": confidence_results.get("complex_plddt"),
                "complex_iplddt": confidence_results.get("complex_iplddt"),
                "complex_pde": confidence_results.get("complex_pde"),
                "complex_ipde": confidence_results.get("complex_ipde"),
                "chains_ptm": confidence_results.get("chains_ptm"),
                "pair_chains_iptm": confidence_results.get("pair_chains_iptm")
            })
        else:
            # If confidence file doesn't exist, use default values
            results.update({
                "confidence_score": 0.0,
                "ptm": 0.0,
                "iptm": 0.0,
                "ligand_iptm": 0.0,
                "protein_iptm": 0.0,
                "complex_plddt": 0.0,
                "complex_iplddt": 0.0,
                "complex_pde": 0.0,
                "complex_ipde": 0.0,
                "chains_ptm": {"0": 0.0, "1": 0.0},
                "pair_chains_iptm": {"0": {"0": 0.0, "1": 0.0}, "1": {"0": 0.0, "1": 0.0}}
            })
        return results
    except Exception as e:
        raise Exception(f"Error parsing Boltz results: {str(e)}")

print("✅ Utility functions loaded for structure prediction")

# Use the computational settings already configured above
print(f"Using GPU: {use_gpu}")
print(f"Recycling steps: {recycling_steps}")
print(f"Sampling steps: {sampling_steps}")
print(f"Diffusion samples: {diffusion_samples}")
print(f"Max parallel samples: {max_parallel_samples}")
print(f"Step scale: {step_scale}")
print(f"Affinity sampling steps: {sampling_steps_affinity}")
print(f"Affinity diffusion samples: {diffusion_samples_affinity}")
print(f"Affinity MW correction: {affinity_mw_correction}")
print(f"Max MSA seqs: {max_msa_seqs}")
print(f"Subsample MSA: {subsample_msa}")
print(f"Num subsampled MSA: {num_subsampled_msa}")

# Run predictions for all designs
prediction_results = []

for yaml_path in yaml_files:
    try:
        # Extract design info from YAML filename
        yaml_filename = os.path.basename(yaml_path)
        design_name = os.path.splitext(yaml_filename)[0].replace("drug_screening_", "")
        
        # Extract design information directly from YAML file
        try:
            with open(yaml_path, "r") as f:
                yaml_content = yaml.safe_load(f)
            
            # Extract protein sequence and ligand SMILES from YAML
            protein_sequence = ""
            ligand_smiles = ""
            
            for seq_entry in yaml_content.get("sequences", []):
                if "protein" in seq_entry:
                    protein_sequence = seq_entry["protein"]["sequence"]
                elif "ligand" in seq_entry:
                    ligand_smiles = seq_entry["ligand"]["smiles"]
            
            if not protein_sequence:
                protein_sequence = "Unknown"
            if not ligand_smiles:
                ligand_smiles = "Unknown"
                
        except Exception as e:
            print(f"Warning: Could not extract design info from YAML: {e}")
            protein_sequence = "Unknown"
            ligand_smiles = "Unknown"

        print(f"🔬 Running prediction for {design_name}...")

        # Run Boltz prediction with configured parameters
        output = run_boltz_prediction(
            yaml_path,
            use_gpu=use_gpu,
            override=True,
            recycling_steps=recycling_steps,
            sampling_steps=sampling_steps,
            diffusion_samples=diffusion_samples,
            max_parallel_samples=max_parallel_samples,
            step_scale=step_scale,
            affinity_mw_correction=affinity_mw_correction,
            max_msa_seqs=max_msa_seqs,
            sampling_steps_affinity=sampling_steps_affinity,
            diffusion_samples_affinity=diffusion_samples_affinity,
            subsample_msa=subsample_msa,
            num_subsampled_msa=num_subsampled_msa
        )

        print(f"✅ Boltz prediction completed for {design_name}")

        # Parse results
        results = parse_boltz_results(yaml_path)

        # Store results
        result_entry = {
            "Design Name": design_name,
            "Protein Sequence": protein_sequence,
            "Ligand SMILES": ligand_smiles,
            "Affinity Pred Value": results.get("affinity_pred_value"),
            "Affinity Probability Binary": results.get("affinity_probability_binary"),
            "Confidence Score": results.get("confidence_score"),
            "pTM": results.get("ptm"),
            "ipTM": results.get("iptm"),
            "Ligand ipTM": results.get("ligand_iptm"),
            "Protein ipTM": results.get("protein_iptm"),
            "Complex pLDDT": results.get("complex_plddt"),
            "Complex ipLDDT": results.get("complex_iplddt"),
            "Complex PDE": results.get("complex_pde"),
            "Complex ipDE": results.get("complex_ipde")
        }

        # Calculate pIC50 if affinity value is available
        if results.get("affinity_pred_value"):
            # Convert from nM to M and calculate pIC50
            ic50_m = results["affinity_pred_value"] * 1e-9
            pic50 = -np.log10(ic50_m)
            result_entry["pIC50"] = pic50

        prediction_results.append(result_entry)

        print(f"📊 Results parsed for {design_name}")
        if results.get("affinity_pred_value"):
            print(f"   Predicted IC50: {results['affinity_pred_value']:.2f} nM")
            if "pIC50" in result_entry:
                print(f"   Predicted pIC50: {result_entry['pIC50']:.2f}")
        if results.get("confidence_score"):
            print(f"   Confidence Score: {results['confidence_score']:.3f}")

    except Exception as e:
        print(f"❌ Error processing {design_name}: {str(e)}")
        # Add entry with error information
        prediction_results.append({
            "Design Name": design_name,
            "Error": str(e)
        })

# Create results DataFrame
df_results = pd.DataFrame(prediction_results)

# Display results summary
print(f"🎯 Prediction Summary:")
print(f"Total designs processed: {len(yaml_files)}")
successful_predictions = len([r for r in prediction_results if "Error" not in r])
print(f"Successful predictions: {successful_predictions}")
print(f"Failed predictions: {len(prediction_results) - successful_predictions}")

# Display results table
if successful_predictions > 0:
    print("📋 Prediction Results:")
    display_columns = ["Design Name", "Affinity Pred Value", "pIC50", "Confidence Score", "Ligand ipTM"]
    available_columns = [col for col in display_columns if col in df_results.columns]
    display(df_results[available_columns])
else:
    print("⚠️ No successful predictions to display.")


In [None]:
#@title 6️⃣ 📊 Visualize Results & 3D Structures { display-mode: "form" }

#@markdown ### 🎯 Plot Configuration

#@markdown **Show all designs**: Display results for all designs or only successful predictions.
show_all_designs = False #@param {type:"boolean"}

#@markdown **Plot style**: Choose the visual style for your plots.
plot_style = "plotly_white" #@param ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white"]

#@markdown **Color scheme**: Select color palette for visualizations.
color_scheme = "viridis" #@param ["viridis", "plasma", "inferno", "magma", "cividis", "rainbow", "turbo"]

#@markdown ### 🧬 3D Structure Settings

#@markdown **Structure style**: Choose how to display the 3D molecular structures.
structure_style = "cartoon" #@param ["cartoon", "stick", "sphere", "line", "surface"]

#@markdown **Show ligand**: Highlight the ligand in the 3D structure.
show_ligand = True #@param {type:"boolean"}

#@markdown **Ligand style**: Choose how to display the ligand.
ligand_style = "stick" #@param ["stick", "sphere", "line", "surface"]

#@markdown **Background color**: Set the background color for 3D viewer.
background_color = "white" #@param ["white", "black", "gray", "lightgray"]

# Import required libraries
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import os
import glob

# Install and import py3Dmol for 3D visualization
try:
    import py3Dmol
except ImportError:
    !pip install py3Dmol
    import py3Dmol

print("🎨 Setting up visualizations...")

# Filter data based on user preference
if show_all_designs:
    viz_data = df_results.copy()
    print(f"📊 Visualizing all {len(viz_data)} designs")
else:
    viz_data = df_results[~df_results.get("Error", pd.Series(dtype=object)).notna()].copy()
    print(f"📊 Visualizing {len(viz_data)} successful predictions")

if len(viz_data) == 0:
    print("⚠️ No data available for visualization.")
else:
    # Create visualization dashboard
    print("🎯 Creating interactive plots...")

    # Set the plot template
    import plotly.io as pio
    pio.templates.default = plot_style

    # 1. Affinity vs Confidence Scatter Plot
    if "Affinity Pred Value" in viz_data.columns and "Confidence Score" in viz_data.columns:
        fig1 = px.scatter(
            viz_data,
            x="Confidence Score",
            y="Affinity Pred Value",
            hover_data=["Design Name"],
            title="🎯 Predicted Affinity vs Confidence Score",
            labels={
                "Affinity Pred Value": "Predicted IC50 (nM)",
                "Confidence Score": "Structure Confidence"
            },
            color_discrete_sequence=px.colors.qualitative.Set1
        )
        fig1.update_layout(
            height=500,
            showlegend=False,
            title_x=0.5,
            font=dict(size=12)
        )
        fig1.show()

    # 2. pIC50 Distribution
    if "pIC50" in viz_data.columns:
        fig2 = px.histogram(
            viz_data,
            x="pIC50",
            nbins=20,
            title="📈 Distribution of Predicted pIC50 Values",
            labels={"pIC50": "Predicted pIC50", "count": "Number of Designs"},
            color_discrete_sequence=[px.colors.qualitative.Set1[1]]
        )
        fig2.add_vline(
            x=viz_data["pIC50"].median(),
            line_dash="dash",
            line_color="red",
            annotation_text=f"Median: {viz_data["pIC50"].median():.2f}"
        )
        fig2.update_layout(
            height=400,
            showlegend=False,
            title_x=0.5,
            font=dict(size=12)
        )
        fig2.show()

    # 3. Multi-metric Comparison
    metrics_cols = ["Confidence Score", "Ligand ipTM", "Complex pLDDT"]
    available_metrics = [col for col in metrics_cols if col in viz_data.columns]

    if len(available_metrics) > 1:
        fig3 = make_subplots(
            rows=1, cols=len(available_metrics),
            subplot_titles=available_metrics,
            shared_yaxis=True
        )

        colors = px.colors.qualitative.Set1[:len(available_metrics)]

        for i, metric in enumerate(available_metrics):
            fig3.add_trace(
                go.Box(
                    y=viz_data[metric],
                    name=metric,
                    marker_color=colors[i],
                    showlegend=False
                ),
                row=1, col=i+1
            )

        fig3.update_layout(
            title_text="📊 Quality Metrics Distribution",
            height=400,
            title_x=0.5,
            font=dict(size=12)
        )
        fig3.show()

    # 4. Top Performers Table
    if "pIC50" in viz_data.columns:
        top_designs = viz_data.nlargest(5, "pIC50")
        print("🏆 Top 5 Predicted Performers:")
        display_cols = ["Design Name", "pIC50", "Affinity Pred Value", "Confidence Score"]
        available_display_cols = [col for col in display_cols if col in top_designs.columns]
        display(top_designs[available_display_cols])

    print("🧬 Preparing 3D structure visualization...")

    # Find available PDB files
    pdb_files = []
    results_dir = "catdiscovery_results"

    if os.path.exists(results_dir):
        for design_name in viz_data["Design Name"].unique():
            # Look for PDB files in Boltz results
            pattern = os.path.join(results_dir, f"boltz_results_drug_screening_{design_name}", "predictions", f"drug_screening_{design_name}", "*.pdb")
            found_files = glob.glob(pattern)
            if found_files:
                pdb_files.extend([(design_name, f) for f in found_files])

    if pdb_files:
        print(f"Found {len(pdb_files)} PDB structure files")

        # Create 3D visualization for top designs
        max_structures = min(3, len(pdb_files))
        print(f"🎭 Displaying top {max_structures} 3D structures...")

        for i, (design_name, pdb_file) in enumerate(pdb_files[:max_structures]):
            print(f"🔬 Structure {i+1}: {design_name}")

            try:
                # Read PDB file
                with open(pdb_file, "r") as f:
                    pdb_content = f.read()

                # Create 3D viewer
                viewer = py3Dmol.view(width=800, height=600)
                viewer.addModel(pdb_content, "pdb")

                # Set background
                viewer.setBackgroundColor(background_color)

                # Style protein structure
                if structure_style == "cartoon":
                    viewer.setStyle({"cartoon": {"color": "spectrum"}})
                elif structure_style == "stick":
                    viewer.setStyle({"stick": {"colorscheme": "chainHetatm"}})
                elif structure_style == "sphere":
                    viewer.setStyle({"sphere": {"colorscheme": "chainHetatm"}})
                elif structure_style == "line":
                    viewer.setStyle({"line": {"colorscheme": "chainHetatm"}})
                elif structure_style == "surface":
                    viewer.addSurface(py3Dmol.VDW, {"opacity": 0.7, "color": "white"})

                # Highlight ligand if requested
                if show_ligand:
                    if ligand_style == "stick":
                        viewer.setStyle({"hetflag": True}, {"stick": {"colorscheme": "greenCarbon", "radius": 0.3}})
                    elif ligand_style == "sphere":
                        viewer.setStyle({"hetflag": True}, {"sphere": {"colorscheme": "greenCarbon", "radius": 0.8}})
                    elif ligand_style == "line":
                        viewer.setStyle({"hetflag": True}, {"line": {"colorscheme": "greenCarbon"}})
                    elif ligand_style == "surface":
                        viewer.addSurface(py3Dmol.VDW, {"opacity": 0.8, "color": "green"}, {"hetflag": True})

                # Center and zoom
                viewer.zoomTo()
                viewer.show()

                # Display structure info
                if design_name in viz_data["Design Name"].values:
                    design_info = viz_data[viz_data["Design Name"] == design_name].iloc[0]
                    print(f"📊 Structure Quality Metrics:")
                    if "Confidence Score" in design_info and pd.notna(design_info["Confidence Score"]):
                        print(f"   • Confidence Score: {design_info["Confidence Score"]:.3f}")
                    if "Ligand ipTM" in design_info and pd.notna(design_info["Ligand ipTM"]):
                        print(f"   • Ligand ipTM: {design_info["Ligand ipTM"]:.3f}")
                    if "pIC50" in design_info and pd.notna(design_info["pIC50"]):
                        print(f"   • Predicted pIC50: {design_info["pIC50"]:.2f}")

            except Exception as e:
                print(f"❌ Error displaying structure for {design_name}: {str(e)}")
    else:
        print("⚠️ No PDB structure files found. Make sure predictions have completed successfully.")
        print(f"   Looking in: {results_dir}/boltz_results_*/predictions/*/")

print("✨ Visualization complete!")


## 💡 Tips & Troubleshooting

### 🔧 Common Issues:

1. Dependencies not installed: Re-run the setup cell and restart runtime
2. Invalid SMILES: Check strings using online SMILES validators
3. Invalid protein sequence: Use only standard amino acid codes (A-Z)
4. GPU memory issues: Reduce batch size or switch to CPU mode

### ⚡ Performance Tips:

- Use GPU acceleration when available
- Start with fewer combinations for testing
- Monitor memory usage for large screens
- Save intermediate results regularly

---

Happy Screening! 🧬💊