In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Import necessary functions and classes from the codebase
from esm.sdk.api import ESMProtein
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.encoding import tokenize_structure  # to convert raw structure to tokens/tensors
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.models.vqvae import StructureTokenEncoder

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join('..','..', 'scripts','data_ingestion','antibody_structure_ingestion')))
from pdb2esm import read_monomer_structure, read_multimer_structure, detect_and_process_structure

#### start by loading in ESMProteins from train directory...

*The resulting object is not directly passed into model, but useful for validating the Encoder*

In [None]:
# Directory where training PDB files are stored
train_directory = '/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/train-test-split/'

# List all files ending with "_train.pdb" in the directory
train_pdb_files = [
    os.path.join(train_directory, f)
    for f in os.listdir(train_directory) if f.endswith("_train.pdb")
]

# Process each file into an ESMProtein object
esm_protein_list = []
for pdb_path in train_pdb_files:
    protein = detect_and_process_structure(pdb_path)
    if protein is None:
        print(f"Warning: Failed to process {pdb_path}")
        continue
    esm_protein_list.append(protein)

# uncomment if you wish
# print(f"Loaded {len(esm_protein_list)} ESMProtein objects from training files.")

# Instantiate ESM3 Tokenizer + Encoder

##### Instantiate the structure tokenizer and encoder. We will use pretrained parameters that match the setup native to ESM3-OS, but the untrained encoder is also provided below with a breakdown for our learning.

The `structure_encoder` function shows that the pre‐training setup for the structure encoder. It used a d_model of 1024, 1 attention head (with additional parameters like v_heads, n_layers, etc.), a d_out of 128, and a codebook size of 4096. You can find this exact instantiation in the codebase (see esm/pretrained.py

**The class definitions in esm/models/vqvae.py also reveal how parameters like d_model, n_heads, d_out, and n_codes are used to set up the structure token encoder, illustrating the design choices for handling protein structure. Read the notes below to learn.**

•d_model:
This is the dimensionality of the model’s hidden representations. In transformer architectures, each token (or residue, in the case of protein models) is represented by a vector of length d_model. For example, if d_model is 1024, every residue is encoded as a 1024-dimensional vector. This size is a design choice that balances model capacity with computational cost.

•n_heads:
In multi-head attention, n_heads determines how many separate attention “heads” the model uses. Each head performs its own attention calculation in a lower-dimensional subspace (i.e. d_model divided by n_heads) and the results are combined. More heads allow the model to capture different aspects or relationships within the data. In some configurations for the structure branch (as seen in esm/pretrained.py), you might see n_heads set to 1, meaning all attention is computed in a single head.

•v_heads:
The term “v_heads” stands for “voting heads” and is specific to parts of the architecture designed for geometric reasoning or structure prediction. Instead of (or in addition to) standard attention, these heads help aggregate structural information—effectively “voting” on spatial or geometric features needed to predict 3D structure. For instance, in the pretraining code for the structure encoder, v_heads might be set to 128, indicating that 128 separate voting mechanisms are used to capture geometric nuances.

•n_layers:
This parameter indicates the number of layers (or blocks) in the network. In a transformer, each layer typically consists of an attention mechanism followed by a feed-forward network. More layers usually allow the model to capture increasingly complex patterns.

•d_out:
In contexts like the structure token encoder, d_out is the dimension of the output projections. After processing with the transformer layers, the model might project the hidden representation to a lower-dimensional space (e.g. for predicting structure tokens or coordinates) where d_out defines that size.

•n_codes:
This parameter is used in models employing a vector quantization (VQ) approach (as in the VQ-VAE modules). n_codes is the number of discrete codes (or “entries”) in the codebook. During training, continuous latent representations are mapped to one of these codes, which helps in regularizing the model and capturing discrete structure in the data.

*Can we modify the encoder for antibodies?*

1.	Pretrained Architecture Compatibility:
The original ESM3 model’s configuration (with parameters like d_model, n_heads, v_heads, n_layers, etc.) was carefully tuned during pretraining. If you change these hyperparameters, you won’t be able to directly load the pretrained weights because the architecture will no longer match. In other words, modifying these settings without re‐training from scratch could “mess up” the model or lead to unpredictable behavior.

2.	Tailoring for Antibody Structures:
Antibodies, and especially their CDRH3 loops, often require capturing very fine, local structural details. One strategy is not necessarily to change the global architecture but rather to add specialized modules or adjust the attention mechanisms to focus more on local interactions in these regions. For example, you might incorporate a local refinement layer or tweak the geometric attention mechanism. That way, you retain the benefits of the pretrained ESM3 while enhancing its ability to model antibody-specific features.

3.	Trade-offs and Training Strategy:
If you opt for a custom configuration that significantly changes parameters (like increasing the resolution for local regions by modifying d_model or n_heads), you will need to train the model (or at least the modified parts) from scratch or with careful fine-tuning. This can be resource intensive and might require a larger dataset that includes plenty of antibody examples.

In [3]:
structure_tokenizer = StructureTokenizer()
structure_encoder = StructureTokenEncoder(
    d_model=1024, 
    n_heads=1, 
    v_heads=128, 
    n_layers=2, 
    d_out=128, 
    n_codes=4096
).train()

*use the pretrained encoder instead and finetune on this.*

In [6]:
from esm.pretrained import ESM3_structure_encoder_v0
structure_tokenizer = StructureTokenizer()
encoder = ESM3_structure_encoder_v0(device="cuda")  # or "cpu"
encoder.train()  # if you need to fine-tune

StructureTokenEncoder(
  (transformer): GeometricEncoderStack(
    (blocks): ModuleList(
      (0-1): 2 x UnifiedTransformerBlock(
        (geom_attn): GeometricReasoningOriginalImpl(
          (s_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (proj): Linear(in_features=1024, out_features=1920, bias=True)
          (out_proj): Linear(in_features=384, out_features=1024, bias=True)
        )
        (ffn): Sequential(
          (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1024, out_features=8192, bias=True)
          (2): SwiGLU()
          (3): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (norm): Identity()
  )
  (pre_vq_proj): Linear(in_features=1024, out_features=128, bias=True)
  (codebook): EMACodebook()
  (relative_positional_embedding): RelativePositionEmbedding(
    (embedding): Embedding(66, 1024)
  )
)

### validate the encoder works on ESMProtein data

In [7]:
# Set the encoder to evaluation mode if you are only testing the encoding.
# (If you are fine-tuning, you'll later call .train())
# structure_encoder.eval()
encoder.eval()

for protein in esm_protein_list[0:3]:
    # Now, use the tokenize_structure utility to encode the protein.
    # The function expects the raw coordinates, the encoder, the tokenizer,
    # and the reference sequence (from the protein) as inputs.
    encoded_coords, plddt, structure_tokens = tokenize_structure(
        protein.coordinates,
        encoder,
        structure_tokenizer,
        reference_sequence=protein.sequence,
        add_special_tokens=True
    )

    # Print out shapes and tokens to verify the encoding.
    print("Encoded coordinates shape:", encoded_coords.shape)
    print("PLDDT shape:", plddt.shape)
    print("Structure tokens:", structure_tokens)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


Encoded coordinates shape: torch.Size([226, 37, 3])
PLDDT shape: torch.Size([226])
Structure tokens: tensor([4098,  144,   53, 3314, 1650, 1582, 3484, 3599, 2168, 3245, 3670, 2834,
        2543, 1854, 2982, 3563, 1486, 2800, 1884,  597, 3748, 3881, 2952,  672,
        2420,  965, 3564,  897, 1585,  620, 1067,  156, 3304,  821, 2165, 2084,
        3075,  534, 3881,  747, 2227, 1600,  728, 3697, 3245, 1739, 2567,  539,
        2517,  795, 3054, 4063, 1582,  976,  318, 3554, 2552, 3619, 2640,  104,
        1429, 1842, 2938, 3915, 3372,   38, 1925, 2571,  747, 3748, 2408,  755,
        2372, 1475, 3300, 3146, 3750, 3867, 2534,  618,  178, 1568, 3090, 3831,
        2685, 2633, 1272,  280, 4028, 3462,  355, 1020,  554,  281, 3889, 2087,
        3889,  421,  471, 1310, 1117, 1579,  621,  988, 3189, 3589, 2904, 3034,
        1313, 3558, 2787, 3052, 1514, 2708, 4051, 1095, 1054, 2246, 1697, 3152,
        1143, 2791,  121, 1207,   58,    3, 3296, 1918, 2408, 3912, 1419, 2953,
        3366, 3563,

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


In [8]:
def encode_protein(protein: ESMProtein):
    """
    Encodes an ESMProtein object into tensors suitable for structure fine-tuning.
    This uses the tokenize_structure function from the codebase to generate:
      - encoded backbone coordinates input (processed coordinates)
      - additional tokens if needed (here we return only the encoded coordinates).
    
    Assumes protein.coordinates holds the raw backbone coordinate tensor and 
    protein.sequence holds the protein sequence.
    
    Using this wrapper function avoids manually defined collation for pytorch DataLoader
    """
    # tokenize_structure converts raw coordinates and sequence into the proper inputs for the structure head.
    # It returns a tuple: (coordinates, plddt, structure_tokens)
    encoded_coords, plddt, structure_tokens = tokenize_structure(
        protein.coordinates,
        encoder,
        structure_tokenizer,
        reference_sequence=protein.sequence,
        add_special_tokens=True
    )
    
    # For our training, we assume the encoded_coords is what the model will consume.
    return encoded_coords

•	`structure_encoder`:
The untrained encoder module (from esm/models/vqvae.py) that can be customized for a full train. Shoudl not be direct invocation for finetunes.

•	`encoder`: 
THeh encoder from pretrained ESM3 model with weights. It is called using ESM's `tokenize_structure` for data ingestion processing

•	`encode_protein` (wrapper function):
A higher‑level function that calls ESM's `tokenize_structure` to convert an ESMProtein into the final encoded tensor. 
This is what we should use in DataLoader. It's an invocation for this training instance.


In [6]:
# # ---------------------
# # Define a custom Dataset no Encoder
# # ---------------------
# class ProteinStructureDataset(Dataset):
#     def __init__(self, pdb_directory, suffix):
#         """
#         Args:
#             pdb_directory (str): Directory where the PDB files are stored.
#             suffix (str): Suffix to filter files (e.g. "_train.pdb" or "_val.pdb").
#         """
#         self.pdb_directory = pdb_directory
#         self.pdb_files = [os.path.join(pdb_directory, f)
#                           for f in os.listdir(pdb_directory) if f.endswith(suffix)]
        
#     def __len__(self):
#         return len(self.pdb_files)
    
#     def __getitem__(self, idx):
#         pdb_path = self.pdb_files[idx]
#         protein = detect_and_process_structure(pdb_path)
#         if protein is None:
#             raise ValueError(f"Protein processing failed for {pdb_path}")
#         # Assume that after processing, protein.coordinates is a torch.Tensor
#         # representing the ground-truth structure coordinates.
#         gt_coords = protein.coordinates  # shape: (L, 3) or (L, 3, ...) as needed
#         return protein, gt_coords

In [9]:
# Revised Dataset that directly returns the encoded protein tensor
class ProteinStructureDataset(Dataset):
    def __init__(self, pdb_directory: str, suffix: str, encoder):
        """
        Args:
            pdb_directory (str): Directory where the PDB files are stored.
            suffix (str): Suffix to filter files (e.g. "_train.pdb" or "_val.pdb").
            encoder (callable): Function that converts an ESMProtein into a tensor.
        """
        self.pdb_directory = pdb_directory
        self.pdb_files = [
            os.path.join(pdb_directory, f)
            for f in os.listdir(pdb_directory) if f.endswith(suffix)
        ]
        self.encoder = encoder

    def __len__(self):
        return len(self.pdb_files)

    def __getitem__(self, idx):
        pdb_path = self.pdb_files[idx]
        # Use your existing function to process the PDB into an ESMProtein.
        protein = detect_and_process_structure(pdb_path)
        if protein is None:
            raise ValueError(f"Protein processing failed for {pdb_path}")
        # Ground truth coordinates remain the same.
        gt_coords = protein.coordinates  
        # Encode the protein into the tensor expected by the model.
        encoded_protein = self.encoder(protein)
        return encoded_protein, gt_coords

In [10]:
# ---------------------
# Create datasets and dataloaders
# ---------------------
train_directory = '/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/train-test-split/'

train_dataset = ProteinStructureDataset(train_directory, "_train.pdb", encoder=encode_protein)
val_dataset   = ProteinStructureDataset(train_directory, "_val.pdb", encoder=encode_protein)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [9]:
# ---------------------
# Prepare the model for fine-tuning
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def create_esm3_model():
    from esm.pretrained import ESM3_sm_open_v0
    model = ESM3_sm_open_v0()  # Now returns only the model object
    
    return model

# Create the model (as in your provided code)
model = create_esm3_model()
state_dict = torch.load(
    "/home/jupyter/DATA/model_weights/esm3_complete/esm3_sm_open_v1_state_dict.pt",
    map_location=device
)
model.load_state_dict(state_dict)
model.to(device)

# IMPORTANT: Ensure that the model’s structure prediction head is trainable.
model.train()

ESM3(
  (encoder): EncodeInputs(
    (sequence_embed): Embedding(64, 1536)
    (plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_per_res_plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_tokens_embed): Embedding(4101, 1536)
    (ss8_embed): Embedding(11, 1536)
    (sasa_embed): Embedding(19, 1536)
    (function_embed): ModuleList(
      (0-7): 8 x Embedding(260, 192, padding_idx=0)
    )
    (residue_embed): EmbeddingBag(1478, 1536, mode='sum', padding_idx=0)
  )
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0): UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1536, out_features=4608, bias=False)
          )
          (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (q_ln): LayerNorm((1536,), eps=1e-05, el

In [12]:
# After loading the model and before starting training:
for module in model.modules():
    from esm.layers.codebook import EMACodebook  # Ensure we refer to the correct class.
    if isinstance(module, EMACodebook):
         module.freeze_codebook = True

In [15]:
from esm.pretrained import ESM3_structure_decoder_v0

# This helper instantiates the structure head (decoder) with its pretrained configuration.
structure_head_pretrained = ESM3_structure_decoder_v0(device="cuda")
structure_head_pretrained.train()  # set to training mode so that gradients are computed

StructureTokenDecoder(
  (embed): Embedding(4101, 1280)
  (decoder_stack): TransformerStack(
    (blocks): ModuleList(
      (0-29): 30 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1280, out_features=3840, bias=False)
          )
          (out_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (q_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1280, out_features=7168, bias=False)
          (2): SwiGLU()
          (3): Linear(in_features=3584, out_features=1280, bias=False)
        )
      )
    )
    (norm): LayerNorm((1280,), eps=1e-05, elem

In [13]:
# ---------------------
# Setup training components
# ---------------------
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

num_epochs = 5

In [14]:
# --------------
# Validate DataLoader
# --------------

# Get one batch from the train loader
batch = next(iter(train_loader))
encoded_batch, gt_coords = batch

print("Encoded batch shape:", encoded_batch.shape)
print("Ground truth coordinates shape:", gt_coords.shape)

✅ Detected 2 chains (['G', 'I']). Processing as a multimer.


  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


AssertionError: Not implemented

In [16]:
# ---------------------
# Training loop
# ---------------------
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for proteins, gt_coords in train_loader:
        # proteins is a list of ESMProtein objects; if needed, you may convert them
        # into a batched format that your model accepts.
        # Move ground truth coordinates to device.
        # (If gt_coords is not already a tensor, convert it accordingly.)
        gt_coords = torch.stack([p.to(device) if isinstance(p, torch.Tensor) else p 
                                 for p in gt_coords]).float()
        
        # Forward pass: obtain predicted structure coordinates.
        # Depending on your model's implementation, you might need to modify this.
        pred_coords = forward_structure(model, proteins)
        # Ensure pred_coords is on the same device and has the same shape as gt_coords.
        pred_coords = pred_coords.to(device)
        
        loss = loss_fn(pred_coords, gt_coords)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}")
    
    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for proteins, gt_coords in val_loader:
            gt_coords = torch.stack([p.to(device) if isinstance(p, torch.Tensor) else p 
                                     for p in gt_coords]).float()
            pred_coords = forward_structure(model, proteins)
            pred_coords = pred_coords.to(device)
            loss = loss_fn(pred_coords, gt_coords)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Val Loss: {avg_val_loss:.4f}")

print("Fine-tuning complete!")

✅ Detected 2 chains (['D', 'C']). Processing as a multimer.
✅ Detected 2 chains (['H', 'L']). Processing as a multimer.


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'esm.sdk.api.ESMProtein'>