# Baseline Fine-tuning For Comparison

## Data setup

In [2]:
import torch
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split

In [3]:
input_dir = Path('yeast_data')
df = pd.read_csv(str(input_dir / 'avrpikC_full.csv'))
df

Unnamed: 0,aa_sequence,enrichment_score
0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.468796
1,GLKRIIVIKVAREGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.415944
2,GLKRIIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.389615
3,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.359651
4,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.343857
...,...,...
3955,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.041749
3956,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEA...,-1.041749
3957,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.057543
3958,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKTEV...,-1.057543


In [4]:
class BindingDataset(Dataset):
    """Dataset for the final binding prediction task."""
    def __init__(self, dataframe):
        self.df = dataframe
        # Ensure the columns exist
        assert 'aa_sequence' in self.df.columns, "DataFrame must have 'aa_sequence' column."
        assert 'enrichment_score' in self.df.columns, "DataFrame must have 'enrichment_score' column."
        
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sequence = self.df.iloc[idx]['aa_sequence']
        # The label is a single float value
        label = torch.tensor(self.df.iloc[idx]['enrichment_score'], dtype=torch.float)

        # Tokenize the sequence
        tokenized_output = self.tokenizer(
            sequence,
            max_length=160, # Use a fixed max length. 158 residues plus 2 extra tokens
            return_tensors='pt'
        )
        
        # Return input_ids: attention masks, removing batch dimension
        inputs = {key: val.squeeze(0) for key, val in tokenized_output.items()}
        
        return inputs, label

In [5]:
full_dataset = BindingDataset(df)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

## Baseline 1: Fine-tuning the full ESM2 8M model
This is the default method that Alexander used, though he used the huggingface transformer library, so results are slightly different.

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from transformers import EsmForSequenceClassification
from scipy.stats import spearmanr
import tqdm

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 8M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total trainable parameters: 7,840,442
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:11<00:00, 18.69it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 55.55it/s]


Training Loss: 0.1978 | Validation Loss: 0.1723 | Spearman Correlation: 0.5596

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:14<00:00, 15.11it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 41.27it/s]


Training Loss: 0.1420 | Validation Loss: 0.1197 | Spearman Correlation: 0.6244

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:11<00:00, 19.16it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 52.15it/s]


Training Loss: 0.1210 | Validation Loss: 0.1192 | Spearman Correlation: 0.6939

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:15<00:00, 14.38it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 33.62it/s]


Training Loss: 0.1044 | Validation Loss: 0.0923 | Spearman Correlation: 0.7229

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:18<00:00, 11.80it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 48.93it/s]

Training Loss: 0.0966 | Validation Loss: 0.0859 | Spearman Correlation: 0.7586





## Baseline 2: 35M ESM2-only feature extractor
For a simple comparison, what happens when you only use ESM2 as the feature extractor for the binding head above. This uses the special class token.

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from transformers import AutoTokenizer, EsmForSequenceClassification
import tqdm
from scipy.stats import spearmanr


# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 35M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t12_35M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# --- FREEZE THE BASE MODEL (for fair comparison) ---
print("Freezing the ESM-2 base model layers...")
for name, param in model.named_parameters():
    if name.startswith("esm."): # This freezes all parameters of the main ESM body
        param.requires_grad = False

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing the ESM-2 base model layers...
Total trainable parameters: 231,361
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:05<00:00, 37.48it/s] 
[Val]: 100%|██████████| 25/25 [00:00<00:00, 29.42it/s]


Training Loss: 0.2091 | Validation Loss: 0.1943 | Spearman Correlation: 0.5128

Epoch 2/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:08<00:00, 26.61it/s]
[Val]: 100%|██████████| 25/25 [00:00<00:00, 26.79it/s]


Training Loss: 0.2015 | Validation Loss: 0.1979 | Spearman Correlation: 0.5311

Epoch 3/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:08<00:00, 27.20it/s]
[Val]: 100%|██████████| 25/25 [00:00<00:00, 25.71it/s]


Training Loss: 0.1930 | Validation Loss: 0.1779 | Spearman Correlation: 0.5412

Epoch 4/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:05<00:00, 42.07it/s]
[Val]: 100%|██████████| 25/25 [00:00<00:00, 28.01it/s]


Training Loss: 0.1841 | Validation Loss: 0.1641 | Spearman Correlation: 0.5499

Epoch 5/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:08<00:00, 27.57it/s]
[Val]: 100%|██████████| 25/25 [00:00<00:00, 26.22it/s]

Training Loss: 0.1734 | Validation Loss: 0.1513 | Spearman Correlation: 0.5563





## Baseline 3: 35M ESM2 Model with Last Two Layers Unfrozen
This is to test another strategy for fine-tuning that was detailed in the ESMEffect paper.

In [None]:
from transformers import AutoTokenizer, EsmForSequenceClassification

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 35M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t12_35M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# FREEZE THE BASE, UNFREEZE LAST TWO LAYERS ---
print("Freezing the ESM-2 base model and unfreezing the final 2 layers...")

# 1. First, freeze all parameters of the entire base model
for param in model.esm.parameters():
    param.requires_grad = False

# 2. Now, unfreeze only the parameters of the last two transformer layers
# The ESM-2 35M model has 12 layers (0-11) in model.esm.encoder.layer
num_layers_to_unfreeze = 2
for layer in model.esm.encoder.layer[-num_layers_to_unfreeze:]:
    for param in layer.parameters():
        param.requires_grad = True

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")

Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing the ESM-2 base model and unfreezing the final 2 layers...
Total trainable parameters: 5,773,441
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:13<00:00, 16.42it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 19.41it/s]


Training Loss: 0.1906 | Validation Loss: 0.1245 | Spearman Correlation: 0.6330

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:11<00:00, 18.79it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 30.68it/s]


Training Loss: 0.1268 | Validation Loss: 0.1068 | Spearman Correlation: 0.7100

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:10<00:00, 20.96it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 28.77it/s]


Training Loss: 0.1140 | Validation Loss: 0.0906 | Spearman Correlation: 0.7507

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:08<00:00, 27.46it/s] 
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 25.91it/s]


Training Loss: 0.1022 | Validation Loss: 0.0885 | Spearman Correlation: 0.7503

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:10<00:00, 20.68it/s]
[Validation]: 100%|██████████| 25/25 [00:00<00:00, 28.79it/s]

Training Loss: 0.0930 | Validation Loss: 0.1062 | Spearman Correlation: 0.7602





## Baseline 4: Using ESMDance only as a feature extractor
For another comparison to approach one, I'm using just ESMDance predictions as a feature extractor. Essentially using predicted molecular dynamics features. I should revisit this, as pooling these features for the binding head makes zero sense.

In [None]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    """
    def __init__(self, config, feature_vector_size: int):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False
        
        # Initialize binding head
        print(f"Initializing trainable binding head with input size {feature_vector_size}...")
        self.binding_head = nn.Sequential(
            nn.Linear(feature_vector_size, feature_vector_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(feature_vector_size // 2, 1)
        )
    
    def forward(self, inputs):
        # The forward pass now just needs to extract features and pass to the head
        with torch.no_grad():
            # Get predictions and embeddings from the frozen base
            base_preds = self.esmdance_base(inputs)
            raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
            
            # Pooling logic (this can be a helper method if you prefer)
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
            
            res_keys = self.config['training']['res_feature_idx'].keys()
            tensors_to_cat = []
            for k in res_keys:
                tensor = base_preds[k]
                if tensor.dim() == 2:
                    tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    tensors_to_cat.append(tensor)
            res_features = torch.cat(tensors_to_cat, dim=-1)
            pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
        
        # Pass the extracted features through the trainable head
        return self.binding_head(feature_vector)

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
# 1. Perform a "dry run" with the base model to get the feature size
print("Performing dry run to determine feature vector size...")
temp_base_model = ESMwrap(esm2_select='model_35M', model_select='esmdance').to(device)
temp_base_model.eval()
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k, v in dummy_inputs.items()}
    
    # Manually run the feature extraction logic once
    base_preds = temp_base_model(dummy_inputs)
    raw_embeddings = temp_base_model.esm2(**dummy_inputs).last_hidden_state
    attention_mask = dummy_inputs['attention_mask'].unsqueeze(-1)
    pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
    res_keys = base_config['training']['res_feature_idx'].keys()
    tensors_to_cat = [base_preds[k].unsqueeze(-1) if base_preds[k].dim() == 2 else base_preds[k] for k in res_keys]
    res_features = torch.cat(tensors_to_cat, dim=-1)
    pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
    feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
    feature_vector_size = feature_vector.shape[1]
    
    del temp_base_model # Free up memory
    torch.cuda.empty_cache()

print(f"Determined concatenated feature vector size: {feature_vector_size}")

# Now we can initialize the final model with the correct size
model = ESMDanceForBinding(config=base_config, feature_vector_size=feature_vector_size).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
# Pass the parameters of the new binding_head to the optimizer
optimizer = AdamW(model.binding_head.parameters(), lr=1e-4)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")

Using device: cuda
Performing dry run to determine feature vector size...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Determined concatenated feature vector size: 530
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head with input size 530...
Total trainable parameters: 140,981
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:16<00:00, 13.33it/s]
[Validation]: 100%|██████████| 25/25 [00:02<00:00, 12.29it/s]


Training Loss: 0.1982 | Validation Loss: 0.1783 | Spearman Correlation: 0.5279

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:19<00:00, 11.23it/s]
[Validation]: 100%|██████████| 25/25 [00:02<00:00, 10.90it/s]


Training Loss: 0.1818 | Validation Loss: 0.1612 | Spearman Correlation: 0.5305

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:16<00:00, 13.38it/s]
[Validation]: 100%|██████████| 25/25 [00:02<00:00, 10.80it/s]


Training Loss: 0.1690 | Validation Loss: 0.1480 | Spearman Correlation: 0.5356

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:16<00:00, 13.49it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 12.55it/s]


Training Loss: 0.1571 | Validation Loss: 0.1409 | Spearman Correlation: 0.5436

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:19<00:00, 11.61it/s]
[Validation]: 100%|██████████| 25/25 [00:02<00:00, 11.71it/s]

Training Loss: 0.1514 | Validation Loss: 0.1326 | Spearman Correlation: 0.5535





## Baseline 5: ESMDance with Extra Layers Unfrozen
This model unfreezes all of the layers added by ESMDance to ESM2, as well the final two layers of ESM2.

In [None]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    """
    def __init__(self, config, feature_vector_size: int):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False

        # Unfreeze final layers of ESMDance
        for param in self.esmdance_base.res_pred_nn.parameters():
            param.requires_grad = True
        
        for param in self.esmdance_base.res_transform_nn.parameters():
            param.requires_grad = True

        for param in self.esmdance_base.pair_middle_linear.parameters():
            param.requires_grad = True

        for param in self.esmdance_base.pair_pred_linear.parameters():
            param.requires_grad = True

        # Unfreeze final layers of ESM2
        num_layers_to_unfreeze = 2
        for layer in self.esmdance_base.esm2.encoder.layer[-num_layers_to_unfreeze:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # Initialize the binding head
        print(f"Initializing trainable binding head with input size {feature_vector_size}...")
        self.binding_head = nn.Sequential(
            nn.Linear(feature_vector_size, feature_vector_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(feature_vector_size // 2, 1)
        )
    
    def forward(self, inputs):
        # The forward pass now just needs to extract features and pass to the head
        with torch.no_grad():
            # Get predictions and embeddings from the frozen base
            base_preds = self.esmdance_base(inputs)
            raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
            
            # Pooling logic (this can be a helper method if you prefer)
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
            
            res_keys = self.config['training']['res_feature_idx'].keys()
            tensors_to_cat = []
            for k in res_keys:
                tensor = base_preds[k]
                if tensor.dim() == 2:
                    tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    tensors_to_cat.append(tensor)
            res_features = torch.cat(tensors_to_cat, dim=-1)
            pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
        
        # Pass the extracted features through the trainable head
        return self.binding_head(feature_vector)

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
# 1. Perform a "dry run" with the base model to get the feature size
print("Performing dry run to determine feature vector size...")
temp_base_model = ESMwrap(esm2_select='model_35M', model_select='esmdance').to(device)
temp_base_model.eval()
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k, v in dummy_inputs.items()}
    
    # Manually run the feature extraction logic once
    base_preds = temp_base_model(dummy_inputs)
    raw_embeddings = temp_base_model.esm2(**dummy_inputs).last_hidden_state
    attention_mask = dummy_inputs['attention_mask'].unsqueeze(-1)
    pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
    res_keys = base_config['training']['res_feature_idx'].keys()
    tensors_to_cat = [base_preds[k].unsqueeze(-1) if base_preds[k].dim() == 2 else base_preds[k] for k in res_keys]
    res_features = torch.cat(tensors_to_cat, dim=-1)
    pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
    feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
    feature_vector_size = feature_vector.shape[1]
    
    del temp_base_model # Free up memory
    torch.cuda.empty_cache()

print(f"Determined concatenated feature vector size: {feature_vector_size}")

# Initialize the final model with the correct size
model = ESMDanceForBinding(config=base_config, feature_vector_size=feature_vector_size).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=1e-4)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")



Using device: cuda
Performing dry run to determine feature vector size...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Determined concatenated feature vector size: 530
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head with input size 530...
Total trainable parameters: 6,869,444
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:13<00:00, 17.09it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.34it/s]


Training Loss: 0.1982 | Validation Loss: 0.1783 | Spearman Correlation: 0.5279

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:10<00:00, 20.95it/s] 
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.27it/s]


Training Loss: 0.1818 | Validation Loss: 0.1612 | Spearman Correlation: 0.5305

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:13<00:00, 16.87it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.16it/s]


Training Loss: 0.1690 | Validation Loss: 0.1480 | Spearman Correlation: 0.5356

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:10<00:00, 21.00it/s] 
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.32it/s]


Training Loss: 0.1571 | Validation Loss: 0.1409 | Spearman Correlation: 0.5436

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:13<00:00, 17.01it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.16it/s]

Training Loss: 0.1514 | Validation Loss: 0.1326 | Spearman Correlation: 0.5535





## Baseline 6: Trainable ESMDance with no pooling
This model changes the binding head to an attention head to process all of the information from ESMDance. It also unfreezes all parameters through the model from the end to the two final layers of ESM2.

In [None]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

class AttentionBindingHead(nn.Module):
    """
    A binding head that processes per-residue and per-pair features using a biased attention mechanism.
    """
    def __init__(self, embed_dim, pair_dim, num_heads=10):
        super().__init__()
        
        # A small MLP to process the pairwise features into an attention bias
        self.pair_bias_net = nn.Sequential(
            nn.Linear(pair_dim, num_heads),
            nn.ReLU(),
            nn.Linear(num_heads, num_heads)
        )
        
        # A standard multi-head self-attention layer
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)

        # Final feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim*4, embed_dim)
        )
        self.layer_norm2 = nn.LayerNorm(embed_dim)

        # Final regression head that acts on the [CLS] token embedding
        self.regressor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),\
            nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, residue_features, pair_features, attention_mask):
        # residue_features: [batch, seq_len, embed_dim]
        # pair_features:    [batch, seq_len, seq_len, pair_dim]
        # attention_mask:   [batch, seq_len]

        # Create the attention bias from pairwise features
        # Shape: [batch, seq_len, seq_len, num_heads] -> [batch * num_heads, seq_len, seq_len]
        pair_bias = self.pair_bias_net(pair_features).permute(0, 3, 1, 2)
        batch_size, num_heads, seq_len, _ = pair_bias.shape
        pair_bias = pair_bias.reshape(batch_size * num_heads, seq_len, seq_len)

        # Create the padding mask for the attention layer
        # Shape: [batch, seq_len] -> [batch, 1, 1, seq_len] -> broadcast
        padding_mask = (attention_mask == 0)

        # Perform biased self-attention
        attn_output, _ = self.attention(
            residue_features, residue_features, residue_features,
            key_padding_mask=padding_mask,
            attn_mask=pair_bias
        )
        residue_features = self.layer_norm1(residue_features + attn_output)
        
        # Pass through the feed-forward network
        ffn_output = self.ffn(residue_features)
        residue_features = self.layer_norm2(residue_features + ffn_output)
        
        # Select the [CLS] token embedding (at index 0) for the final prediction
        cls_token_embedding = residue_features[:, 0, :]
        
        return self.regressor(cls_token_embedding)

class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False

        # Unfreeze final layers of ESMDance
        for param in self.esmdance_base.res_pred_nn.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.res_transform_nn.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.pair_middle_linear.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.pair_pred_linear.parameters():
            param.requires_grad = True

        # Unfreeze final layers of ESM2
        num_layers_to_unfreeze = 2
        for layer in self.esmdance_base.esm2.encoder.layer[-num_layers_to_unfreeze:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # Initialize the binding head
        print(f"Initializing trainable binding head...")
        # Input to the head will be raw embeddings + predicted residue features
        embed_dim = config['model_35M']['embed_dim']
        res_out_dim = config['model_35M']['res_out_dim']
        pair_out_dim = config['model_35M']['pair_out_dim']
        
        self.binding_head = AttentionBindingHead(
            embed_dim=embed_dim + res_out_dim, # 480 + 50 = 530
            pair_dim=pair_out_dim              # 13
        )
    
    def forward(self, inputs):
        # The forward pass is now trainable for some layers
        base_outputs = self.esmdance_base(inputs)
        raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
        
        # --- Gather Features (No Pooling) ---
        # 1. Gather residue features
        res_keys = base_config['training']['res_feature_idx'].keys()
        res_tensors = [base_outputs[k].unsqueeze(-1) if base_outputs[k].dim() == 2 else base_outputs[k] for k in res_keys]
        predicted_res_features = torch.cat(res_tensors, dim=-1)
        
        # 2. Gather pairwise features
        pair_keys = base_config['training']['pair_feature_idx'].keys()
        predicted_pair_features = torch.stack([base_outputs[k] for k in pair_keys], dim=-1)
        
        # 3. Concatenate raw embeddings with predicted residue features
        final_residue_features = torch.cat([raw_embeddings, predicted_res_features], dim=-1)
        
        # 4. Pass everything to the new binding head
        return self.binding_head(
            residue_features=final_residue_features,
            pair_features=predicted_pair_features,
            attention_mask=inputs['attention_mask']
        )

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
model = ESMDanceForBinding(config=base_config).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
# The optimizer will now correctly find all trainable parameters
optimizer = AdamW(model.parameters(), lr=5e-5) # Use a smaller LR for fine-tuning

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")



Using device: cuda
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head...
Total trainable parameters: 10,247,384
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.81it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.43it/s]


Training Loss: 0.2024 | Validation Loss: 0.1537 | Spearman Correlation: 0.5905

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.92it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.67it/s]


Training Loss: 0.1299 | Validation Loss: 0.1152 | Spearman Correlation: 0.6771

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.74it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.68it/s]


Training Loss: 0.1082 | Validation Loss: 0.1074 | Spearman Correlation: 0.7111

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:23<00:00,  9.54it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.63it/s]


Training Loss: 0.0992 | Validation Loss: 0.1002 | Spearman Correlation: 0.7402

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.86it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.43it/s]

Training Loss: 0.0990 | Validation Loss: 0.0957 | Spearman Correlation: 0.7465





Just out of curiousity, how much of this is being aided by ESM2 layers being tunable? Below, I'll freeze them and see how it performs

In [None]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

class AttentionBindingHead(nn.Module):
    """
    A binding head that processes per-residue and per-pair features using a biased attention mechanism.
    """
    def __init__(self, embed_dim, pair_dim, num_heads=10):
        super().__init__()
        
        # A small MLP to process the pairwise features into an attention bias
        self.pair_bias_net = nn.Sequential(
            nn.Linear(pair_dim, num_heads),
            nn.ReLU(),
            nn.Linear(num_heads, num_heads)
        )
        
        # A standard multi-head self-attention layer
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)

        # Final feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim*4, embed_dim)
        )
        self.layer_norm2 = nn.LayerNorm(embed_dim)

        # Final regression head that acts on the [CLS] token embedding
        self.regressor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),\
            nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, residue_features, pair_features, attention_mask):
        # residue_features: [batch, seq_len, embed_dim]
        # pair_features:    [batch, seq_len, seq_len, pair_dim]
        # attention_mask:   [batch, seq_len]

        # Create the attention bias from pairwise features
        # Shape: [batch, seq_len, seq_len, num_heads] -> [batch * num_heads, seq_len, seq_len]
        pair_bias = self.pair_bias_net(pair_features).permute(0, 3, 1, 2)
        batch_size, num_heads, seq_len, _ = pair_bias.shape
        pair_bias = pair_bias.reshape(batch_size * num_heads, seq_len, seq_len)

        # Create the padding mask for the attention layer
        # Shape: [batch, seq_len] -> [batch, 1, 1, seq_len] -> broadcast
        padding_mask = (attention_mask == 0)

        # Perform biased self-attention
        attn_output, _ = self.attention(
            residue_features, residue_features, residue_features,
            key_padding_mask=padding_mask,
            attn_mask=pair_bias
        )
        residue_features = self.layer_norm1(residue_features + attn_output)
        
        # Pass through the feed-forward network
        ffn_output = self.ffn(residue_features)
        residue_features = self.layer_norm2(residue_features + ffn_output)
        
        # Select the [CLS] token embedding (at index 0) for the final prediction
        cls_token_embedding = residue_features[:, 0, :]
        
        return self.regressor(cls_token_embedding)

class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False

        # Unfreeze final layers of ESMDance
        for param in self.esmdance_base.res_pred_nn.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.res_transform_nn.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.pair_middle_linear.parameters():
            param.requires_grad = True
        for param in self.esmdance_base.pair_pred_linear.parameters():
            param.requires_grad = True

        ## Unfreeze final layers of ESM2
        #num_layers_to_unfreeze = 2
        #for layer in self.esmdance_base.esm2.encoder.layer[-num_layers_to_unfreeze:]:
        #    for param in layer.parameters():
        #        param.requires_grad = True
        
        # Initialize the binding head
        print(f"Initializing trainable binding head...")
        # Input to the head will be raw embeddings + predicted residue features
        embed_dim = config['model_35M']['embed_dim']
        res_out_dim = config['model_35M']['res_out_dim']
        pair_out_dim = config['model_35M']['pair_out_dim']
        
        self.binding_head = AttentionBindingHead(
            embed_dim=embed_dim + res_out_dim, # 480 + 50 = 530
            pair_dim=pair_out_dim              # 13
        )
    
    def forward(self, inputs):
        # The forward pass is now trainable for some layers
        base_outputs = self.esmdance_base(inputs)
        raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
        
        # --- Gather Features (No Pooling) ---
        # 1. Gather residue features
        res_keys = base_config['training']['res_feature_idx'].keys()
        res_tensors = [base_outputs[k].unsqueeze(-1) if base_outputs[k].dim() == 2 else base_outputs[k] for k in res_keys]
        predicted_res_features = torch.cat(res_tensors, dim=-1)
        
        # 2. Gather pairwise features
        pair_keys = base_config['training']['pair_feature_idx'].keys()
        predicted_pair_features = torch.stack([base_outputs[k] for k in pair_keys], dim=-1)
        
        # 3. Concatenate raw embeddings with predicted residue features
        final_residue_features = torch.cat([raw_embeddings, predicted_res_features], dim=-1)
        
        # 4. Pass everything to the new binding head
        return self.binding_head(
            residue_features=final_residue_features,
            pair_features=predicted_pair_features,
            attention_mask=inputs['attention_mask']
        )

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================

# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
model = ESMDanceForBinding(config=base_config).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
# The optimizer will now correctly find all trainable parameters
optimizer = AdamW(model.parameters(), lr=5e-5) # Use a smaller LR for fine-tuning

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")



Using device: cuda
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head...
Total trainable parameters: 4,705,304
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:18<00:00, 12.12it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.85it/s]


Training Loss: 0.2125 | Validation Loss: 0.1987 | Spearman Correlation: 0.4990

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:17<00:00, 12.44it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.05it/s]


Training Loss: 0.1842 | Validation Loss: 0.1737 | Spearman Correlation: 0.5247

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.68it/s]
[Validation]: 100%|██████████| 25/25 [-1:59:59<00:00, -21.89it/s]


Training Loss: 0.1403 | Validation Loss: 0.1424 | Spearman Correlation: 0.5836

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.71it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.70it/s]


Training Loss: 0.1307 | Validation Loss: 0.1361 | Spearman Correlation: 0.6150

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:18<00:00, 12.31it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 15.96it/s]

Training Loss: 0.1210 | Validation Loss: 0.1242 | Spearman Correlation: 0.6408



