In [2]:
import sys
import os

# Assume the current working directory is "HyperBind2-OpenSource"
project_root = os.path.abspath('/home/jupyter/HyperBind2-OpenSource')

In [3]:
# Add the model directory to sys.path
model_path = os.path.join(project_root, "scripts", "model")
if model_path not in sys.path:
    sys.path.insert(0, model_path)

from antibody_structure_encoder import AntibodyStructureEncoder

print("Imported AntibodyStructureEncoder from antibody_structure_encoder:", AntibodyStructureEncoder)

Imported AntibodyStructureEncoder from antibody_structure_encoder: <class 'antibody_structure_encoder.AntibodyStructureEncoder'>


In [4]:
import torch

from antibody_structure_encoder import AntibodyStructureEncoder
from esm.sdk.api import ESMProtein

In [6]:
# Create a dummy antibody structure (ESMProtein object)
dummy_antibody = ESMProtein()

# For this example, we create dummy data.
# In practice, you would load a real antibody structure via your parser.
dummy_length = 100  # example protein length
dummy_antibody.sequence = "A" * dummy_length  # A simple sequence of alanines
# Create dummy coordinates with shape (1, L, 37, 3) as expected by tokenize_structure.
dummy_antibody.coordinates = torch.randn(dummy_length, 37, 3)

# Instantiate the singleton encoder.a
encoder = AntibodyStructureEncoder(device="cuda")
# Optionally, set the encoder to evaluation mode to avoid certain training-specific branches.
encoder.encoder.eval()

# Use the encoder to encode the dummy antibody.
try:
    encoded_coords = encoder.encode(dummy_antibody)
    print("Encoded coordinates type:", type(encoded_coords))
    print("Encoded coordinates shape:", encoded_coords.shape)
except Exception as e:
    print("Error during encoding:", e)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


Encoded coordinates type: <class 'torch.Tensor'>
Encoded coordinates shape: torch.Size([102, 37, 3])


In [33]:
#!/usr/bin/env python3
"""
Module for encoding and ingesting antibody structure data for fine-tuning.

This file includes:
  - AntibodyStructureDataset: A PyTorch Dataset for processing PDB files.
  - create_dataloaders: A convenience function to build training and validation DataLoaders.
  - custom_collate_fn: A custom collate function to handle variable-length tensors.
  
Usage:
  python antibody_structure_ingestion.py --pdb_dir <PDB_DIRECTORY> [--batch_size 2]
"""

import os
import sys
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Import necessary functions and classes from the codebase.
from esm.sdk.api import ESMProtein
from antibody_structure_encoder import AntibodyStructureEncoder  # Import the singleton antibody encoder

# Instantiate the singleton encoder once.
# This instance will be reused throughout the application.
encoder_instance = AntibodyStructureEncoder(device="cuda")

sys.path.insert(0, os.path.abspath(os.path.join('..', '..', 'scripts', 'data_ingestion', 'antibody_structure_ingestion')))
from pdb2esm import read_monomer_structure, read_multimer_structure, detect_and_process_structure

class AntibodyStructureDataset(Dataset):
    """
    Custom Dataset for loading and encoding antibody structures.

    Each item is a tuple of (encoded_antibody, ground_truth_coordinates).
    """
    def __init__(self, pdb_directory: str, suffix: str, encoder):
        """
        Initializes the dataset by scanning for PDB files with a given suffix.

        Args:
            pdb_directory (str): Directory containing PDB files.
            suffix (str): File suffix to filter PDB files (e.g., "_train.pdb" or "_val.pdb").
            encoder: An instance of AntibodyStructureEncoder (singleton) to encode an antibody.
        """
        self.pdb_directory = pdb_directory
        self.pdb_files = [
            os.path.join(pdb_directory, f)
            for f in os.listdir(pdb_directory) if f.endswith(suffix)
        ]
        self.encoder = encoder

    def __len__(self):
        """Returns the number of PDB files found."""
        return len(self.pdb_files)

    def __getitem__(self, idx):
        """
        Retrieves and processes a single PDB file.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (encoded_antibody, ground_truth_coordinates)
        """
        pdb_path = self.pdb_files[idx]
        # Process the PDB file into an ESMProtein (antibody) object using the custom parser.
        antibody = detect_and_process_structure(pdb_path)
        if antibody is None:
            raise ValueError(f"Antibody processing failed for {pdb_path}")
        # Ground truth coordinates from the processed antibody.
        gt_coords = antibody.coordinates
        # Encode the antibody using the singleton encoder instance.
        encoded_antibody = self.encoder.encode(antibody)
        return encoded_antibody, gt_coords

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-length tensors by padding them.
    
    Args:
        batch (list): A list of tuples (encoded_antibody, ground_truth_coordinates).
    
    Returns:
        tuple: Padded tensors for encoded antibodies and ground truth coordinates.
    """
    # Unzip the batch.
    encoded_list, coords_list = zip(*batch)
    # Pad the list of tensors along the variable dimension (assumed dim 0).
    padded_encoded = pad_sequence(encoded_list, batch_first=True, padding_value=0)
    padded_coords = pad_sequence(coords_list, batch_first=True, padding_value=0)
    return padded_encoded, padded_coords

def create_dataloaders(pdb_directory: str, batch_size: int = 2):
    """
    Creates DataLoaders for training and validation datasets.

    Args:
        pdb_directory (str): Directory containing PDB files.
        batch_size (int): Batch size for DataLoader.

    Returns:
        tuple: (train_loader, val_loader) for training and validation data.
    """
    train_dataset = AntibodyStructureDataset(pdb_directory, "_train.pdb", encoder=encoder_instance)
    val_dataset = AntibodyStructureDataset(pdb_directory, "_val.pdb", encoder=encoder_instance)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    return train_loader, val_loader

In [34]:
pdb_directory='/home/jupyter/DATA/hyperbind_train/sabdab/all_structures/train-test-split/'
train_loader, val_loader = create_dataloaders(pdb_directory, batch_size=2)

In [35]:
for item in train_loader:
    print(item)
    break

✅ Detected 2 chains (['K', 'L']). Processing as a multimer.
✅ Detected 2 chains (['H', 'L']). Processing as a multimer.
(tensor([[[[     inf,      inf,      inf],
          [     inf,      inf,      inf],
          [     inf,      inf,      inf],
          ...,
          [     inf,      inf,      inf],
          [     inf,      inf,      inf],
          [     inf,      inf,      inf]],

         [[  3.6154,  18.4525,  -6.7550],
          [  4.9069,  18.2309,  -7.3973],
          [  5.0126,  16.7958,  -7.9026],
          ...,
          [     nan,      nan,      nan],
          [     nan,      nan,      nan],
          [     nan,      nan,      nan]],

         [[  5.1713,  15.8376,  -6.9893],
          [  5.1827,  14.4347,  -7.3805],
          [  3.7492,  14.0737,  -7.7474],
          ...,
          [     nan,      nan,      nan],
          [     nan,      nan,      nan],
          [     nan,      nan,      nan]],

         ...,

         [[-14.4422,  10.9439,  18.5214],
          [-15.

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
