In [54]:
import pandas as pd

In [55]:
df = pd.read_csv('sabdab_nr_summary.tsv',sep='\t')

In [56]:
df['Hchain'].unique()

array(['H', 'A', 'C', 'E', 'I', 'K', 'J', 'D', 'B', 'T', 'M', 'O', 'G',
       'P', 'Q', 'F', 'U', 'R', 'L', 'N', 'W', 'Y', 'X', 'h', 'S', 'Z',
       'V', 'b', 'c', 'f', 'd', '2', 'a', nan], dtype=object)

In [57]:
df.drop_duplicates(subset='pdb',keep='first',inplace=True)

In [58]:
test_df = df[df['resolution']<1.3]

In [11]:
import os
from Bio import PDB

def filter_pdb_keep_only(input_pdb, output_pdb, chains_to_keep):
    """
    Keeps only the specified chains in a PDB file, removing all others (including waters, salts, and non-protein molecules).

    Args:
        input_pdb (str): Path to the input PDB file.
        output_pdb (str): Path to save the modified PDB file.
        chains_to_keep (set): Set of chain IDs to retain.
    """
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("structure", input_pdb)

    # Create a new structure to store only the kept chains
    new_structure = PDB.Structure.Structure("filtered_structure")
    model = structure[0]  # Assume single model
    new_model = PDB.Model.Model(0)

    # Define standard amino acid residues (3-letter codes)
    standard_amino_acids = {
        "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
        "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"
    }

    for chain in model:
        if chain.id in chains_to_keep:
            new_chain = PDB.Chain.Chain(chain.id)
            
            for residue in chain:
                # Keep only standard amino acids
                if residue.id[0] == " " and residue.resname in standard_amino_acids:
                    new_chain.add(residue.copy())

            # Add cleaned chain to model
            if len(new_chain):
                new_model.add(new_chain)

    if len(new_model):
        new_structure.add(new_model)

    # Save the modified structure
    io = PDB.PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb)
    print(f"Saved filtered PDB (keeping only chains {chains_to_keep} with protein residues): {output_pdb}")

In [12]:
import os
from Bio import PDB

def filter_pdb_remove_chain(input_pdb, output_pdb, chain_to_remove):
    """
    Removes a specific chain from a PDB file while keeping all other chains.
    Also removes waters, salts, and non-protein molecules.

    Args:
        input_pdb (str): Path to the input PDB file.
        output_pdb (str): Path to save the modified PDB file.
        chain_to_remove (str): Chain ID to remove.
    """
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("structure", input_pdb)

    # Create a new structure
    new_structure = PDB.Structure.Structure("filtered_structure")
    model = structure[0]  # Assume single model
    new_model = PDB.Model.Model(0)

    # Define standard amino acid residues (3-letter codes)
    standard_amino_acids = {
        "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
        "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"
    }

    for chain in model:
        if chain.id != chain_to_remove:
            new_chain = PDB.Chain.Chain(chain.id)

            for residue in chain:
                # Keep only standard amino acids
                if residue.id[0] == " " and residue.resname in standard_amino_acids:
                    new_chain.add(residue.copy())

            # Add cleaned chain to model
            if len(new_chain):
                new_model.add(new_chain)

    if len(new_model):
        new_structure.add(new_model)

    # Save the modified structure
    io = PDB.PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb)
    print(f"Saved filtered PDB (removed chain {chain_to_remove}, kept only proteins): {output_pdb}")

In [28]:
df.loc[:,['pdb','Hchain','Lchain','antigen_chain']].head(5)

Unnamed: 0,pdb,Hchain,Lchain,antigen_chain
0,3u1s,H,L,
1,5uxq,H,L,
3,4nwu,H,L,
4,8ezl,H,H,A
5,6urh,H,L,C


In [None]:
❌ Error: PDB 7wor not found in DataFrame
❌ Error: PDB 8qw4 not found in DataFrame
❌ Error: PDB 6gwn not found in DataFrame
❌ Error: PDB 7nd0 not found in DataFrame
❌ Error: PDB 8fxs not found in DataFrame
❌ Error: PDB 8jep not found in DataFrame

In [13]:
import os
import numpy as np
import pandas as pd
from Bio import PDB

# Example Usage:
directory = "/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/imgt/"
output_directory = "/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/"

# Ensure output directory exists
os.makedirs(output_directory, exist_ok=True)


def get_chains_in_pdb(pdb_file):
    """
    Extracts all chain IDs present in a PDB file.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        list: List of chain IDs found in the PDB. If an error occurs, returns an empty list.
    """
    parser = PDB.PDBParser(QUIET=True)

    try:
        structure = parser.get_structure("structure", pdb_file)
        model = structure[0]  # Accessing the first model

        # Extract chain IDs
        chains = [chain.id.strip() for chain in model.get_chains()]
        return chains

    except KeyError:
        print(f"❌ KeyError: Could not access model 0 in {pdb_file}. No valid chains found.")
        return []  # Return an empty list to avoid breaking downstream code
    except Exception as e:
        print(f"❌ Error reading {pdb_file}: {e}")
        return []  # Return an empty list to ensure robustness


def trim_pdb_residues(input_pdb, output_pdb, max_residue_num=128):
    """
    Trims residues with numbers greater than max_residue_num in all chains.

    Args:
        input_pdb (str): Path to input PDB file.
        output_pdb (str): Path to save trimmed PDB file.
        max_residue_num (int): Maximum residue number to keep.
    """
    parser = PDB.PDBParser(QUIET=True)
    
    try:
        structure = parser.get_structure("structure", input_pdb)
    except ValueError as e:
        print(f"❌ ValueError while parsing {input_pdb}: {e}. Skipping file.")
        return

    # Create a new structure for the trimmed version
    new_structure = PDB.Structure.Structure("trimmed_structure")
    model = structure[0]  # Assume single-model PDB
    new_model = PDB.Model.Model(0)

    for chain in model:
        new_chain = PDB.Chain.Chain(chain.id)

        for residue in chain:
            try:
                res_id = int(residue.id[1])  # Extract residue number safely
                if res_id <= max_residue_num:
                    new_chain.add(residue.copy())
            except ValueError:
                print(f"⚠️ Skipping malformed residue ID in {input_pdb}: {residue}")

        if len(new_chain):  # Only add chain if residues remain
            new_model.add(new_chain)

    if len(new_model):
        new_structure.add(new_model)

    # Save the trimmed PDB file
    io = PDB.PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb)
    print(f"✅ Trimmed PDB saved: {output_pdb} (Residues > {max_residue_num} removed)")


# Iterate over PDB files
for filename in os.listdir(directory):
    if filename.endswith(".pdb"):
        input_pdb = os.path.join(directory, filename)
        pdb_id = os.path.splitext(filename)[0]  # Extract PDB ID from filename

        # Skip if trimmed file already exists
        output_pdb_trimmed = os.path.join(output_directory, f"trimmed_{filename}")
        if os.path.exists(output_pdb_trimmed):
            print(f"⏭️ Skipping {filename} (Trimmed version already exists)")
            continue

        try:
            # Extract heavy and light chain values
            pdb_entry = df[df['pdb'] == pdb_id]
            if pdb_entry.empty:
                print(f"❌ Error: PDB {pdb_id} not found in DataFrame")
                continue

            Hchain = pdb_entry['Hchain'].values[0]
            Lchain = pdb_entry['Lchain'].values[0]
            chains_to_keep = [Hchain, Lchain]  # Ensure it's a list

            # Extract antigen chain (if present)
            antigen = pdb_entry['antigen_chain'].values[0]
            chains_to_remove = antigen if pd.notna(antigen) else None

        except Exception as e:
            print(f"❌ Error processing PDB {pdb_id}: {e}")
            continue

        # Print all detected chains in the original PDB
        original_chains = get_chains_in_pdb(input_pdb)
        print(f"🔍 Original chains in {pdb_id}: {original_chains}")

        try:
            # Keep only specific chains
            output_pdb_keep = os.path.join(output_directory, f"filtered_keep_{filename}")
            filter_pdb_keep_only(input_pdb, output_pdb_keep, chains_to_keep)

            # Print chains after chains_to_keep filtering
            remaining_chains_after_keep = get_chains_in_pdb(output_pdb_keep)
            print(f"✅ Chains after keeping {chains_to_keep}: {remaining_chains_after_keep}")

            # Validate that the only chains left are the ones specified in chains_to_keep
            if set(remaining_chains_after_keep) == set(chains_to_keep):
                print("✅ Chain validation successful: Only expected chains remain.")
            else:
                print(f"⚠️ Chain validation failed: Unexpected chains detected. Running chain removal.")
                if chains_to_remove:
                    output_pdb_remove = os.path.join(output_directory, f"filtered_remove_{filename}")
                    filter_pdb_remove_chain(output_pdb_keep, output_pdb_remove, chains_to_remove)

                    # Print chains after chains_to_remove is applied
                    remaining_chains_after_remove = get_chains_in_pdb(output_pdb_remove)
                    print(f"✅ Chains after removing {chains_to_remove}: {remaining_chains_after_remove}")
                else:
                    print("⚠️ No antigen chain specified for removal, skipping.")

            # Trim residues > 128 and save processed version
            output_pdb_trimmed = os.path.join(output_directory, f"trimmed_{filename}")
            trim_pdb_residues(output_pdb_keep, output_pdb_trimmed)

            # Validate trimmed PDB has only 1 or 2 chains
            remaining_chains_trimmed = get_chains_in_pdb(output_pdb_trimmed)

            if len(remaining_chains_trimmed) > 2:
                print(f"⚠️ Warning: Trimmed PDB {output_pdb_trimmed} contains more than 2 chains! Found chains: {remaining_chains_trimmed}")

        except Exception as e:
            print(f"❌ Error processing {filename}: {e}. Skipping this file and moving to the next one.")
            continue  # Move on to the next file

print("✅ Processing complete.")

🔍 Original chains in 7k9z: ['E', 'H', 'L', 'B', 'A']
Saved filtered PDB (keeping only chains ['H', 'L'] with protein residues): /home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/filtered_keep_7k9z.pdb
✅ Chains after keeping ['H', 'L']: ['H', 'L']
✅ Chain validation successful: Only expected chains remain.
✅ Trimmed PDB saved: /home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/trimmed_7k9z.pdb (Residues > 128 removed)
❌ Error: PDB 5itf not found in DataFrame
🔍 Original chains in 6mts: ['H', 'L']
Saved filtered PDB (keeping only chains ['H', 'L'] with protein residues): /home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/filtered_keep_6mts.pdb
✅ Chains after keeping ['H', 'L']: ['H', 'L']
✅ Chain validation successful: Only expected chains remain.
✅ Trimmed PDB saved: /home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/trimmed_6mts.pdb (Residues > 128 removed)
🔍 Original chains in 8ysf: ['B', 'A', 'C', 'D']
Saved filtered PDB

KeyboardInterrupt: 

In [None]:
DATA/hyperbind_train/sabdab/all_structures/processed/trimmed/trimmed_1nj9.pdb

In [1]:
from esm.sdk.api import ESMProtein, ProteinComplex
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

In [2]:
import os
from Bio import PDB

In [3]:
def read_multimer_structure(pdb_input):
    """
    Reads and processes a multimer PDB structure.
    Args:
        pdb_input (str): Path to the PDB file.
    Returns:
        protein (ESMProtein): Processed protein structure.
    """
    try:
        complex = ProteinComplex.from_pdb(pdb_input)
        multimer_protein = ESMProtein.from_protein_complex(complex)
        return multimer_protein
    except Exception as e:
        print(f"❌ Error processing multimer structure: {e}")
        return None


def read_monomer_structure(pdb_input, chain):
    """
    Reads and processes a monomer PDB structure for a given chain.
    Args:
        pdb_input (str): Path to the PDB file.
        chain (str): Chain ID to process.
    Returns:
        protein (ESMProtein): Processed monomer structure.
    """
    try:
        protein_chain = ProteinChain.from_pdb(pdb_input, chain_id=chain)
        monomer_protein = ESMProtein.from_protein_chain(protein_chain)
        return monomer_protein
    except Exception as e:
        print(f"❌ Error processing monomer structure (Chain {chain}): {e}")
        return None


def detect_and_process_structure(pdb_input):
    """
    Detects if a PDB file contains a single or multiple chains and runs the appropriate function.

    Args:
        pdb_input (str): Path to the PDB file.

    Returns:
        protein (ESMProtein): Processed protein structure.
    """
    parser = PDB.PDBParser(QUIET=True)

    try:
        structure = parser.get_structure("protein", pdb_input)
        model = structure[0]  # Assume single-model PDB

        # Extract all available chain IDs
        chains = [chain.id.strip() for chain in model.get_chains()]
        num_chains = len(chains)

        if num_chains == 0:
            print(f"❌ Error: No valid protein chains found in {pdb_input}.")
            return None

        elif num_chains > 1:
            print(f"✅ Detected {num_chains} chains ({chains}). Processing as a **multimer**.")
            return read_multimer_structure(pdb_input)
        
        elif num_chains == 1:
            chain_id = chains[0]
            if chain_id:
                print(f"✅ Detected a **single chain ({chain_id})**. Processing as a **monomer**.")
                return read_monomer_structure(pdb_input, chain_id)
            else:
                print(f"⚠️ Error: Chain ID not properly detected in {pdb_input}.")
                return None

    except FileNotFoundError:
        print(f"❌ Error: PDB file not found: {pdb_input}")
    except Exception as e:
        print(f"❌ Unexpected error reading PDB file {pdb_input}: {e}")

    return None

In [4]:
# Example usage
pdb_input = "/home/jupyter/1BEY.pdb"
protein = detect_and_process_structure(pdb_input)

✅ Detected 2 chains (['L', 'H']). Processing as a **multimer**.


In [21]:
# Functions for visualizing 3D structure
import py3Dmol

def visualize_pdb(pdb_string):
    view = py3Dmol.view(width=400, height=400)
    view.addModel(pdb_string, "pdb")
    view.setStyle({"cartoon": {"color": "spectrum"}})
    view.zoomTo()
    view.render()
    view.center()
    return view


def visualize_pdb(pdb_string, width=400, height=400):
    view = py3Dmol.view(width=width, height=height)
    view.addModel(pdb_string, "pdb")
    view.setStyle({"cartoon": {"color": "spectrum"}})
    view.zoomTo()
    return view.show()


def visualize_multiple_pdbs(pdb_strings, columns=2):
    """
    Display multiple PDB structures in a grid layout using ipywidgets.
    
    Args:
        pdb_strings (list): List of PDB format strings.
        columns (int): Number of columns for layout.
    """
    n = len(pdb_strings)
    rows = -(-n // columns)  # Ceiling division

    # Create widget containers for display
    output_widgets = [[widgets.Output() for _ in range(columns)] for _ in range(rows)]

    # Render each PDB structure in the corresponding widget
    for i, pdb_string in enumerate(pdb_strings):
        row, col = divmod(i, columns)
        with output_widgets[row][col]:
            visualize_pdb(pdb_string)

    # Display in a grid layout
    grid = widgets.VBox([widgets.HBox(row) for row in output_widgets])
    display(grid)


def visualize_3D_coordinates(coordinates):
    """
    This uses all Alanines
    """
    protein_with_same_coords = ESMProtein(coordinates=coordinates)
    # pdb with all alanines
    pdb_string = protein_with_same_coords.to_pdb_string()
    return visualize_pdb(pdb_string)


def visualize_3D_protein(protein):
    pdb_string = protein.to_pdb_string()
    return visualize_pdb(pdb_string)

In [7]:
# visualize from just the coordinates
visualize_3D_coordinates(protein.coordinates)

<py3Dmol.view at 0x7fd51cb37b20>

In [23]:
import random
# Iterate over PDB files
directory = '/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/trimmed/'


protein_list = []
for filename in os.listdir(directory):
    if filename.endswith(".pdb"):
        print(filename)
        input_pdb = os.path.join(directory, filename)
        protein = detect_and_process_structure(input_pdb)
        protein_list.append(protein)

trimmed_1nc4.pdb
✅ Detected 2 chains (['D', 'C']). Processing as a **multimer**.
trimmed_4m1g.pdb
✅ Detected 2 chains (['H', 'L']). Processing as a **multimer**.
trimmed_7nx7.pdb
✅ Detected 2 chains (['H', 'L']). Processing as a **multimer**.
trimmed_1eo8.pdb
✅ Detected 2 chains (['H', 'L']). Processing as a **multimer**.
trimmed_1nsn.pdb
✅ Detected 2 chains (['H', 'L']). Processing as a **multimer**.
trimmed_5chn.pdb
✅ Detected 2 chains (['A', 'B']). Processing as a **multimer**.
trimmed_5zia.pdb
✅ Detected 2 chains (['G', 'L']). Processing as a **multimer**.
trimmed_9b44.pdb
✅ Detected 2 chains (['A', 'B']). Processing as a **multimer**.
trimmed_6k69.pdb
✅ Detected a **single chain (A)**. Processing as a **monomer**.
trimmed_9erw.pdb
✅ Detected a **single chain (E)**. Processing as a **monomer**.
trimmed_1eap.pdb
✅ Detected 2 chains (['B', 'A']). Processing as a **multimer**.
trimmed_7ul0.pdb
✅ Detected 2 chains (['H', 'L']). Processing as a **multimer**.
trimmed_3cxd.pdb
✅ Detected 

KeyboardInterrupt: 

In [38]:
i = random.randint(0,20)
visual_input = protein_list[i]
print(visual_input.sequence)
visualize_3D_coordinates(visual_input.coordinates)

EVQLEESGGGLVQAGGSLTLSCAASGFTFDDYAMGWYRQAPGKERVGVSCISRTDGYTYYLDSVKGRFTISTDHAKHTVYLQMNNLKPDDTGLYYCAADADPEYGSRCPDPYYGMDYWGKGILVTVS


In [45]:
test_df['resolution']

123     1.29000
257     1.22600
470     1.29600
882     1.20000
1583    1.20000
1736    1.03000
2015    1.16300
2535    1.22000
3101    1.27400
3310    1.26000
4024    1.11700
4087    1.12300
4506    1.10000
4554    1.06000
4571    1.25000
4639    1.11097
4661    1.27000
4749    1.15300
4847    1.10000
4851    1.20000
4890    1.26000
5098    1.00000
5108    1.06700
5277    1.14000
5308    1.13000
5378    1.25000
5414    1.19100
5475    1.23000
5634    0.92000
Name: resolution, dtype: float64

TODAY:

Structure finetune challenge
1. develop a train/val sabdab set DONE
2. develop a test sabdab set DONE
3. establish baseline of sequence --> esm3 infer structure ---> RMSD Angstrom MSE (something built in to ESM3 for this?) DONE

Sequence fientune challenge
1. develop train/val dataset
2. develop a test dataset
3. establish baseline of mask in-fill accuracy ---> esm3 MLM infer mask tokens ---> Binary Cross Entropy

STRETCH GOAL:
1. Run a extremely basic functional programming test on the ability to finetune (trainset is n=10, val is n=3)
   1.4B weights. Epochs = 1 or 2.

   - Can we push the weights at all?
   - How does ESM switch from Train to Validate?
   - How do we run infer on a test and get accuracy?
   - How do we monitor loss?

In [53]:
df_test

NameError: name 'df_test' is not defined

In [51]:
len(df)

2915

In [None]:
directory = '/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/trimmed/'

for filename in os.listdir(directory):
    if filename.endswith(".pdb"):
        input_pdb = os.path.join(directory, filename)

In [60]:
import os
import random
import shutil

# Directories
directory = "/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/processed/trimmed/"
output_directory = "/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/train-test-split/"
os.makedirs(output_directory, exist_ok=True)

# List of test PDBs with ultra high resolution
test_pdbs = set(test_df['pdb'].tolist())  # ✅ Convert to set

# Gather all PDB files
all_pdb_files = [f for f in os.listdir(directory) if f.endswith(".pdb")]

# Extract PDB IDs from filenames (removing "trimmed_" prefix)
all_pdb_ids = {f.replace("trimmed_", "").replace(".pdb", "") for f in all_pdb_files}

# Identify trainable PDBs (excluding test set)
trainable_pdbs = list(all_pdb_ids - test_pdbs)  # ✅ Fix applied

# Select 10% of trainable PDBs for validation
val_size = int(0.10 * len(trainable_pdbs))
val_pdbs = set(random.sample(trainable_pdbs, val_size))

# Remaining PDBs are for training
train_pdbs = all_pdb_ids - test_pdbs - val_pdbs

# File renaming and moving
for filename in all_pdb_files:
    pdb_id = filename.replace("trimmed_", "").replace(".pdb", "")
    input_pdb = os.path.join(directory, filename)

    if pdb_id in test_pdbs:
        new_filename = f"{pdb_id}_test.pdb"
    elif pdb_id in val_pdbs:
        new_filename = f"{pdb_id}_val.pdb"
    elif pdb_id in train_pdbs:
        new_filename = f"{pdb_id}_train.pdb"
    else:
        print(f"⚠️ Warning: {pdb_id} not categorized. Skipping.")
        continue

    # Move file to output directory with new name
    output_pdb = os.path.join(output_directory, new_filename)
    shutil.copy(input_pdb, output_pdb)
    print(f"✅ Renamed: {filename} → {new_filename}")

print(f"\n✅ Process Complete: {len(train_pdbs)} train, {len(val_pdbs)} val, {len(test_pdbs)} test.")

✅ Renamed: trimmed_1nc4.pdb → 1nc4_train.pdb
✅ Renamed: trimmed_4m1g.pdb → 4m1g_train.pdb
✅ Renamed: trimmed_7nx7.pdb → 7nx7_train.pdb
✅ Renamed: trimmed_1eo8.pdb → 1eo8_train.pdb
✅ Renamed: trimmed_1nsn.pdb → 1nsn_train.pdb
✅ Renamed: trimmed_5chn.pdb → 5chn_train.pdb
✅ Renamed: trimmed_5zia.pdb → 5zia_train.pdb
✅ Renamed: trimmed_9b44.pdb → 9b44_train.pdb
✅ Renamed: trimmed_6k69.pdb → 6k69_train.pdb
✅ Renamed: trimmed_9erw.pdb → 9erw_train.pdb
✅ Renamed: trimmed_1eap.pdb → 1eap_train.pdb
✅ Renamed: trimmed_7ul0.pdb → 7ul0_train.pdb
✅ Renamed: trimmed_3cxd.pdb → 3cxd_train.pdb
✅ Renamed: trimmed_8q7s.pdb → 8q7s_train.pdb
✅ Renamed: trimmed_5xqw.pdb → 5xqw_train.pdb
✅ Renamed: trimmed_5gzo.pdb → 5gzo_train.pdb
✅ Renamed: trimmed_4hs6.pdb → 4hs6_train.pdb
✅ Renamed: trimmed_7tuf.pdb → 7tuf_train.pdb
✅ Renamed: trimmed_6u6o.pdb → 6u6o_train.pdb
✅ Renamed: trimmed_6xkp.pdb → 6xkp_train.pdb
✅ Renamed: trimmed_8djg.pdb → 8djg_train.pdb
✅ Renamed: trimmed_7uvh.pdb → 7uvh_train.pdb
✅ Renamed: