<a href="https://colab.research.google.com/github/eoinleen/PDB-tools/blob/main/which_residues_to_fix_for_ProteinMPNN_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
#!/usr/bin/env python3
"""
Motif Contact Analysis for ProteinMPNN
Analyzes PDB files with two chains and identifies contact residues for fixing.
Assumes the smaller chain is the motif to be analyzed.

Distances represent the minimum distance between any heavy atom (non-hydrogen)
in each motif residue and any heavy atom in the target protein. For each
residue, all atoms are compared against all target atoms, and the shortest
distance is reported. This approach identifies residues with the closest
physical proximity to the binding interface, regardless of whether the contact
involves backbone or sidechain atoms. Residues showing distances ≤3.5Å typically
represent critical binding contacts that should be fixed during ProteinMPNN
design.

Usage in Google Colab:
1. Upload PDB file
2. Run script
3. Get residue list for ProteinMPNN FIXED positions
"""

import numpy as np
import pandas as pd
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Install required packages (run this cell first in Colab)
def install_dependencies():
    """Install required packages for Colab"""
    import subprocess
    import sys

    packages = ['biopython', 'pandas', 'numpy']
    for package in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

def parse_pdb_simple(pdb_file):
    """
    Simple PDB parser that extracts coordinates and residue info
    Returns dict with chain information
    """
    chains = defaultdict(list)

    with open(pdb_file, 'r') as f:
        for line in f:
            if line.startswith('ATOM'):
                chain_id = line[21]
                res_num = int(line[22:26].strip())
                res_name = line[17:20].strip()
                atom_name = line[12:16].strip()
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])

                atom_info = {
                    'res_num': res_num,
                    'res_name': res_name,
                    'atom_name': atom_name,
                    'coords': np.array([x, y, z]),
                    'chain': chain_id
                }
                chains[chain_id].append(atom_info)

    return chains

def identify_chains(chains):
    """
    Identify motif (smaller) and target (larger) chains
    """
    chain_sizes = {chain: len(set(atom['res_num'] for atom in atoms))
                  for chain, atoms in chains.items()}

    print(f"Chain sizes: {chain_sizes}")

    motif_chain = min(chain_sizes.keys(), key=lambda x: chain_sizes[x])
    target_chain = max(chain_sizes.keys(), key=lambda x: chain_sizes[x])

    print(f"Motif chain (smaller): {motif_chain} ({chain_sizes[motif_chain]} residues)")
    print(f"Target chain (larger): {target_chain} ({chain_sizes[target_chain]} residues)")

    return motif_chain, target_chain

def calculate_distances(motif_atoms, target_atoms, heavy_atoms_only=True):
    """
    Calculate minimum distances between all motif and target atoms
    Returns dict with residue-level minimum distances
    """
    motif_residues = defaultdict(list)
    target_coords = []

    # Group motif atoms by residue
    for atom in motif_atoms:
        if heavy_atoms_only and atom['atom_name'].startswith('H'):
            continue
        motif_residues[atom['res_num']].append(atom)

    # Get all target coordinates
    for atom in target_atoms:
        if heavy_atoms_only and atom['atom_name'].startswith('H'):
            continue
        target_coords.append(atom['coords'])

    target_coords = np.array(target_coords)

    # Calculate minimum distance for each motif residue
    residue_distances = {}

    for res_num, atoms in motif_residues.items():
        min_distance = float('inf')
        res_name = atoms[0]['res_name']

        for atom in atoms:
            distances = np.linalg.norm(target_coords - atom['coords'], axis=1)
            min_dist_for_atom = np.min(distances)
            min_distance = min(min_distance, min_dist_for_atom)

        residue_distances[res_num] = {
            'min_distance': min_distance,
            'res_name': res_name
        }

    return residue_distances

def analyze_contacts(pdb_file, distance_cutoffs=[3.0, 3.5, 4.0, 4.5, 5.0]):
    """
    Main analysis function
    """
    print(f"Analyzing PDB file: {pdb_file}")
    print("="*50)

    # Parse PDB
    chains = parse_pdb_simple(pdb_file)

    if len(chains) != 2:
        raise ValueError(f"Expected 2 chains, found {len(chains)}: {list(chains.keys())}")

    # Identify motif and target chains
    motif_chain, target_chain = identify_chains(chains)

    # Calculate distances
    print("\nCalculating distances...")
    distances = calculate_distances(chains[motif_chain], chains[target_chain])

    # Analyze contacts at different cutoffs
    results = {}

    print(f"\nContact Analysis for Chain {motif_chain}:")
    print("="*40)

    for cutoff in distance_cutoffs:
        contacts = {res: info for res, info in distances.items()
                   if info['min_distance'] <= cutoff}

        contact_list = [(res, info['res_name'], info['min_distance'])
                       for res, info in contacts.items()]
        contact_list.sort()

        results[cutoff] = contact_list

        print(f"\nContacts within {cutoff}Å: {len(contact_list)} residues")
        for res, res_name, dist in contact_list:
            print(f"  {motif_chain}:{res}:{res_name} (min dist: {dist:.2f}Å)")

    return results, motif_chain, target_chain

def generate_proteinmpnn_lists(results, motif_chain, recommended_cutoff=3.5):
    """
    Generate residue lists for ProteinMPNN in different formats
    """
    print(f"\n" + "="*50)
    print("PROTEINMPNN RESIDUE LISTS")
    print("="*50)

    # Conservative (3.5Å)
    conservative_contacts = results[recommended_cutoff]
    conservative_residues = [res for res, _, _ in conservative_contacts]

    # Liberal (4.0Å)
    liberal_contacts = results[4.0] if 4.0 in results else results[recommended_cutoff]
    liberal_residues = [res for res, _, _ in liberal_contacts]

    # Identify structural residues (Proline, Cysteine)
    structural_residues = []
    for cutoff_results in results.values():
        for res, res_name, _ in cutoff_results:
            if res_name in ['PRO', 'CYS'] and res not in structural_residues:
                structural_residues.append(res)

    print(f"\n1. CONSERVATIVE LIST ({recommended_cutoff}Å cutoff):")
    conservative_list = f"{motif_chain}:" + ",".join([f"{motif_chain}:{res}" for res in sorted(conservative_residues)])
    print(f"   {conservative_list}")

    print(f"\n2. LIBERAL LIST (4.0Å cutoff):")
    liberal_list = f"{motif_chain}:" + ",".join([f"{motif_chain}:{res}" for res in sorted(liberal_residues)])
    print(f"   {liberal_list}")

    if structural_residues:
        print(f"\n3. STRUCTURAL RESIDUES FOUND:")
        struct_list = f"{motif_chain}:" + ",".join([f"{motif_chain}:{res}" for res in sorted(structural_residues)])
        print(f"   {struct_list}")
        print("   ^ Consider adding these to your FIXED list")

    # Final recommendation
    final_residues = sorted(set(conservative_residues + structural_residues))
    final_list = ",".join([f"{motif_chain}:{res}" for res in final_residues])

    print(f"\n4. RECOMMENDED FINAL LIST:")
    print(f"   {final_list}")
    print(f"   ({len(final_residues)} residues total)")

    return {
        'conservative': conservative_residues,
        'liberal': liberal_residues,
        'structural': structural_residues,
        'recommended': final_residues,
        'chain': motif_chain
    }

def create_summary_dataframe(results, motif_chain):
    """
    Create a pandas DataFrame summary of all contacts
    """
    all_residues = set()
    for contact_list in results.values():
        for res, _, _ in contact_list:
            all_residues.add(res)

    summary_data = []
    for res in sorted(all_residues):
        row = {'residue': f"{motif_chain}:{res}"}

        # Find residue name and distances at each cutoff
        res_name = None
        for cutoff, contact_list in results.items():
            distance = None
            for r, rn, d in contact_list:
                if r == res:
                    distance = d
                    res_name = rn
                    break
            row[f'{cutoff}A'] = f"{distance:.2f}" if distance else "-"

        row['res_name'] = res_name
        summary_data.append(row)

    df = pd.DataFrame(summary_data)
    return df

# Main execution function for Google Colab
def analyze_motif_contacts(pdb_file_path):
    """
    Main function to run the complete analysis
    """
    try:
        # Run analysis
        results, motif_chain, target_chain = analyze_contacts(pdb_file_path)

        # Generate ProteinMPNN lists
        mpnn_lists = generate_proteinmpnn_lists(results, motif_chain)

        # Create summary table
        summary_df = create_summary_dataframe(results, motif_chain)

        print(f"\n" + "="*50)
        print("SUMMARY TABLE")
        print("="*50)
        print(summary_df.to_string(index=False))

        print(f"\n" + "="*50)
        print("USAGE INSTRUCTIONS")
        print("="*50)
        print("1. Copy the RECOMMENDED FINAL LIST above")
        print("2. Use it as the --fixed_residues parameter in ProteinMPNN")
        print("3. Start with 15-18 sequences per scaffold")
        print("4. Analyze results and iterate if needed")

        return results, mpnn_lists, summary_df

    except Exception as e:
        print(f"Error analyzing PDB file: {e}")
        return None, None, None

# Auto-run setup for Google Colab
print("🧬 MOTIF CONTACT ANALYSIS FOR PROTEINMPNN")
print("="*50)
print("📦 Setting up environment...")

# Install dependencies automatically
try:
    install_dependencies()
    print("✅ Dependencies installed successfully!")
except Exception as e:
    print(f"❌ Error installing dependencies: {e}")

# Import required modules
try:
    from google.colab import files
    import os
    print("✅ Google Colab modules loaded!")
except ImportError:
    print("❌ This script requires Google Colab environment")

print("\n" + "="*50)
print("🚀 READY FOR ANALYSIS!")
print("="*50)
print("📁 Please upload your PDB file below...")
print("   • File should contain exactly 2 chains")
print("   • Smaller chain will be treated as the motif")
print("   • Click 'Choose Files' to select your PDB")

# Prompt for file upload
uploaded = files.upload()

if uploaded:
    pdb_filename = list(uploaded.keys())[0]
    print(f"\n✅ File uploaded: {pdb_filename}")

    # Validate file extension
    if not pdb_filename.lower().endswith('.pdb'):
        print("⚠️  Warning: File doesn't have .pdb extension")
        print("Continuing with analysis anyway...")

    # Run analysis
    print(f"\n🔬 Analyzing {pdb_filename}...")
    print("="*50)

    try:
        results, mpnn_lists, summary_df = analyze_motif_contacts(pdb_filename)

        if results is not None:
            print(f"\n🎯 FINAL RESULTS:")
            print("="*50)
            print(f"✅ Analysis completed successfully!")
            print(f"\n📊 RECOMMENDED RESIDUES FOR PROTEINMPNN:")
            recommended_list = ",".join([f"{mpnn_lists['chain']}:{res}" for res in mpnn_lists['recommended']])
            print(f"\n   --fixed_residues {recommended_list}")
            print(f"\n💡 Copy the line above and use it in your ProteinMPNN command!")
            print(f"📈 Total residues to fix: {len(mpnn_lists['recommended'])}")
        else:
            print("❌ Analysis failed. Please check your PDB file format.")

        # Clean up uploaded file
        if os.path.exists(pdb_filename):
            os.remove(pdb_filename)

    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print("Please check that your PDB file contains exactly 2 chains.")

else:
    print("❌ No file uploaded.")
    print("\n💡 To run analysis, re-run this cell and upload a PDB file when prompted.")

print(f"\n" + "="*60)
print("🔄 To analyze another file, simply re-run this cell!")
print("="*60)

🧬 MOTIF CONTACT ANALYSIS FOR PROTEINMPNN
📦 Setting up environment...
✅ Dependencies installed successfully!
✅ Google Colab modules loaded!

🚀 READY FOR ANALYSIS!
📁 Please upload your PDB file below...
   • File should contain exactly 2 chains
   • Smaller chain will be treated as the motif
   • Click 'Choose Files' to select your PDB


Saving 66-93.pdb to 66-93.pdb

✅ File uploaded: 66-93.pdb

🔬 Analyzing 66-93.pdb...
Analyzing PDB file: 66-93.pdb
Chain sizes: {'A': 28, 'B': 215}
Motif chain (smaller): A (28 residues)
Target chain (larger): B (215 residues)

Calculating distances...

Contact Analysis for Chain A:

Contacts within 3.0Å: 6 residues
  A:70:THR (min dist: 2.42Å)
  A:71:GLY (min dist: 2.80Å)
  A:72:LEU (min dist: 2.89Å)
  A:74:CYS (min dist: 2.83Å)
  A:75:GLU (min dist: 2.95Å)
  A:76:MET (min dist: 2.83Å)

Contacts within 3.5Å: 9 residues
  A:70:THR (min dist: 2.42Å)
  A:71:GLY (min dist: 2.80Å)
  A:72:LEU (min dist: 2.89Å)
  A:73:GLN (min dist: 3.30Å)
  A:74:CYS (min dist: 2.83Å)
  A:75:GLU (min dist: 2.95Å)
  A:76:MET (min dist: 2.83Å)
  A:81:GLN (min dist: 3.38Å)
  A:86:GLU (min dist: 3.06Å)

Contacts within 4.0Å: 13 residues
  A:70:THR (min dist: 2.42Å)
  A:71:GLY (min dist: 2.80Å)
  A:72:LEU (min dist: 2.89Å)
  A:73:GLN (min dist: 3.30Å)
  A:74:CYS (min dist: 2.83Å)
  A:75:GLU (min dist: 2.95Å)
  A:7