# Task Details
**Objective**: Create a new residue contact prediction model that leverages both the ESM2 embeddings of the
input sequence residues and structural data from similar sequences. \n
**Implementation**:
* Implement your model in Python, modifying or extending the ESM2 model to incorporate
additional structural information. The input of the model is a single sequence, the output is a binary
contact map (1 - if there is a contact between the residue pair, 0 - if there is no contact).
* Use an open-source protein database, such as the Protein Data Bank (PDB), to access sequences
and their 3D structural data for training and evaluation.
* Ensure that the dataset is split into training, validation, and test sets, with a proper handling of
sequence similarity to avoid train-test data leakage.

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import EsmTokenizer, EsmModel
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import requests
from prody import parsePDB, buildDistMatrix
from Bio.PDB.Polypeptide import three_to_one
from sklearn.metrics import accuracy_score, f1_score
from scipy.sparse import csr_matrix, vstack
import joblib
from sklearn.linear_model import SGDClassifier

Let's use esm2_t6_8M_UR50D model.

In [2]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

Let's be specify the range of proteins. We will pick in cell proteins in human, asymetric, with alpha and beta structures.

In [3]:
with (open("rcsb_pdb_ids_cell_HS_mem_asym_ab.txt") as f):
    pdb_ids = f.read()

pdb_ids = pdb_ids.split(',')
len(pdb_ids)

9090

Due to technical limitations let's test training pipeline on 20 proteins sample. With following filtering out proteins with mor than 10000 atoms.

In [4]:
np.random.seed(42)
pdb_ids = np.random.choice(pdb_ids,20)

In [5]:
def download_pdb_files(pdb_ids, download_dir='pdb_structures'):
    """
    Downloads PDB structures given a list of PDB IDs using the requests library.
    """
    os.makedirs(download_dir, exist_ok=True)
    for pdb_id in pdb_ids:
        entry_id = pdb_id.split('_')[0]
        file_path = os.path.join(download_dir, f'{entry_id}.pdb')
        if os.path.exists(file_path):
            continue
        url = f'https://files.rcsb.org/download/{entry_id}.pdb'
        try:
            response = requests.get(url)
            response.raise_for_status()
            file_path = os.path.join(download_dir, f'{entry_id}.pdb')
            with open(file_path, 'w') as f:
                f.write(response.text)
            print(f"Downloaded {entry_id}.pdb")
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {entry_id}: {e}")

In [6]:
download_pdb_files(pdb_ids)

Let's load sequences and labels.

In [None]:
def extract_sequence_and_contacts(pdb_file):
    """
    Extracts the amino acid sequence and computes the contact map from a PDB file using ProDy.
    """
    try:
        structure = parsePDB(pdb_file)
        ca_atoms = structure.select('name CA')

        if ca_atoms is None:
            print(f"No CA atoms found in {pdb_file}")
            return None, None

        if structure.numAtoms() > 10000:
            return None, None

        # Extract sequence
        residues = ca_atoms.getResnames()
        seq = ''
        for resname in residues:
            try:
                one_letter = three_to_one(resname)
                seq += one_letter
            except KeyError:
                seq += 'X'  # Unknown amino acid

        # Compute contact map
        coords = ca_atoms.getCoords()
        num_residues = len(coords)
        contact_map = np.zeros((num_residues, num_residues), dtype=np.float32)

        dist_matrix = buildDistMatrix(coords, coords)
        contact_map = (dist_matrix < 8.0).astype(np.float32)  # Contact threshold

        return seq, contact_map
    except Exception as e:
        print(f"Error processing {pdb_file}: {e}")
        return None, None

In [8]:
sequences = []
contact_maps = []
pdb_files = [os.path.join('pdb_structures', f'{pdb_id.split("_")[0]}.pdb') for pdb_id in pdb_ids]

# Check if files exist
existing_pdb_files = []
for pdb_file in pdb_files:
    if os.path.exists(pdb_file):
        existing_pdb_files.append(pdb_file)
    else:
        print(f"File not found: {pdb_file}")

for pdb_file in existing_pdb_files:
    seq, cmap = extract_sequence_and_contacts(pdb_file)
    if seq is not None and cmap is not None:
        sequences.append(seq)
        contact_maps.append(cmap)

@> 4369 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 5396 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 2388 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 5577 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 5871 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 2054 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 4979 atoms and 1 coordinate set(s) were parsed in 0.02s.
@> 1007 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 2865 atoms and 1 coordinate set(s) were parsed in 0.02s.
@> 11859 atoms and 1 coordinate set(s) were parsed in 0.07s.
@> 2764 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 1039 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 2552 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 2141 atoms and 1 coordinate set(s) were parsed in 0.01s.
@> 5677 atoms and 1 coordinate set(s) were parsed in 0.03s.
@> 12818 atoms and 1 coordinate set(s) were parsed in 0.07s.
@> 11675 atoms and 1 coordinate set(s)

Then let's split data and make embeddings.

In [None]:
def get_esm2_embeddings(sequences):
    """
    Obtains ESM2 embeddings for a list of sequences using Hugging Face Transformers.
    """
    # Load the ESM2 model and tokenizer from the Hugging Face Hub
    model_name = model_checkpoint
    tokenizer = EsmTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name)
    model.eval()

    embeddings = []

    for seq in tqdm(sequences, desc="Computing ESM2 embeddings"):
        # Tokenize the sequence
        inputs = tokenizer(seq, return_tensors='pt', add_special_tokens=False)
        with torch.no_grad():
            outputs = model(**inputs)
            # Extract the embeddings from the last hidden state
            token_embeddings = outputs.last_hidden_state  # (1, sequence_length, hidden_size)
            embedding = token_embeddings[0, :, :]  # (sequence_length, hidden_size)
            embeddings.append(embedding.numpy().astype(np.float32))  # Convert to float32

    return embeddings

In [10]:
# Simple data splitting without clustering
np.random.seed(42)
num_sequences = len(sequences)
indices = list(range(num_sequences))
np.random.shuffle(indices)

train_end = int(0.7 * num_sequences)
val_end = int(0.85 * num_sequences)

train_indices = indices[:train_end]
val_indices = indices[train_end:val_end]
test_indices = indices[val_end:]

train_sequences = [sequences[i] for i in train_indices]
val_sequences = [sequences[i] for i in val_indices]
test_sequences = [sequences[i] for i in test_indices]

train_contact_maps = [contact_maps[i] for i in train_indices]
val_contact_maps = [contact_maps[i] for i in val_indices]
test_contact_maps = [contact_maps[i] for i in test_indices]

# Verify dataset lengths
print(f"Number of training samples: {len(train_sequences)}")
print(f"Number of validation samples: {len(val_sequences)}")
print(f"Number of test samples: {len(test_sequences)}")

Number of training samples: 11
Number of validation samples: 3
Number of test samples: 3


In [11]:
# Get ESM2 embeddings
train_embeddings = get_esm2_embeddings(train_sequences)
val_embeddings = get_esm2_embeddings(val_sequences)
test_embeddings = get_esm2_embeddings(test_sequences)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Computing ESM2 embeddings:   0%|          | 0/11 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Computing ESM2 embeddings:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Computing ESM2 embeddings:   0%|          | 0/3 [00:00<?, ?it/s]

In [12]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, contact_maps, esm_embeddings):
        self.sequences = sequences
        self.contact_maps = contact_maps
        self.esm_embeddings = esm_embeddings

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        embedding = self.esm_embeddings[idx]
        contact_map = self.contact_maps[idx]

        # Convert to tensors with float32 dtype
        embedding = torch.tensor(embedding, dtype=torch.float32)
        contact_map = torch.tensor(contact_map, dtype=torch.float32)

        # Ensure dimensions match
        L_emb = embedding.size(0)
        L_contact = contact_map.size(0)
        assert L_emb == L_contact, f"Dimension mismatch: L_emb={L_emb}, L_contact={L_contact}"

        return embedding, contact_map

In [13]:
# Create datasets
train_dataset = ProteinDataset(train_sequences, train_contact_maps, train_embeddings)
val_dataset = ProteinDataset(val_sequences, val_contact_maps, val_embeddings)
test_dataset = ProteinDataset(test_sequences, test_contact_maps, test_embeddings)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

In [14]:
# Verify dataloader lengths
print(f"Length of train_loader: {len(train_loader)}")
print(f"Length of val_loader: {len(val_loader)}")
print(f"Length of test_loader: {len(test_loader)}")

Length of train_loader: 11
Length of val_loader: 3
Length of test_loader: 3


In [15]:
class ContactPredictionModel(nn.Module):
    def __init__(self, embedding_dim):
        super(ContactPredictionModel, self).__init__()
        self.embedding_dim = embedding_dim

        # The input channels will be 2D (concatenated embeddings)
        self.conv1 = nn.Conv2d(in_channels=embedding_dim * 2, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, embedding):
        """
        embedding: Tensor of shape (batch_size, L, D)
        """
        batch_size, L, D = embedding.size()

        # Compute pairwise features without using expand
        emb_i = embedding.unsqueeze(2)  # (batch_size, L, 1, D)
        emb_j = embedding.unsqueeze(1)  # (batch_size, 1, L, D)

        # Use broadcasting to compute pairwise concatenation
        emb_i = emb_i.repeat(1, 1, L, 1)  # (batch_size, L, L, D)
        emb_j = emb_j.repeat(1, L, 1, 1)  # (batch_size, L, L, D)

        pairwise_emb = torch.cat([emb_i, emb_j], dim=-1)  # (batch_size, L, L, 2D)

        # Transpose to match Conv2d input: (batch_size, channels, height, width)
        x = pairwise_emb.permute(0, 3, 1, 2).contiguous()  # (batch_size, 2D, L, L)

        # Pass through convolutional layers
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = x.squeeze(1)  # Output shape: (batch_size, L, L)

        return x

In [None]:
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for embedding, contact_map in tqdm(dataloader, desc="Training"):
        embedding = embedding.to(device)           # Shape: (batch_size, L, D)
        contact_map = contact_map.to(device)       # Shape: (batch_size, L, L)

        optimizer.zero_grad()
        outputs = model(embedding)        # Shape: (batch_size, L, L)
        loss = criterion(outputs, contact_map)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_targets = []
    all_predictions = []

    with torch.no_grad():
        for embedding, contact_map in tqdm(dataloader, desc="Evaluating"):
            embedding = embedding.to(device)           # Shape: (batch_size, L, D)
            contact_map = contact_map.to(device)       # Shape: (batch_size, L, L)

            outputs = model(embedding)        # Shape: (batch_size, L, L)
            loss = criterion(outputs, contact_map)
            total_loss += loss.item()

            # Convert outputs and targets to binary predictions
            preds = (outputs >= 0.5).float()  # Threshold at 0.5
            all_predictions.append(preds.cpu())
            all_targets.append(contact_map.cpu())

    # Concatenate all predictions and targets using reshape
    all_predictions = torch.cat([p.reshape(-1) for p in all_predictions])
    all_targets = torch.cat([t.reshape(-1) for t in all_targets])

    # Compute metrics
    accuracy = accuracy_score(all_targets.numpy(), all_predictions.numpy())
    f1 = f1_score(all_targets.numpy(), all_predictions.numpy())

    avg_loss = total_loss / len(dataloader)
    return avg_loss, accuracy, f1

In [17]:
# Initialize model, criterion, and optimizer
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

# Determine embedding dimension based on the ESM2 model output
embedding_dim = train_embeddings[0].shape[-1]  # D

In [18]:
# Initialize the model
model = ContactPredictionModel(embedding_dim).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_model(model, train_loader, criterion, optimizer, device)
    val_loss, val_accuracy, val_f1 = evaluate_model(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, "
            f"Val Accuracy = {val_accuracy:.4f}, Val F1-score = {val_f1:.4f}")

# Test the model
test_loss, test_accuracy, test_f1 = evaluate_model(model, test_loader, criterion, device)
print(f"Test Loss = {test_loss:.4f}, Test Accuracy = {test_accuracy:.4f}, Test F1-score = {test_f1:.4f}")


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/10: Train Loss = 0.5079, Val Loss = 0.3618, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/10: Train Loss = 0.2235, Val Loss = 0.2526, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/10: Train Loss = 0.1581, Val Loss = 0.2409, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/10: Train Loss = 0.1500, Val Loss = 0.2454, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/10: Train Loss = 0.1502, Val Loss = 0.2452, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/10: Train Loss = 0.1487, Val Loss = 0.2429, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/10: Train Loss = 0.1474, Val Loss = 0.2406, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/10: Train Loss = 0.1468, Val Loss = 0.2398, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/10: Train Loss = 0.1466, Val Loss = 0.2400, Val Accuracy = 0.9804, Val F1-score = 0.0000


Training:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/10: Train Loss = 0.1462, Val Loss = 0.2413, Val Accuracy = 0.9804, Val F1-score = 0.0000


Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Test Loss = 0.1171, Test Accuracy = 0.9807, Test F1-score = 0.0000


Additional structural information could be incorporated the following way:

In [19]:
def find_similar_structures(sequence, sequences, contact_maps, threshold=0.3):
    """
    Finds structural data from sequences similar to the input sequence.
    """
    similar_structs = []
    seq_len = len(sequence)
    for seq, cmap in zip(sequences, contact_maps):
        min_len = min(len(sequence), len(seq))
        matches = sum(a == b for a, b in zip(sequence[:min_len], seq[:min_len]))
        similarity = matches / min_len
        if similarity > threshold:
            # Resize contact map to match the length of the input sequence
            cmap_resized = cmap[:seq_len, :seq_len]
            similar_structs.append(cmap_resized)

    # Aggregate structural information (e.g., mean contact map)
    if similar_structs:
        aggregated_struct = np.mean(similar_structs, axis=0).astype(np.float32)  # Ensure dtype is float32
        # Ensure the aggregated structure has the correct dimensions
        if aggregated_struct.shape[0] != seq_len:
            aggregated_struct = aggregated_struct[:seq_len, :seq_len]
    else:
        # Return a zero matrix if no similar structures are found
        aggregated_struct = np.zeros((seq_len, seq_len), dtype=np.float32)

    return aggregated_struct

Above code uses only 20 sample due to technical limitations. There are several methods to overcome memory overload:
* Switch to 16-Bit Precision
* Use out-of-core processing with libraries like Dask