<a href="https://colab.research.google.com/github/neetushibu/IontheFold-Team6/blob/main/IonTheFold001_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ProteinMPNN Baseline Evaluation for Ion the Fold Project

In [1]:
!nvidia-smi

Fri Aug 22 15:21:26 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             48W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
#@title Cell 1: Install Dependencies and Setup
import subprocess
import sys

def install_packages():
    packages = [
        'plotly', 'seaborn', 'biopython', 'matplotlib', 'pandas', 'numpy', 'scipy'
    ]
    for package in packages:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])

install_packages()
print("✅ All packages installed!")

✅ All packages installed!


In [3]:
#@title Cell 2: Clone Repository and Import Libraries
import json, time, os, sys, glob
import warnings
warnings.filterwarnings('ignore')
import subprocess # Import subprocess

proteinmpnn_path = '/content/ProteinMPNN'

if not os.path.isdir(proteinmpnn_path):
    print(f"Cloning ProteinMPNN repository into {proteinmpnn_path}...")
    # Use subprocess for better control and error handling
    try:
        subprocess.run(["git", "clone", "-q", "https://github.com/dauparas/ProteinMPNN.git", proteinmpnn_path], check=True)
        print("✅ ProteinMPNN repository cloned successfully!")
    except subprocess.CalledProcessError as e:
        print(f"❌ Error cloning repository: {e}")
        # Exit or raise an error if cloning fails critically
        # sys.exit("Failed to clone ProteinMPNN repository.") # Or just print error and continue if possible
else:
    print(f"ProteinMPNN directory already exists at {proteinmpnn_path}. Skipping cloning.")


# Ensure the cloned directory is in the system path
if proteinmpnn_path not in sys.path:
    sys.path.append(proteinmpnn_path)
    print(f"Added {proteinmpnn_path} to sys.path")

# Verify if the module can be imported
try:
    import protein_mpnn_utils
    print("✅ Successfully imported protein_mpnn_utils!")
except ModuleNotFoundError as e:
    print(f"❌ ModuleNotFoundError after cloning and path update: {e}")
    print("Please check if the /content/ProteinMPNN directory exists and contains protein_mpnn_utils.py")
    # You might want to add steps here to help the user diagnose further
    # For example, list directory contents: !ls /content/ProteinMPNN

# Import remaining libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
from scipy import stats
from collections import defaultdict, Counter
import re
from google.colab import files

# Re-import the necessary functions from protein_mpnn_utils after confirming it's available
# This ensures these functions are available in the global scope of this cell
try:
    from protein_mpnn_utils import (
        loss_nll, loss_smoothed, gather_edges, gather_nodes,
        gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq,
        tied_featurize, parse_PDB, StructureDataset,
        StructureDatasetPDB, ProteinMPNN
    )
    print("✅ protein_mpnn_utils functions imported!")
except ModuleNotFoundError as e:
     print(f"❌ Failed to import protein_mpnn_utils functions: {e}")
     print("This indicates a persistent issue with finding the module.")


print("✅ Repository setup and libraries imported!")

Cloning ProteinMPNN repository into /content/ProteinMPNN...
✅ ProteinMPNN repository cloned successfully!
Added /content/ProteinMPNN to sys.path
✅ Successfully imported protein_mpnn_utils!
✅ protein_mpnn_utils functions imported!
✅ Repository setup and libraries imported!


In [4]:
#@title Cell 3: Setup Model and Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model configuration
model_name = "v_48_020"  # Options: v_48_002, v_48_010, v_48_020, v_48_030
backbone_noise = 0.00

# Load model
path_to_model_weights = '/content/ProteinMPNN/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print('Number of edges:', checkpoint['num_edges'])
    print(f'Training noise level: {checkpoint["noise_level"]}A')

    model = ProteinMPNN(
        num_letters=21,
        node_features=hidden_dim,
        edge_features=hidden_dim,
        hidden_dim=hidden_dim,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        augment_eps=backbone_noise,
        k_neighbors=checkpoint['num_edges']
    )
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")

Using device: cuda:0
Number of edges: 48
Training noise level: 0.2A
✅ Model loaded successfully!


In [5]:
#@title Cell 4: Helper Functions
def make_tied_positions_for_homomers(pdb_dict_list):
    """Create tied positions for homomer proteins"""
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain'])
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1, chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i]
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict


def get_pdb_file(pdb_code, dest_dir=".", overwrite=False, allow_upload=True):
    """
    Returns a local path to a plain .pdb file.
    Tries:
      1) https://files.rcsb.org/download/<code>.pdb
      2) https://files.rcsb.org/pub/pdb/data/structures/divided/pdb/<code[1:3]>/pdb<code>.ent.gz (then gunzip)
    If pdb_code is empty and allow_upload=True, prompts for upload in Google Colab.
    """
    import os, urllib.request, gzip, shutil
    from urllib.error import HTTPError, URLError

    code = (pdb_code or "").strip()
    if not code:
        if not allow_upload:
            raise ValueError("pdb_code is empty and uploads are disabled.")
        # Colab upload path
        try:
            from google.colab import files
        except Exception:
            raise RuntimeError("Upload only works in Google Colab. Provide a pdb_code or run in Colab.")
        print("Please upload a PDB file:")
        upload_dict = files.upload()
        name, data = next(iter(upload_dict.items()))
        out_path = os.path.join(dest_dir, name if name.lower().endswith(".pdb") else "tmp.pdb")
        with open(out_path, "wb") as f:
            f.write(data)
        return out_path

    code_l = code.lower()
    code_u = code.upper()
    out_path = os.path.join(dest_dir, f"{code_u}.pdb")

    if os.path.exists(out_path) and not overwrite:
        print(f"✔ Using existing file: {out_path}")
        return out_path

    # 1) Try direct .pdb
    url1 = f"https://files.rcsb.org/download/{code_l}.pdb"
    try:
        urllib.request.urlretrieve(url1, out_path)
        print(f"✅ Downloaded {code_u} from {url1}")
        return out_path
    except (HTTPError, URLError) as e1:
        # 2) Fallback: divided gz path (need to compute subfolder and gunzip)
        subdir = code_l[1:3]  # 2nd–3rd characters
        url2 = f"https://files.rcsb.org/pub/pdb/data/structures/divided/pdb/{subdir}/pdb{code_l}.ent.gz"
        gz_path = os.path.join(dest_dir, f"pdb{code_l}.ent.gz")
        try:
            urllib.request.urlretrieve(url2, gz_path)
            with gzip.open(gz_path, "rb") as f_in, open(out_path, "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)
            os.remove(gz_path)
            print(f"✅ Downloaded and decompressed {code_u} from {url2}")
            return out_path
        except (HTTPError, URLError, OSError) as e2:
            if os.path.exists(gz_path):
                try: os.remove(gz_path)
                except: pass
            print(f"❌ Could not download {code_u}.\n - {url1} error: {e1}\n - {url2} error: {e2}")
            return None



def analyze_amino_acid_composition(sequences, labels=None):
    """Analyze amino acid composition of sequences"""
    if labels is None:
        labels = [f"Seq_{i}" for i in range(len(sequences))]

    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    composition_data = []

    for seq, label in zip(sequences, labels):
        clean_seq = seq.replace('/', '').replace('X', '')  # Remove separators and unknown
        total_length = len(clean_seq)

        for aa in amino_acids:
            count = clean_seq.count(aa)
            percentage = (count / total_length) * 100 if total_length > 0 else 0
            composition_data.append({
                'Sequence': label,
                'Amino_Acid': aa,
                'Count': count,
                'Percentage': percentage
            })

    return pd.DataFrame(composition_data)

def calculate_sequence_metrics(native_seq, designed_seqs, scores):
    """Calculate comprehensive sequence metrics"""
    metrics = {
        'sequence_recovery': [],
        'identity': [],
        'score': [],
        'length': [],
        'charged_residue_recovery': [],
        'hydrophobic_recovery': []
    }

    native_clean = native_seq.replace('/', '').replace('X', '')
    charged_residues = set('DEKR')
    hydrophobic_residues = set('AILMFPWY')

    for designed_seq, score in zip(designed_seqs, scores):
        designed_clean = designed_seq.replace('/', '').replace('X', '')

        # Basic metrics
        length = min(len(native_clean), len(designed_clean))
        if length == 0:
            continue

        identical = sum(1 for a, b in zip(native_clean[:length], designed_clean[:length]) if a == b)
        identity = (identical / length) * 100

        metrics['sequence_recovery'].append(identity)
        metrics['identity'].append(identity)
        metrics['score'].append(float(score))
        metrics['length'].append(length)

        # Charged residue recovery
        native_charged_pos = [i for i, aa in enumerate(native_clean[:length]) if aa in charged_residues]
        if native_charged_pos:
            charged_recovery = sum(1 for pos in native_charged_pos
                                 if pos < len(designed_clean) and designed_clean[pos] in charged_residues)
            charged_recovery_rate = (charged_recovery / len(native_charged_pos)) * 100
        else:
            charged_recovery_rate = 0
        metrics['charged_residue_recovery'].append(charged_recovery_rate)

        # Hydrophobic recovery
        native_hydrophobic_pos = [i for i, aa in enumerate(native_clean[:length]) if aa in hydrophobic_residues]
        if native_hydrophobic_pos:
            hydrophobic_recovery = sum(1 for pos in native_hydrophobic_pos
                                     if pos < len(designed_clean) and designed_clean[pos] in hydrophobic_residues)
            hydrophobic_recovery_rate = (hydrophobic_recovery / len(native_hydrophobic_pos)) * 100
        else:
            hydrophobic_recovery_rate = 0
        metrics['hydrophobic_recovery'].append(hydrophobic_recovery_rate)

    return pd.DataFrame(metrics)

print("✅ Helper functions defined!")

✅ Helper functions defined!


In [7]:
#@title Cell 5: Configuration and Input Setup

import numpy as np # Import numpy

pdb_codes = ['3JAY', # 5 chains, 919 charged residues
             '3JB0', # 5 chains, 919 charged residues
             '5A1A', # 4 chains, 984 charged residues
             '5FTJ', # 6 chains, 1326 charged residues
             '5FTK', # 6 chains, 1326 charged residues
             '5K12', # 6 chains, 474 charged residues
             '5L35', # 7 chains, 511 charged residues
             '5MDO', # 6 chains, 468 charged residues
             '5MDR', # 12 chains, 468 charged residues
             '5MF4', # 6 chains, 566 charged residues
             '5MFM', # 8 chains, 362 charged residues
             '5MH6', # 4 chains, 350 charged residues
             '5MHF', # 8 chains, 626 charged residues
             '5MIW', # 6 chains, 342 charged residues
             '5MJY', # 6 chains, 396 charged residues
             '5MK1', # 8 chains, 393 charged residues
             '5MK3', # 8 chains, 395 charged residues
             '5MKM', # 6 chains, 442 charged residues
             '5MKN', # 28 chains, 595 charged residues
             '5MLD', # 8 chains, 584 charged residues
             '5MNS', # 6 chains, 671 charged residues
             '5MNV', # 9 chains, 977 charged residues
             '5MQZ', # 6 chains, 469 charged residues
             '5MR0', # 6 chains, 474 charged residues
             '5MUX', # 6 chains, 822 charged residues
             '5MX5', # 14 chains, 873 charged residues
             '5MY0', # 4 chains, 722 charged residues
             '5MY2', # 4 chains, 723 charged residues
             '5MZ2', # 16 chains, 1125 charged residues
             '5MZ5', # 4 chains, 532 charged residues
             ]

design_config = {'3JAY': {'designed_chains': ['B'], 'fixed_chains': ['A', 'C', 'D', 'E']},
                 '3JB0': {'designed_chains': ['B'], 'fixed_chains': ['A', 'C', 'D', 'E']},
                 '5A1A': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']},
                 '5FTJ': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5FTK': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5K12': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5L35': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G']},
                 '5MDO': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MDR': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MF4': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MFM': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H']},
                 '5MH6': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']},
                 '5MHF': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']},
                 '5MIW': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MJY': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MK1': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H']},
                 '5MK3': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H']},
                 '5MKM': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MKN': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'd']},
                 '5MLD': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H']},
                 '5MNS': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MNV': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']},
                 '5MQZ': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MR0': {'designed_chains': ['F'], 'fixed_chains': ['A', 'B', 'C', 'D', 'E']},
                 '5MUX': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F']},
                 '5MX5': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']},
                 '5MY0': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']},
                 '5MY2': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']},
                 '5MZ2': {'designed_chains': ['A'], 'fixed_chains': ['C', 'H', 'F', 'D', 'B', 'E', 'G', 'I', 'O', 'L', 'N', 'M', 'P', 'J', 'K']},
                 '5MZ5': {'designed_chains': ['A'], 'fixed_chains': ['B', 'C', 'D']}}

# Design parameters
num_seqs = 6
sampling_temp = "0.1"
homomer = False

# Advanced configuration options
batch_size = 1
max_length = 20000
omit_AAs = 'X'  # Omit unknown amino acids
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_np = np.array([AA in omit_AAs for AA in alphabet]).astype(np.float32)

print(f"✅ Configured to process {len(pdb_codes)} proteins")
print(f"✅ Will generate {num_seqs} sequences per protein")
print(f"✅ Using sampling temperature: {sampling_temp}")
print()

print("PROTEIN CONFIGURATION SUMMARY:")
print("="*60)
print(f"{'PDB':<6} {'Chains':<12} {'Designed':<10} {'Fixed':<15} {'Charged Res':<12}")
print("-"*60)

protein_data = {'3JAY': {'chains': 5, 'charged': 919},
                              '3JB0': {'chains': 5, 'charged': 919},
                              '5A1A': {'chains': 4, 'charged': 984},
                              '5FTJ': {'chains': 6, 'charged': 1326},
                              '5FTK': {'chains': 6, 'charged': 1326},
                              '5K12': {'chains': 6, 'charged': 474},
                              '5L35': {'chains': 7, 'charged': 511},
                              '5MDO': {'chains': 6, 'charged': 468},
                              '5MDR': {'chains': 12, 'charged': 468},
                              '5MF4': {'chains': 6, 'charged': 566},
                              '5MFM': {'chains': 8, 'charged': 362},
                              '5MH6': {'chains': 4, 'charged': 350},
                              '5MHF': {'chains': 8, 'charged': 626},
                              '5MIW': {'chains': 6, 'charged': 342},
                              '5MJY': {'chains': 6, 'charged': 396},
                              '5MK1': {'chains': 8, 'charged': 393},
                              '5MK3': {'chains': 8, 'charged': 395},
                              '5MKM': {'chains': 6, 'charged': 442},
                              '5MKN': {'chains': 28, 'charged': 595},
                              '5MLD': {'chains': 8, 'charged': 584},
                              '5MNS': {'chains': 6, 'charged': 671},
                              '5MNV': {'chains': 9, 'charged': 977},
                              '5MQZ': {'chains': 6, 'charged': 469},
                              '5MR0': {'chains': 6, 'charged': 474},
                              '5MUX': {'chains': 6, 'charged': 822},
                              '5MX5': {'chains': 14, 'charged': 873},
                              '5MY0': {'chains': 4, 'charged': 722},
                              '5MY2': {'chains': 4, 'charged': 723},
                              '5MZ2': {'chains': 16, 'charged': 1125},
                              '5MZ5': {'chains': 4, 'charged': 532}
                              }

for pdb_code in pdb_codes:
    if pdb_code in design_config:
        config = design_config[pdb_code]
        data = protein_data.get(pdb_code, {'chains': '?', 'charged': '?'})
        designed_str = ', '.join(config['designed_chains'])
        fixed_str = ', '.join(config['fixed_chains']) if config['fixed_chains'] else 'None'
        print(f"{pdb_code:<6} {data['chains']:<12} {designed_str:<10} {fixed_str:<15} {data['charged']:<12}")

print("="*60)
print()

print("🧪 TESTING RECOMMENDATIONS:")
print("1. Start with 2-3 proteins first to test the pipeline")
print("2. Recommended testing order:")
print("   - 9CDF (2 chains, 316 charged residues)")
print("   - 9IR2 (1 chain, 63 charged residues)")
print("3. Once working, add more proteins gradually")
print("4. For 8W2U (20 chains), you may need to specify exact chains later")
print()


print("⚠️  IMPORTANT NOTES:")
print("- If any chain IDs are wrong, the notebook will tell you what chains exist")
print("- You can modify the design_config above based on those error messages")
print("- For complex multi-chain proteins, you may want to check PDB files manually")
print("- Start with fewer proteins and increase batch size once everything works")

print()
print("🔍 VALIDATION CHECKLIST:")
print("- All proteins have reasonable chain configurations")
print("- Single chains are designed entirely")
print("- Multi-chain proteins have one designed, others fixed")
print("- Ready to run Cell 7 after this configuration")

✅ Configured to process 30 proteins
✅ Will generate 6 sequences per protein
✅ Using sampling temperature: 0.1

PROTEIN CONFIGURATION SUMMARY:
PDB    Chains       Designed   Fixed           Charged Res 
------------------------------------------------------------
3JAY   5            B          A, C, D, E      919         
3JB0   5            B          A, C, D, E      919         
5A1A   4            A          B, C, D         984         
5FTJ   6            A          B, C, D, E, F   1326        
5FTK   6            A          B, C, D, E, F   1326        
5K12   6            A          B, C, D, E, F   474         
5L35   7            A          B, C, D, E, F, G 511         
5MDO   6            A          B, C, D, E, F   468         
5MDR   12           A          B, C, D, E, F   468         
5MF4   6            A          B, C, D, E, F   566         
5MFM   8            A          B, C, D, E, F, G, H 362         
5MH6   4            A          B, C, D         350         
5MHF   8    

In [8]:
#@title Cell 6: Process Single Protein Function
def process_single_protein(pdb_code, designed_chains, fixed_chains, num_sequences=4, temperature=0.1):
    """Process a single protein and return results"""
    print(f"\n{'='*50}")
    print(f"Processing {pdb_code}")
    print(f"{'='*50}")

    # Import necessary libraries and functions within the function
    import numpy as np
    import copy
    import torch
    import sys
    try:
        from protein_mpnn_utils import parse_PDB, StructureDatasetPDB, tied_featurize, _scores, _S_to_seq, ProteinMPNN, StructureDataset # Added ProteinMPNN and StructureDataset
    except ModuleNotFoundError:
        # Fallback/Diagnostic: If direct import fails, try adding path and importing again
        proteinmpnn_path = '/content/ProteinMPNN'
        if proteinmpnn_path not in sys.path:
             sys.path.insert(0, proteinmpnn_path)
             print(f"Attempting to add {proteinmpnn_path} to sys.path from within function.")
        try:
             from protein_mpnn_utils import parse_PDB, StructureDatasetPDB, tied_featurize, _scores, _S_to_seq, ProteinMPNN, StructureDataset # Added ProteinMPNN and StructureDataset
             print("Successfully imported protein_mpnn_utils after adding path within function.")
        except ModuleNotFoundError:
             print(f"❌ Still unable to import protein_mpnn_utils even after adding path within function.")
             return None # Cannot proceed without the module


    pdb_path = get_pdb_file(pdb_code)
    if pdb_path is None:
        return None

    try:
        # Parse chains
        chain_list = list(set(designed_chains + fixed_chains))

        # Parse PDB
        pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

        # Setup chain configuration
        chain_id_dict = {pdb_dict_list[0]['name']: (designed_chains, fixed_chains)}

        print(f"Chain configuration: {chain_id_dict}")
        for chain in chain_list:
            # Check if the chain exists in the parsed PDB data before accessing its length
            if f"seq_chain_{chain}" in pdb_dict_list[0]:
                l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
                print(f"Length of chain {chain}: {l}")
            else:
                print(f"Warning: Chain {chain} not found in PDB file {pdb_code}.")


        # Setup tied positions if homomer
        tied_positions_dict = None
        if homomer:
            tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)

        # Initialize parameters
        NUM_BATCHES = num_sequences // batch_size
        BATCH_COPIES = batch_size
        temperatures = [temperature]

        # Initialize dictionaries
        fixed_positions_dict = None
        pssm_dict = None
        omit_AA_dict = None
        bias_by_res_dict = None
        bias_AAs_np = np.zeros(len(alphabet))

        # Storage for results
        results = {
            'pdb_code': pdb_code,
            'sequences': [],
            'scores': [],
            'recovery_rates': [],
            'temperatures': [],
            'native_sequence': '',
            'native_score': 0
        }

        # Process protein
        with torch.no_grad():
            for ix, protein in enumerate(dataset_valid):
                batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]

                # Featurize
                X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(
                    batch_clones, device, chain_id_dict, fixed_positions_dict,
                    omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict
                )

                pssm_log_odds_mask = (pssm_log_odds_all > 0.0).float()
                name_ = batch_clones[0]['name']

                # Calculate native score
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S, log_probs, mask_for_loss)
                native_score = scores.cpu().data.numpy().mean()
                results['native_score'] = float(native_score)

                # Generate sequences
                for temp in temperatures:
                    for j in range(NUM_BATCHES):
                        randn_2 = torch.randn(chain_M.shape, device=X.device)

                        # Sample sequences
                        if tied_positions_dict is None:
                            sample_dict = model.sample(
                                X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
                                mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
                                bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                                omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
                                pssm_bias=pssm_bias, pssm_multi=0.0,
                                pssm_log_odds_flag=False, pssm_log_odds_mask=pssm_log_odds_mask,
                                pssm_bias_flag=False, bias_by_res=bias_by_res_all
                            )
                        else:
                            sample_dict = model.tied_sample(
                                X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
                                mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
                                bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                                omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
                                pssm_bias=pssm_bias, pssm_multi=0.0,
                                pssm_log_odds_flag=False, pssm_log_odds_mask=pssm_log_odds_mask,
                                pssm_bias_flag=False, tied_pos=tied_pos_list_of_lists_list[0],
                                tied_beta=tied_beta, bias_by_res=bias_by_res_all
                            )

                        S_sample = sample_dict["S"]

                        # Score designed sequences
                        log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx,
                                        chain_encoding_all, randn_2, use_input_decoding_order=True,
                                        decoding_order=sample_dict["decoding_order"])
                        scores = _scores(S_sample, log_probs, mask_for_loss)
                        scores = scores.cpu().data.numpy()

                        # Process results
                        for b_ix in range(BATCH_COPIES):
                            masked_chain_length_list = masked_chain_length_list_list[b_ix]
                            masked_list = masked_list_list[b_ix]

                            # Calculate recovery
                            seq_recovery_rate = torch.sum(
                                torch.sum(torch.nn.functional.one_hot(S[b_ix], 21) *
                                         torch.nn.functional.one_hot(S_sample[b_ix], 21), axis=-1) *
                                mask_for_loss[b_ix]
                            ) / torch.sum(mask_for_loss[b_ix])

                            # Convert sequences
                            seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                            native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])

                            # Format sequences
                            if results['native_sequence'] == '':  # First time
                                start, end = 0, 0
                                list_of_AAs = []
                                for mask_l in masked_chain_length_list:
                                    end += mask_l
                                    list_of_AAs.append(native_seq[start:end])
                                    start = end
                                native_formatted = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                l0 = 0
                                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                    l0 += mc_length
                                    native_formatted = native_formatted[:l0] + '/' + native_formatted[l0:]
                                    l0 += 1
                                results['native_sequence'] = native_formatted

                            # Format designed sequence
                            start, end = 0, 0
                            list_of_AAs = []
                            for mask_l in masked_chain_length_list:
                                end += mask_l
                                list_of_AAs.append(seq[start:end])
                                start = end
                            seq_formatted = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                            l0 = 0
                            for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                l0 += mc_length
                                seq_formatted = seq_formatted[:l0] + '/' + seq_formatted[l0:]
                                l0 += 1

                            # Store results
                            results['sequences'].append(seq_formatted)
                            results['scores'].append(float(scores[b_ix]))
                            results['recovery_rates'].append(float(seq_recovery_rate.detach().cpu().numpy()))
                            results['temperatures'].append(temp)

                            print(f"Generated sequence {len(results['sequences'])}: Recovery={seq_recovery_rate:.3f}, Score={scores[b_ix]:.4f}")

        print(f"✅ Successfully processed {pdb_code}: {len(results['sequences'])} sequences generated")
        return results

    except Exception as e:
        print(f"❌ Error processing {pdb_code}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

print("✅ Single protein processing function ready!")

✅ Single protein processing function ready!


In [None]:
#@title Cell 7: Process All Proteins
import sys
# Ensure the ProteinMPNN directory is in the system path
proteinmpnn_path = '/content/ProteinMPNN'
if proteinmpnn_path not in sys.path:
    sys.path.insert(0, proteinmpnn_path) # Use insert to prioritize this path

all_protein_results = {}

for pdb_code in pdb_codes:
    if pdb_code in design_config:
        config = design_config[pdb_code]
        result = process_single_protein(
            pdb_code,
            config['designed_chains'],
            config['fixed_chains'],
            num_sequences=num_seqs,
            temperature=float(sampling_temp)
        )
        if result is not None:
            all_protein_results[pdb_code] = result
    else:
        print(f"⚠️ No configuration found for {pdb_code}, skipping...")

print(f"\n✅ Processing complete! Successfully processed {len(all_protein_results)} proteins")


Processing 3JAY
✅ Downloaded and decompressed 3JAY from https://files.rcsb.org/pub/pdb/data/structures/divided/pdb/ja/pdb3jay.ent.gz
Chain configuration: {'3JAY': (['B'], ['A', 'C', 'D', 'E'])}
Length of chain D: 292
Length of chain E: 292
Length of chain B: 1199
Length of chain C: 1260
Length of chain A: 1057
Generated sequence 1: Recovery=0.417, Score=0.8800
Generated sequence 2: Recovery=0.428, Score=0.8724
Generated sequence 3: Recovery=0.400, Score=0.8763
Generated sequence 4: Recovery=0.416, Score=0.8859
Generated sequence 5: Recovery=0.412, Score=0.8781
Generated sequence 6: Recovery=0.420, Score=0.8798
✅ Successfully processed 3JAY: 6 sequences generated

Processing 3JB0
✅ Downloaded and decompressed 3JB0 from https://files.rcsb.org/pub/pdb/data/structures/divided/pdb/jb/pdb3jb0.ent.gz
Chain configuration: {'3JB0': (['B'], ['A', 'C', 'D', 'E'])}
Length of chain D: 292
Length of chain E: 292
Length of chain B: 1199
Length of chain C: 1260
Length of chain A: 1057
Generated seque

In [None]:
#@title Cell 8: Analyze Results and Create Visualizations
if len(all_protein_results) == 0:
    print("❌ No results to analyze. Please check the previous cells.")
else:
    print("📊 Analyzing results...")

    # Combine all results
    combined_metrics = []
    summary_stats = []

    for pdb_code, results in all_protein_results.items():
        if len(results['sequences']) > 0:
            # Calculate metrics for this protein
            metrics_df = calculate_sequence_metrics(
                results['native_sequence'],
                results['sequences'],
                results['scores']
            )
            metrics_df['pdb_code'] = pdb_code
            metrics_df['temperature'] = results['temperatures']
            metrics_df['native_score'] = results['native_score']

            combined_metrics.append(metrics_df)

            # Summary statistics
            summary_stats.append({
                'PDB': pdb_code,
                'Sequences': len(results['sequences']),
                'Mean_Recovery': metrics_df['sequence_recovery'].mean(),
                'Std_Recovery': metrics_df['sequence_recovery'].std(),
                'Mean_Score': metrics_df['score'].mean(),
                'Best_Score': metrics_df['score'].min(),
                'Native_Score': results['native_score'],
                'Charged_Recovery': metrics_df['charged_residue_recovery'].mean(),
                'Hydrophobic_Recovery': metrics_df['hydrophobic_recovery'].mean()
            })

    if combined_metrics:
        # Combine all metrics
        all_metrics_df = pd.concat(combined_metrics, ignore_index=True)
        summary_df = pd.DataFrame(summary_stats)

        print("📈 SUMMARY STATISTICS")
        print("="*60)
        print(summary_df.round(2).to_string(index=False))

        # Create visualizations
        fig = make_subplots(
            rows=2, cols=3,
            subplot_titles=(
                'Sequence Recovery by Protein',
                'Score Distribution by Protein',
                'Recovery vs Score',
                'Charged Residue Recovery',
                'Overall Performance Summary',
                'Score Comparison (Native vs Designed)'
            ),
            specs=[[{"type": "box"}, {"type": "box"}, {"type": "scatter"}],
                   [{"type": "box"}, {"type": "bar"}, {"type": "bar"}]]
        )

        # Row 1
        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['sequence_recovery'], name='Recovery'),
            row=1, col=1
        )

        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['score'], name='Scores'),
            row=1, col=2
        )

        fig.add_trace(
            go.Scatter(x=all_metrics_df['score'], y=all_metrics_df['sequence_recovery'],
                      text=all_metrics_df['pdb_code'], mode='markers+text', name='Recovery vs Score'),
            row=1, col=3
        )

        # Row 2
        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['charged_residue_recovery'], name='Charged Recovery'),
            row=2, col=1
        )

        fig.add_trace(
            go.Bar(x=summary_df['PDB'], y=summary_df['Mean_Recovery'], name='Mean Recovery'),
            row=2, col=2
        )

        # Score comparison
        score_data = []
        score_labels = []
        score_colors = []
        for _, row in summary_df.iterrows():
            score_data.extend([row['Native_Score'], row['Mean_Score']])
            score_labels.extend([f"{row['PDB']}_Native", f"{row['PDB']}_Designed"])
            score_colors.extend(['blue', 'red'])

        fig.add_trace(
            go.Bar(x=score_labels, y=score_data, name='Score Comparison',
                  marker_color=score_colors),
            row=2, col=3
        )

        fig.update_layout(height=800, showlegend=False,
                         title_text="ProteinMPNN Baseline Performance Dashboard")
        fig.show()

        # Export results
        timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")

        # Save detailed metrics
        all_metrics_df.to_csv(f'baseline_detailed_metrics_{timestamp}.csv', index=False)
        summary_df.to_csv(f'baseline_summary_{timestamp}.csv', index=False)

        # Save sequences
        with open(f'baseline_sequences_{timestamp}.fasta', 'w') as f:
            for pdb_code, results in all_protein_results.items():
                f.write(f">NATIVE_{pdb_code}_score_{results['native_score']:.4f}\n")
                f.write(f"{results['native_sequence']}\n")

                for i, (seq, score, recovery) in enumerate(zip(
                    results['sequences'], results['scores'], results['recovery_rates']
                )):
                    f.write(f">DESIGNED_{pdb_code}_{i+1}_score_{score:.4f}_recovery_{recovery:.3f}\n")
                    f.write(f"{seq}\n")

        # Final summary
        print(f"\n🎯 BASELINE TARGETS FOR ESM-2 ENHANCEMENT:")
        print(f"Overall Mean Recovery: {all_metrics_df['sequence_recovery'].mean():.2f}% → Target: >{all_metrics_df['sequence_recovery'].mean() + 10:.2f}%")
        print(f"Charged Recovery: {all_metrics_df['charged_residue_recovery'].mean():.2f}% → Target: >{all_metrics_df['charged_residue_recovery'].mean() + 15:.2f}%")
        print(f"Mean Score: {all_metrics_df['score'].mean():.4f} → Target: <{all_metrics_df['score'].mean() - 0.1:.4f}")

        print(f"\n💾 Results saved:")
        print(f"- Detailed metrics: baseline_detailed_metrics_{timestamp}.csv")
        print(f"- Summary: baseline_summary_{timestamp}.csv")
        print(f"- Sequences: baseline_sequences_{timestamp}.fasta")

        print(f"\n🚀 Ready for ESM-2 integration!")


❌ No results to analyze. Please check the previous cells.
