<a href="https://colab.research.google.com/github/neetushibu/IontheFold-Team6/blob/main/IonTheFold001.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ProteinMPNN Baseline Evaluation for Ion the Fold Project

In [None]:
!nvidia-smi

Wed Jul 30 07:19:47 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   55C    P8             13W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
#@title Cell 1: Install Dependencies and Setup
import subprocess
import sys

def install_packages():
    packages = [
        'plotly', 'seaborn', 'biopython', 'matplotlib', 'pandas', 'numpy', 'scipy'
    ]
    for package in packages:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])

install_packages()
print("✅ All packages installed!")

✅ All packages installed!


In [None]:
#@title Cell 2: Clone Repository and Import Libraries
import json, time, os, sys, glob
import warnings
warnings.filterwarnings('ignore')

if not os.path.isdir("ProteinMPNN"):
    os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
from scipy import stats
from collections import defaultdict, Counter
import re
from google.colab import files

from protein_mpnn_utils import (
    loss_nll, loss_smoothed, gather_edges, gather_nodes,
    gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq,
    tied_featurize, parse_PDB, StructureDataset,
    StructureDatasetPDB, ProteinMPNN
)

print("✅ Repository cloned and libraries imported!")

✅ Repository cloned and libraries imported!


In [None]:
#@title Cell 3: Setup Model and Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model configuration
model_name = "v_48_020"  # Options: v_48_002, v_48_010, v_48_020, v_48_030
backbone_noise = 0.00

# Load model
path_to_model_weights = '/content/ProteinMPNN/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print('Number of edges:', checkpoint['num_edges'])
    print(f'Training noise level: {checkpoint["noise_level"]}A')

    model = ProteinMPNN(
        num_letters=21,
        node_features=hidden_dim,
        edge_features=hidden_dim,
        hidden_dim=hidden_dim,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        augment_eps=backbone_noise,
        k_neighbors=checkpoint['num_edges']
    )
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")

Using device: cuda:0
Number of edges: 48
Training noise level: 0.2A
✅ Model loaded successfully!


In [None]:
#@title Cell 4: Helper Functions
def make_tied_positions_for_homomers(pdb_dict_list):
    """Create tied positions for homomer proteins"""
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain'])
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1, chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i]
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict

def get_pdb_file(pdb_code):
    """Download PDB file or handle uploaded file"""
    if pdb_code is None or pdb_code == "":
        print("Please upload a PDB file:")
        upload_dict = files.upload()
        pdb_string = upload_dict[list(upload_dict.keys())[0]]
        with open("tmp.pdb", "wb") as out:
            out.write(pdb_string)
        return "tmp.pdb"
    else:
        # Try to download from RCSB
        import urllib.request
        try:
            url = f"https://files.rcsb.org/view/{pdb_code}.pdb"
            urllib.request.urlretrieve(url, f"{pdb_code}.pdb")
            print(f"✅ Downloaded {pdb_code}.pdb")
            return f"{pdb_code}.pdb"
        except:
            print(f"❌ Could not download {pdb_code}, please upload manually")
            return None

def analyze_amino_acid_composition(sequences, labels=None):
    """Analyze amino acid composition of sequences"""
    if labels is None:
        labels = [f"Seq_{i}" for i in range(len(sequences))]

    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    composition_data = []

    for seq, label in zip(sequences, labels):
        clean_seq = seq.replace('/', '').replace('X', '')  # Remove separators and unknown
        total_length = len(clean_seq)

        for aa in amino_acids:
            count = clean_seq.count(aa)
            percentage = (count / total_length) * 100 if total_length > 0 else 0
            composition_data.append({
                'Sequence': label,
                'Amino_Acid': aa,
                'Count': count,
                'Percentage': percentage
            })

    return pd.DataFrame(composition_data)

def calculate_sequence_metrics(native_seq, designed_seqs, scores):
    """Calculate comprehensive sequence metrics"""
    metrics = {
        'sequence_recovery': [],
        'identity': [],
        'score': [],
        'length': [],
        'charged_residue_recovery': [],
        'hydrophobic_recovery': []
    }

    native_clean = native_seq.replace('/', '').replace('X', '')
    charged_residues = set('DEKR')
    hydrophobic_residues = set('AILMFPWY')

    for designed_seq, score in zip(designed_seqs, scores):
        designed_clean = designed_seq.replace('/', '').replace('X', '')

        # Basic metrics
        length = min(len(native_clean), len(designed_clean))
        if length == 0:
            continue

        identical = sum(1 for a, b in zip(native_clean[:length], designed_clean[:length]) if a == b)
        identity = (identical / length) * 100

        metrics['sequence_recovery'].append(identity)
        metrics['identity'].append(identity)
        metrics['score'].append(float(score))
        metrics['length'].append(length)

        # Charged residue recovery
        native_charged_pos = [i for i, aa in enumerate(native_clean[:length]) if aa in charged_residues]
        if native_charged_pos:
            charged_recovery = sum(1 for pos in native_charged_pos
                                 if pos < len(designed_clean) and designed_clean[pos] in charged_residues)
            charged_recovery_rate = (charged_recovery / len(native_charged_pos)) * 100
        else:
            charged_recovery_rate = 0
        metrics['charged_residue_recovery'].append(charged_recovery_rate)

        # Hydrophobic recovery
        native_hydrophobic_pos = [i for i, aa in enumerate(native_clean[:length]) if aa in hydrophobic_residues]
        if native_hydrophobic_pos:
            hydrophobic_recovery = sum(1 for pos in native_hydrophobic_pos
                                     if pos < len(designed_clean) and designed_clean[pos] in hydrophobic_residues)
            hydrophobic_recovery_rate = (hydrophobic_recovery / len(native_hydrophobic_pos)) * 100
        else:
            hydrophobic_recovery_rate = 0
        metrics['hydrophobic_recovery'].append(hydrophobic_recovery_rate)

    return pd.DataFrame(metrics)

print("✅ Helper functions defined!")

✅ Helper functions defined!


In [None]:
#@title Cell 5: Configuration and Input Setup

pdb_codes = [
    '6ICZ',  # 51 chains, 4150 charged residues
    '6ID1',  # 43 chains, 3202 charged residues
    '6ID0',  # 42 chains, 3119 charged residues
    '9DTR',  # 47 chains, 2964 charged residues
    '8XI2',  # 34 chains, 2780 charged residues
    '9N4V',  # 48 chains, 2760 charged residues
    '9ES0',  # 28 chains, 2464 charged residues
    '9ES4',  # 28 chains, 2478 charged residues
    '8QKM',  # 60 chains, 2160 charged residues
    '5XNL',  # 56 chains, 1776 charged residues
    '9EZM',  # 18 chains, 2340 charged residues
    '6KS6',  # 16 chains, 2324 charged residues
    '9HVQ',  # 24 chains, 2311 charged residues
    '9GK2',  # 18 chains, 2142 charged residues
    '9BP5',  # 12 chains, 2108 charged residues
    '9I1R',  # 50 chains, 2106 charged residues
    '8BAP',  # 16 chains, 2096 charged residues
    '9F5Y',  # 51 chains, 2084 charged residues
    '9ES2',  # 14 chains, 2072 charged residues
    '8IMK',  # 54 chains, 2062 charged residues
    '8IMI',  # 52 chains, 2047 charged residues
    '9M02',  # 52 chains, 2047 charged residues
    '8VEH',  # 29 chains, 2034 charged residues
    '9DTQ',  # 29 chains, 2016 charged residues
    '8WXY',  # 20 chains, 2000 charged residues
]

design_config = {
    '6ICZ': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y']
    },
    '6ID1': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q']
    },
    '6ID0': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p']
    },
    '9DTR': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u']
    },
    '8XI2': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
    },
    '9N4V': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v']
    },
    '9ES0': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b']
    },
    '9ES4': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b']
    },
    '8QKM': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    },
    '5XNL': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    },
    '9EZM': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R']
    },
    '6KS6': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']
    },
    '9HVQ': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X']
    },
    '9GK2': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R']
    },
    '9BP5': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
    },
    '9I1R': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x']
    },
    '8BAP': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']
    },
    '9F5Y': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y']
    },
    '9ES2': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']
    },
    '8IMK': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    },
    '8IMI': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    },
    '9M02': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    },
    '8VEH': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c']
    },
    '9DTQ': {
        'designed_chains': ['A'],
        'fixed_chains': ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c']
    },
    '8WXY': {
        'designed_chains': ['A', 'B'],
        'fixed_chains': ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T']
    },
}

# Design parameters
num_seqs = 6
sampling_temp = "0.1"
homomer = False

# Advanced configuration options
batch_size = 1
max_length = 20000
omit_AAs = 'X'  # Omit unknown amino acids
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_np = np.array([AA in omit_AAs for AA in alphabet]).astype(np.float32)

print(f"✅ Configured to process {len(pdb_codes)} proteins")
print(f"✅ Will generate {num_seqs} sequences per protein")
print(f"✅ Using sampling temperature: {sampling_temp}")
print()

print("PROTEIN CONFIGURATION SUMMARY:")
print("="*60)
print(f"{'PDB':<6} {'Chains':<12} {'Designed':<10} {'Fixed':<15} {'Charged Res':<12}")
print("-"*60)

protein_data = {
    '6ICZ': {'chains': 51, 'charged': 4150},
    '6ID1': {'chains': 43, 'charged': 3202},
    '6ID0': {'chains': 42, 'charged': 3119},
    '9DTR': {'chains': 47, 'charged': 2964},
    '8XI2': {'chains': 34, 'charged': 2780},
    '9N4V': {'chains': 48, 'charged': 2760},
    '9ES0': {'chains': 28, 'charged': 2464},
    '9ES4': {'chains': 28, 'charged': 2478},
    '8QKM': {'chains': 60, 'charged': 2160},
    '5XNL': {'chains': 56, 'charged': 1776},
    '9EZM': {'chains': 18, 'charged': 2340},
    '6KS6': {'chains': 16, 'charged': 2324},
    '9HVQ': {'chains': 24, 'charged': 2311},
    '9GK2': {'chains': 18, 'charged': 2142},
    '9BP5': {'chains': 12, 'charged': 2108},
    '9I1R': {'chains': 50, 'charged': 2106},
    '8BAP': {'chains': 16, 'charged': 2096},
    '9F5Y': {'chains': 51, 'charged': 2084},
    '9ES2': {'chains': 14, 'charged': 2072},
    '8IMK': {'chains': 54, 'charged': 2062},
    '8IMI': {'chains': 52, 'charged': 2047},
    '9M02': {'chains': 52, 'charged': 2047},
    '8VEH': {'chains': 29, 'charged': 2034},
    '9DTQ': {'chains': 29, 'charged': 2016},
    '8WXY': {'chains': 20, 'charged': 2000},
}

for pdb_code in pdb_codes:
    if pdb_code in design_config:
        config = design_config[pdb_code]
        data = protein_data.get(pdb_code, {'chains': '?', 'charged': '?'})
        designed_str = ', '.join(config['designed_chains'])
        fixed_str = ', '.join(config['fixed_chains']) if config['fixed_chains'] else 'None'
        print(f"{pdb_code:<6} {data['chains']:<12} {designed_str:<10} {fixed_str:<15} {data['charged']:<12}")

print("="*60)
print()

print("🧪 TESTING RECOMMENDATIONS:")
print("1. Start with 2-3 proteins first to test the pipeline")
print("2. Recommended testing order:")
print("   - 9CDF (2 chains, 316 charged residues)")
print("   - 9IR2 (1 chain, 63 charged residues)")
print("3. Once working, add more proteins gradually")
print("4. For 8W2U (20 chains), you may need to specify exact chains later")
print()


print("⚠️  IMPORTANT NOTES:")
print("- If any chain IDs are wrong, the notebook will tell you what chains exist")
print("- You can modify the design_config above based on those error messages")
print("- For complex multi-chain proteins, you may want to check PDB files manually")
print("- Start with fewer proteins and increase batch size once everything works")

print()
print("🔍 VALIDATION CHECKLIST:")
print("- All proteins have reasonable chain configurations")
print("- Single chains are designed entirely")
print("- Multi-chain proteins have one designed, others fixed")
print("- Ready to run Cell 7 after this configuration")

✅ Configured to process 11 proteins
✅ Will generate 6 sequences per protein
✅ Using sampling temperature: 0.1

PROTEIN CONFIGURATION SUMMARY:
PDB    Chains       Designed   Fixed           Charged Res 
------------------------------------------------------------
8W2U   20           A          B               1600        
9CDF   2            A          B               316         
9H39   9            A          B               127         
9IR2   1            A          None            63          
9DWY   1            A          None            40          
9VIC   2            A          B               72          
9CDJ   2            A          B               266         
9FU4   2            A          B               261         
9IMR   2            A          B               164         
9M0R   6            A          B               255         
9G04   2            A          B               652         

🧪 TESTING RECOMMENDATIONS:
1. Start with 2-3 proteins first to test the pipe

In [None]:
#@title Cell 6: Process Single Protein Function
def process_single_protein(pdb_code, designed_chains, fixed_chains, num_sequences=4, temperature=0.1):
    """Process a single protein and return results"""
    print(f"\n{'='*50}")
    print(f"Processing {pdb_code}")
    print(f"{'='*50}")

    pdb_path = get_pdb_file(pdb_code)
    if pdb_path is None:
        return None

    try:
        # Parse chains
        chain_list = list(set(designed_chains + fixed_chains))

        # Parse PDB
        pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

        # Setup chain configuration
        chain_id_dict = {pdb_dict_list[0]['name']: (designed_chains, fixed_chains)}

        print(f"Chain configuration: {chain_id_dict}")
        for chain in chain_list:
            l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
            print(f"Length of chain {chain}: {l}")

        # Setup tied positions if homomer
        tied_positions_dict = None
        if homomer:
            tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)

        # Initialize parameters
        NUM_BATCHES = num_sequences // batch_size
        BATCH_COPIES = batch_size
        temperatures = [temperature]

        # Initialize dictionaries
        fixed_positions_dict = None
        pssm_dict = None
        omit_AA_dict = None
        bias_by_res_dict = None
        bias_AAs_np = np.zeros(len(alphabet))

        # Storage for results
        results = {
            'pdb_code': pdb_code,
            'sequences': [],
            'scores': [],
            'recovery_rates': [],
            'temperatures': [],
            'native_sequence': '',
            'native_score': 0
        }

        # Process protein
        with torch.no_grad():
            for ix, protein in enumerate(dataset_valid):
                batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]

                # Featurize
                X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(
                    batch_clones, device, chain_id_dict, fixed_positions_dict,
                    omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict
                )

                pssm_log_odds_mask = (pssm_log_odds_all > 0.0).float()
                name_ = batch_clones[0]['name']

                # Calculate native score
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S, log_probs, mask_for_loss)
                native_score = scores.cpu().data.numpy().mean()
                results['native_score'] = float(native_score)

                # Generate sequences
                for temp in temperatures:
                    for j in range(NUM_BATCHES):
                        randn_2 = torch.randn(chain_M.shape, device=X.device)

                        # Sample sequences
                        if tied_positions_dict is None:
                            sample_dict = model.sample(
                                X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
                                mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
                                bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                                omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
                                pssm_bias=pssm_bias, pssm_multi=0.0,
                                pssm_log_odds_flag=False, pssm_log_odds_mask=pssm_log_odds_mask,
                                pssm_bias_flag=False, bias_by_res=bias_by_res_all
                            )
                        else:
                            sample_dict = model.tied_sample(
                                X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
                                mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
                                bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                                omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
                                pssm_bias=pssm_bias, pssm_multi=0.0,
                                pssm_log_odds_flag=False, pssm_log_odds_mask=pssm_log_odds_mask,
                                pssm_bias_flag=False, tied_pos=tied_pos_list_of_lists_list[0],
                                tied_beta=tied_beta, bias_by_res=bias_by_res_all
                            )

                        S_sample = sample_dict["S"]

                        # Score designed sequences
                        log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx,
                                        chain_encoding_all, randn_2, use_input_decoding_order=True,
                                        decoding_order=sample_dict["decoding_order"])
                        scores = _scores(S_sample, log_probs, mask_for_loss)
                        scores = scores.cpu().data.numpy()

                        # Process results
                        for b_ix in range(BATCH_COPIES):
                            masked_chain_length_list = masked_chain_length_list_list[b_ix]
                            masked_list = masked_list_list[b_ix]

                            # Calculate recovery
                            seq_recovery_rate = torch.sum(
                                torch.sum(torch.nn.functional.one_hot(S[b_ix], 21) *
                                         torch.nn.functional.one_hot(S_sample[b_ix], 21), axis=-1) *
                                mask_for_loss[b_ix]
                            ) / torch.sum(mask_for_loss[b_ix])

                            # Convert sequences
                            seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                            native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])

                            # Format sequences
                            if results['native_sequence'] == '':  # First time
                                start, end = 0, 0
                                list_of_AAs = []
                                for mask_l in masked_chain_length_list:
                                    end += mask_l
                                    list_of_AAs.append(native_seq[start:end])
                                    start = end
                                native_formatted = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                l0 = 0
                                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                    l0 += mc_length
                                    native_formatted = native_formatted[:l0] + '/' + native_formatted[l0:]
                                    l0 += 1
                                results['native_sequence'] = native_formatted

                            # Format designed sequence
                            start, end = 0, 0
                            list_of_AAs = []
                            for mask_l in masked_chain_length_list:
                                end += mask_l
                                list_of_AAs.append(seq[start:end])
                                start = end
                            seq_formatted = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                            l0 = 0
                            for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                l0 += mc_length
                                seq_formatted = seq_formatted[:l0] + '/' + seq_formatted[l0:]
                                l0 += 1

                            # Store results
                            results['sequences'].append(seq_formatted)
                            results['scores'].append(float(scores[b_ix]))
                            results['recovery_rates'].append(float(seq_recovery_rate.detach().cpu().numpy()))
                            results['temperatures'].append(temp)

                            print(f"Generated sequence {len(results['sequences'])}: Recovery={seq_recovery_rate:.3f}, Score={scores[b_ix]:.4f}")

        print(f"✅ Successfully processed {pdb_code}: {len(results['sequences'])} sequences generated")
        return results

    except Exception as e:
        print(f"❌ Error processing {pdb_code}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

print("✅ Single protein processing function ready!")

✅ Single protein processing function ready!


In [None]:
#@title Cell 7: Process All Proteins
all_protein_results = {}

for pdb_code in pdb_codes:
    if pdb_code in design_config:
        config = design_config[pdb_code]
        result = process_single_protein(
            pdb_code,
            config['designed_chains'],
            config['fixed_chains'],
            num_sequences=num_seqs,
            temperature=float(sampling_temp)
        )
        if result is not None:
            all_protein_results[pdb_code] = result
    else:
        print(f"⚠️ No configuration found for {pdb_code}, skipping...")

print(f"\n✅ Processing complete! Successfully processed {len(all_protein_results)} proteins")


Processing 8W2U
✅ Downloaded 8W2U.pdb
Chain configuration: {'8W2U': (['A'], ['B'])}
Length of chain A: 284
Length of chain B: 284
Generated sequence 1: Recovery=0.433, Score=0.8633
Generated sequence 2: Recovery=0.423, Score=0.8555
Generated sequence 3: Recovery=0.405, Score=0.8488
Generated sequence 4: Recovery=0.440, Score=0.8480
Generated sequence 5: Recovery=0.415, Score=0.8632
Generated sequence 6: Recovery=0.408, Score=0.8460
✅ Successfully processed 8W2U: 6 sequences generated

Processing 9CDF
✅ Downloaded 9CDF.pdb
Chain configuration: {'9CDF': (['A'], ['B'])}
Length of chain A: 547
Length of chain B: 551
Generated sequence 1: Recovery=0.344, Score=1.0049
Generated sequence 2: Recovery=0.354, Score=0.9817
Generated sequence 3: Recovery=0.366, Score=0.9950
Generated sequence 4: Recovery=0.344, Score=1.0054
Generated sequence 5: Recovery=0.331, Score=1.0127
Generated sequence 6: Recovery=0.343, Score=1.0179
✅ Successfully processed 9CDF: 6 sequences generated

Processing 9H39
✅ D

In [None]:
#@title Cell 8: Analyze Results and Create Visualizations
if len(all_protein_results) == 0:
    print("❌ No results to analyze. Please check the previous cells.")
else:
    print("📊 Analyzing results...")

    # Combine all results
    combined_metrics = []
    summary_stats = []

    for pdb_code, results in all_protein_results.items():
        if len(results['sequences']) > 0:
            # Calculate metrics for this protein
            metrics_df = calculate_sequence_metrics(
                results['native_sequence'],
                results['sequences'],
                results['scores']
            )
            metrics_df['pdb_code'] = pdb_code
            metrics_df['temperature'] = results['temperatures']
            metrics_df['native_score'] = results['native_score']

            combined_metrics.append(metrics_df)

            # Summary statistics
            summary_stats.append({
                'PDB': pdb_code,
                'Sequences': len(results['sequences']),
                'Mean_Recovery': metrics_df['sequence_recovery'].mean(),
                'Std_Recovery': metrics_df['sequence_recovery'].std(),
                'Mean_Score': metrics_df['score'].mean(),
                'Best_Score': metrics_df['score'].min(),
                'Native_Score': results['native_score'],
                'Charged_Recovery': metrics_df['charged_residue_recovery'].mean(),
                'Hydrophobic_Recovery': metrics_df['hydrophobic_recovery'].mean()
            })

    if combined_metrics:
        # Combine all metrics
        all_metrics_df = pd.concat(combined_metrics, ignore_index=True)
        summary_df = pd.DataFrame(summary_stats)

        print("📈 SUMMARY STATISTICS")
        print("="*60)
        print(summary_df.round(2).to_string(index=False))

        # Create visualizations
        fig = make_subplots(
            rows=2, cols=3,
            subplot_titles=(
                'Sequence Recovery by Protein',
                'Score Distribution by Protein',
                'Recovery vs Score',
                'Charged Residue Recovery',
                'Overall Performance Summary',
                'Score Comparison (Native vs Designed)'
            ),
            specs=[[{"type": "box"}, {"type": "box"}, {"type": "scatter"}],
                   [{"type": "box"}, {"type": "bar"}, {"type": "bar"}]]
        )

        # Row 1
        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['sequence_recovery'], name='Recovery'),
            row=1, col=1
        )

        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['score'], name='Scores'),
            row=1, col=2
        )

        fig.add_trace(
            go.Scatter(x=all_metrics_df['score'], y=all_metrics_df['sequence_recovery'],
                      text=all_metrics_df['pdb_code'], mode='markers+text', name='Recovery vs Score'),
            row=1, col=3
        )

        # Row 2
        fig.add_trace(
            go.Box(x=all_metrics_df['pdb_code'], y=all_metrics_df['charged_residue_recovery'], name='Charged Recovery'),
            row=2, col=1
        )

        fig.add_trace(
            go.Bar(x=summary_df['PDB'], y=summary_df['Mean_Recovery'], name='Mean Recovery'),
            row=2, col=2
        )

        # Score comparison
        score_data = []
        score_labels = []
        score_colors = []
        for _, row in summary_df.iterrows():
            score_data.extend([row['Native_Score'], row['Mean_Score']])
            score_labels.extend([f"{row['PDB']}_Native", f"{row['PDB']}_Designed"])
            score_colors.extend(['blue', 'red'])

        fig.add_trace(
            go.Bar(x=score_labels, y=score_data, name='Score Comparison',
                  marker_color=score_colors),
            row=2, col=3
        )

        fig.update_layout(height=800, showlegend=False,
                         title_text="ProteinMPNN Baseline Performance Dashboard")
        fig.show()

        # Export results
        timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")

        # Save detailed metrics
        all_metrics_df.to_csv(f'baseline_detailed_metrics_{timestamp}.csv', index=False)
        summary_df.to_csv(f'baseline_summary_{timestamp}.csv', index=False)

        # Save sequences
        with open(f'baseline_sequences_{timestamp}.fasta', 'w') as f:
            for pdb_code, results in all_protein_results.items():
                f.write(f">NATIVE_{pdb_code}_score_{results['native_score']:.4f}\n")
                f.write(f"{results['native_sequence']}\n")

                for i, (seq, score, recovery) in enumerate(zip(
                    results['sequences'], results['scores'], results['recovery_rates']
                )):
                    f.write(f">DESIGNED_{pdb_code}_{i+1}_score_{score:.4f}_recovery_{recovery:.3f}\n")
                    f.write(f"{seq}\n")

        # Final summary
        print(f"\n🎯 BASELINE TARGETS FOR ESM-2 ENHANCEMENT:")
        print(f"Overall Mean Recovery: {all_metrics_df['sequence_recovery'].mean():.2f}% → Target: >{all_metrics_df['sequence_recovery'].mean() + 10:.2f}%")
        print(f"Charged Recovery: {all_metrics_df['charged_residue_recovery'].mean():.2f}% → Target: >{all_metrics_df['charged_residue_recovery'].mean() + 15:.2f}%")
        print(f"Mean Score: {all_metrics_df['score'].mean():.4f} → Target: <{all_metrics_df['score'].mean() - 0.1:.4f}")

        print(f"\n💾 Results saved:")
        print(f"- Detailed metrics: baseline_detailed_metrics_{timestamp}.csv")
        print(f"- Summary: baseline_summary_{timestamp}.csv")
        print(f"- Sequences: baseline_sequences_{timestamp}.fasta")

        print(f"\n🚀 Ready for ESM-2 integration!")


📊 Analyzing results...
📈 SUMMARY STATISTICS
 PDB  Sequences  Mean_Recovery  Std_Recovery  Mean_Score  Best_Score  Native_Score  Charged_Recovery  Hydrophobic_Recovery
8W2U          6          42.08          1.39        0.85        0.85          1.66             67.46                 69.50
9CDF          6          34.71          1.20        1.00        0.98          1.96             56.49                 68.63
9H39          6          54.26          1.78        0.65        0.63          1.32             51.56                 87.56
9IR2          6          45.16          0.64        0.86        0.84          1.63             66.09                 77.90
9DWY          6          38.71          1.74        0.94        0.92          1.74             47.55                 84.09
9VIC          6          45.86          0.58        0.80        0.78          1.57             46.79                 81.92
9CDJ          6          41.51          0.97        0.87        0.86          1.68             


🎯 BASELINE TARGETS FOR ESM-2 ENHANCEMENT:
Overall Mean Recovery: 43.05% → Target: >53.05%
Charged Recovery: 59.53% → Target: >74.53%
Mean Score: 0.8611 → Target: <0.7611

💾 Results saved:
- Detailed metrics: baseline_detailed_metrics_20250728_192902.csv
- Summary: baseline_summary_20250728_192902.csv
- Sequences: baseline_sequences_20250728_192902.fasta

🚀 Ready for ESM-2 integration!
