# Fine-tuning ESMDance Models on Yeast Data

Objectives
- Fine-tune ESMDance base model and mutant NMA expert model on yeast data.
- Evaluate data efficiency of training.
- Evaluate immunogenicity accuracy.

In [1]:
import torch
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset

In [35]:
input_dir = Path('yeast_data')
df = pd.read_csv(str(input_dir / 'avrpikC_full.csv'))
df

Unnamed: 0,aa_sequence,enrichment_score
0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.468796
1,GLKRIIVIKVAREGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.415944
2,GLKRIIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.389615
3,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.359651
4,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.343857
...,...,...
3955,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.041749
3956,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEA...,-1.041749
3957,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.057543
3958,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKTEV...,-1.057543


In [3]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

In [4]:
class BindingDataset(Dataset):
    """Dataset for the final binding prediction task."""
    def __init__(self, dataframe):
        self.df = dataframe
        # Ensure the columns exist
        assert 'aa_sequence' in self.df.columns, "DataFrame must have 'aa_sequence' column."
        assert 'enrichment_score' in self.df.columns, "DataFrame must have 'enrichment_score' column."
        
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sequence = self.df.iloc[idx]['aa_sequence']
        # The label is a single float value
        label = torch.tensor(self.df.iloc[idx]['enrichment_score'], dtype=torch.float)

        # Tokenize the sequence
        tokenized_output = self.tokenizer(
            sequence,
            truncation=True,
            max_length=1024, # Use a fixed max length
            return_tensors='pt'
        )
        
        inputs = {key: val.squeeze(0) for key, val in tokenized_output.items()}
        
        return inputs, label

In [22]:
import torch
from torch import nn
from scripts.esmdance_flex_model import ESMwrap

class FeatureExtractor(nn.Module):
    def __init__(self, original_model_config, nma_model_config, nma_model_path):
        super().__init__()
        
        # 1. Instantiate the original ESMDance with its 50/13 config
        print("Initializing original ESMDance model...")
        self.original_esmdance = ESMwrap(model_config=original_model_config)
        self.original_esmdance = self.original_esmdance.from_pretrained("ChaoHou/ESMDance",model_config=original_model_config)

        # 2. Instantiate your NMA-tuned model using THE SAME CLASS but with the 3/3 config
        print(f"Initializing custom NMA-tuned model from {nma_model_path}...")
        self.nma_esmdance = ESMwrap(model_config=nma_model_config)
        
        # Load your fine-tuned weights
        self.nma_esmdance.load_state_dict(torch.load(nma_model_path, map_location='cpu'))

        # 3. Freeze ALL parameters
        print("Freezing all parameters in the feature extractor...")
        for param in self.parameters():
            param.requires_grad = False
        self.eval()

    def forward(self, inputs):
        # Extract features from both models
        with torch.no_grad():
            md_preds = self.original_esmdance(inputs)
            nma_preds = self.nma_esmdance(inputs)
            raw_embeddings = self.original_esmdance.esm2(**inputs).last_hidden_state
            
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            
            # Pool raw embeddings
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)

            # --- THIS SECTION IS NOW CORRECTED ---
            
            # 1. Correctly gather and unify dimensions for all MD residue features
            original_res_keys = self.original_esmdance.config['training']['res_feature_idx'].keys()
            md_tensors_to_cat = []
            for k in original_res_keys:
                tensor = md_preds[k]
                if tensor.dim() == 2:
                    # If tensor is 2D (e.g., shape [B, L]), add a feature dimension
                    md_tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    # If tensor is already 3D (e.g., shape [B, L, F]), add it as is
                    md_tensors_to_cat.append(tensor)
            md_res_features = torch.cat(md_tensors_to_cat, dim=-1)

            # 2. Correctly gather and unify dimensions for all 3 NMA residue features
            nma_res_keys = self.nma_esmdance.config['training']['res_feature_idx'].keys()
            nma_tensors_to_cat = []
            for k in nma_res_keys:
                tensor = nma_preds[k]
                if tensor.dim() == 2:
                    nma_tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    nma_tensors_to_cat.append(tensor)
            nma_res_features = torch.cat(nma_tensors_to_cat, dim=-1)
            
            # --- END OF CORRECTION ---

            # Now, pool the correctly shaped features
            pooled_md_res = (md_res_features * attention_mask).sum(1) / attention_mask.sum(1)
            pooled_nma_res = (nma_res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            # Concatenate all features into one vector
            final_feature_vector = torch.cat([pooled_embed, pooled_md_res, pooled_nma_res], dim=-1)
            
        return final_feature_vector


class BindingHead(nn.Module):
    """
    The small regression head that we will train.
    It takes the concatenated feature vector as input.
    """
    def __init__(self, input_features):
        super().__init__()
        self.regression_head = nn.Sequential(
            nn.Linear(input_features, input_features // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_features // 2, 1) # Output a single value for enrichment
        )

    def forward(self, x):
        return self.regression_head(x)

In [30]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scripts.base_config import config as base_config
from scripts.nma_finetuned_config import config as nma_config
from scipy.stats import spearmanr


# --- 1. SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2. DATA LOADING ---
# Load your dataframe with 'aa_sequence' and 'enrichment_score' columns
full_dataset = BindingDataset(df)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# --- 3. MODEL INITIALIZATION ---
nma_model_path = 'models/esmdance-mutant-nma-fine-tuned/esmdance_fine-tuned_with_nma_data.pth' # Path to your NMA-tuned model

# Initialize the frozen feature extractor and the trainable binding head
feature_extractor = FeatureExtractor(original_model_config=base_config, 
                                     nma_model_config=nma_config, 
                                     nma_model_path=nma_model_path).to(device)

# We need to determine the input size for the binding head after one forward pass
print("Determining feature vector size...")
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
    dummy_feature_vector = feature_extractor(dummy_inputs,)
    feature_vector_size = dummy_feature_vector.shape[1]

print(f"Concatenated feature vector size: {feature_vector_size}")
binding_head = BindingHead(feature_vector_size).to(device)

# --- 4. LOSS AND OPTIMIZER ---
# MSE is a good loss function for regression tasks like enrichment scores
loss_function = nn.MSELoss()

# CRITICAL: Pass ONLY the parameters of the binding_head to the optimizer
optimizer = AdamW(binding_head.parameters(), lr=1e-4)

# --- 5. TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting training of the binding head for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Training
    binding_head.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Train]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        feature_vector = feature_extractor(inputs)
        predictions = binding_head(feature_vector)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")
    
    # Validation
    binding_head.eval()
    total_val_loss = 0

    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1} [Val]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            feature_vector = feature_extractor(inputs)
            predictions = binding_head(feature_vector)
            
            total_val_loss += loss_function(predictions, labels).item()
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Concatenate all batch tensors into single, large tensors
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    
    # Calculate Spearman's rank correlation coefficient
    # spearmanr returns two values: the correlation and the p-value
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    # --- UPDATED: Print both metrics ---
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")

Using device: cuda
Initializing original ESMDance model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initializing custom NMA-tuned model from models/esmdance-mutant-nma-fine-tuned/esmdance_fine-tuned_with_nma_data.pth...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing all parameters in the feature extractor...
Determining feature vector size...
Concatenated feature vector size: 533
Starting training of the binding head for 5 epochs...


Epoch 1 [Train]: 100%|██████████| 223/223 [00:18<00:00, 11.88it/s]


Training Loss: 0.1987


Epoch 1 [Val]: 100%|██████████| 25/25 [00:02<00:00, 10.21it/s]


Epoch 1 Validation Loss: 0.1991 | Spearman Correlation: 0.5811


Epoch 2 [Train]: 100%|██████████| 223/223 [00:18<00:00, 11.82it/s]


Training Loss: 0.1827


Epoch 2 [Val]: 100%|██████████| 25/25 [00:02<00:00, 10.27it/s]


Epoch 2 Validation Loss: 0.1795 | Spearman Correlation: 0.5753


Epoch 3 [Train]: 100%|██████████| 223/223 [00:21<00:00, 10.25it/s]


Training Loss: 0.1689


Epoch 3 [Val]: 100%|██████████| 25/25 [00:02<00:00, 10.31it/s]


Epoch 3 Validation Loss: 0.1645 | Spearman Correlation: 0.5827


Epoch 4 [Train]: 100%|██████████| 223/223 [00:18<00:00, 11.78it/s]


Training Loss: 0.1594


Epoch 4 [Val]: 100%|██████████| 25/25 [00:02<00:00, 10.33it/s]


Epoch 4 Validation Loss: 0.1533 | Spearman Correlation: 0.5897


Epoch 5 [Train]: 100%|██████████| 223/223 [00:18<00:00, 11.80it/s]


Training Loss: 0.1508


Epoch 5 [Val]: 100%|██████████| 25/25 [00:02<00:00, 10.29it/s]

Epoch 5 Validation Loss: 0.1478 | Spearman Correlation: 0.5977





In [32]:
from transformers import AutoTokenizer, EsmForSequenceClassification

# --- SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DATA LOADING ---
# Load your dataframe with 'aa_sequence' and 'enrichment_score' columns
#full_dataset = BindingDataset(df)
#
#train_size = int(0.9 * len(full_dataset))
#val_size = len(full_dataset) - train_size
#train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
#
## NOTE: The collate function is handled automatically by the DataLoader for this dataset
#train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=16)

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 35M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t12_35M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# --- FREEZE THE BASE MODEL (for fair comparison) ---
print("Freezing the ESM-2 base model layers...")
for name, param in model.named_parameters():
    if name.startswith("esm."): # This freezes all parameters of the main ESM body
        param.requires_grad = False

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Training
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Train]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")
    
    # Validation
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1} [Val]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing the ESM-2 base model layers...
Total trainable parameters: 231,361
Starting fine-tuning for 5 epochs...


Epoch 1 [Train]: 100%|██████████| 223/223 [00:02<00:00, 91.15it/s] 


Training Loss: 0.2059


Epoch 1 [Val]: 100%|██████████| 25/25 [00:00<00:00, 44.11it/s]


Validation Loss: 0.2098 | Spearman Correlation: 0.5577


Epoch 2 [Train]: 100%|██████████| 223/223 [00:05<00:00, 43.55it/s]


Training Loss: 0.1994


Epoch 2 [Val]: 100%|██████████| 25/25 [00:00<00:00, 42.97it/s]


Validation Loss: 0.2086 | Spearman Correlation: 0.5696


Epoch 3 [Train]: 100%|██████████| 223/223 [00:05<00:00, 43.76it/s]


Training Loss: 0.1931


Epoch 3 [Val]: 100%|██████████| 25/25 [00:00<00:00, 44.20it/s]


Validation Loss: 0.1928 | Spearman Correlation: 0.5754


Epoch 4 [Train]: 100%|██████████| 223/223 [00:05<00:00, 43.53it/s]


Training Loss: 0.1872


Epoch 4 [Val]: 100%|██████████| 25/25 [00:00<00:00, 44.03it/s]


Validation Loss: 0.1897 | Spearman Correlation: 0.5794


Epoch 5 [Train]: 100%|██████████| 223/223 [00:05<00:00, 43.65it/s]


Training Loss: 0.1768


Epoch 5 [Val]: 100%|██████████| 25/25 [00:00<00:00, 43.94it/s]

Validation Loss: 0.1760 | Spearman Correlation: 0.5856





In [44]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

# =============================================================================
# 2. NEW MODEL FOR FINE-TUNING
# =============================================================================
class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    
    This version initializes all layers in the constructor for robustness.
    """
    # CRITICAL CHANGE: We now pass the feature vector size during initialization
    def __init__(self, config, feature_vector_size: int):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False
        
        # CRITICAL CHANGE: Initialize the binding head immediately in __init__
        print(f"Initializing trainable binding head with input size {feature_vector_size}...")
        self.binding_head = nn.Sequential(
            nn.Linear(feature_vector_size, feature_vector_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(feature_vector_size // 2, 1)
        )
    
    def forward(self, inputs):
        # The forward pass now just needs to extract features and pass to the head
        with torch.no_grad():
            # Get predictions and embeddings from the frozen base
            base_preds = self.esmdance_base(inputs)
            raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
            
            # Pooling logic (this can be a helper method if you prefer)
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
            
            res_keys = self.config['training']['res_feature_idx'].keys()
            tensors_to_cat = []
            for k in res_keys:
                tensor = base_preds[k]
                if tensor.dim() == 2:
                    tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    tensors_to_cat.append(tensor)
            res_features = torch.cat(tensors_to_cat, dim=-1)
            pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
        
        # Pass the extracted features through the trainable head
        return self.binding_head(feature_vector)

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================

# --- SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

full_dataset = BindingDataset(df)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# --- MODEL INITIALIZATION ---
# 1. First, perform a "dry run" with the base model to get the feature size
print("Performing dry run to determine feature vector size...")
temp_base_model = ESMwrap(esm2_select='model_35M', model_select='esmdance').to(device)
temp_base_model.eval()
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k, v in dummy_inputs.items()}
    
    # Manually run the feature extraction logic once
    base_preds = temp_base_model(dummy_inputs)
    raw_embeddings = temp_base_model.esm2(**dummy_inputs).last_hidden_state
    attention_mask = dummy_inputs['attention_mask'].unsqueeze(-1)
    pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
    res_keys = base_config['training']['res_feature_idx'].keys()
    tensors_to_cat = [base_preds[k].unsqueeze(-1) if base_preds[k].dim() == 2 else base_preds[k] for k in res_keys]
    res_features = torch.cat(tensors_to_cat, dim=-1)
    pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
    feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
    feature_vector_size = feature_vector.shape[1]
    
    del temp_base_model # Free up memory
    torch.cuda.empty_cache()

print(f"Determined concatenated feature vector size: {feature_vector_size}")

# 2. Now, initialize the final model with the correct size
model = ESMDanceForBinding(config=base_config, feature_vector_size=feature_vector_size).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
# CRITICAL: Pass ONLY the parameters of the new binding_head to the optimizer
optimizer = AdamW(model.binding_head.parameters(), lr=1e-4)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    # --- THIS LOOP WAS MISSING ---
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    # --- END OF MISSING LOOP ---
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss:.4f}")
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        # --- THIS LOOP WAS MISSING ---
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())
        # --- END OF MISSING LOOP ---

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Performing dry run to determine feature vector size...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Determined concatenated feature vector size: 530
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head with input size 530...
Total trainable parameters: 140,981
Starting fine-tuning for 5 epochs...


Epoch 1/5 [Training]: 100%|██████████| 223/223 [00:10<00:00, 21.20it/s] 


Epoch 1 Training Loss: 0.1976


Epoch 1/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.61it/s]


Epoch 1 Validation Loss: 0.1994 | Spearman Correlation: 0.5226


Epoch 2/5 [Training]: 100%|██████████| 223/223 [00:13<00:00, 16.44it/s]


Epoch 2 Training Loss: 0.1784


Epoch 2/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.52it/s]


Epoch 2 Validation Loss: 0.1771 | Spearman Correlation: 0.5228


Epoch 3/5 [Training]: 100%|██████████| 223/223 [00:10<00:00, 20.98it/s] 


Epoch 3 Training Loss: 0.1624


Epoch 3/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.46it/s]


Epoch 3 Validation Loss: 0.1638 | Spearman Correlation: 0.5301


Epoch 4/5 [Training]: 100%|██████████| 223/223 [00:13<00:00, 16.38it/s]


Epoch 4 Training Loss: 0.1517


Epoch 4/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.51it/s]


Epoch 4 Validation Loss: 0.1551 | Spearman Correlation: 0.5399


Epoch 5/5 [Training]: 100%|██████████| 223/223 [00:10<00:00, 20.99it/s]


Epoch 5 Training Loss: 0.1462


Epoch 5/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.55it/s]

Epoch 5 Validation Loss: 0.1500 | Spearman Correlation: 0.5493





In [None]:
from transformers import AutoTokenizer, EsmForSequenceClassification

# --- SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DATA LOADING ---
# Load your dataframe with 'aa_sequence' and 'enrichment_score' columns
#full_dataset = BindingDataset(df)
#
#train_size = int(0.9 * len(full_dataset))
#val_size = len(full_dataset) - train_size
#train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
#
## NOTE: The collate function is handled automatically by the DataLoader for this dataset
#train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=16)

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 35M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# --- FREEZE THE BASE MODEL (for fair comparison) ---
#print("Freezing the ESM-2 base model layers...")
#for name, param in model.named_parameters():
#    if name.startswith("esm."): # This freezes all parameters of the main ESM body
#        param.requires_grad = False

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Training
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Train]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")
    
    # Validation
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1} [Val]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total trainable parameters: 7,840,442
Starting fine-tuning for 5 epochs...


Epoch 1 [Train]: 100%|██████████| 223/223 [00:06<00:00, 36.05it/s]


Training Loss: 0.1964


Epoch 1 [Val]: 100%|██████████| 25/25 [00:00<00:00, 99.50it/s]


Validation Loss: 0.1346 | Spearman Correlation: 0.6494


Epoch 2 [Train]: 100%|██████████| 223/223 [00:05<00:00, 38.05it/s]


Training Loss: 0.1243


Epoch 2 [Val]: 100%|██████████| 25/25 [00:00<00:00, 98.88it/s] 


Validation Loss: 0.1107 | Spearman Correlation: 0.7377


Epoch 3 [Train]: 100%|██████████| 223/223 [00:05<00:00, 38.36it/s]


Training Loss: 0.0993


Epoch 3 [Val]: 100%|██████████| 25/25 [00:00<00:00, 96.96it/s]


Validation Loss: 0.0902 | Spearman Correlation: 0.7721


Epoch 4 [Train]: 100%|██████████| 223/223 [00:05<00:00, 38.55it/s]


Training Loss: 0.0923


Epoch 4 [Val]: 100%|██████████| 25/25 [00:00<00:00, 98.34it/s]


Validation Loss: 0.1009 | Spearman Correlation: 0.7779


Epoch 5 [Train]: 100%|██████████| 223/223 [00:02<00:00, 80.03it/s]


Training Loss: 0.0828


Epoch 5 [Val]: 100%|██████████| 25/25 [00:00<00:00, 98.38it/s]

Validation Loss: 0.0812 | Spearman Correlation: 0.7952





In [48]:
from transformers import AutoTokenizer, EsmForSequenceClassification

# --- SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DATA LOADING ---
# Load your dataframe with 'aa_sequence' and 'enrichment_score' columns
#full_dataset = BindingDataset(df)
#
#train_size = int(0.9 * len(full_dataset))
#val_size = len(full_dataset) - train_size
#train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
#
## NOTE: The collate function is handled automatically by the DataLoader for this dataset
#train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=16)

# --- MODEL INITIALIZATION ---
print("Initializing ESM-2 35M for Sequence Classification...")

model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t12_35M_UR50D",
    num_labels=1,                # We are predicting one continuous value.
    problem_type="regression"    # Configure the model for regression.
).to(device)

# FREEZE THE BASE, UNFREEZE LAST TWO LAYERS ---
print("Freezing the ESM-2 base model and unfreezing the final 2 layers...")

# 1. First, freeze all parameters of the entire base model
for param in model.esm.parameters():
    param.requires_grad = False

# 2. Now, unfreeze only the parameters of the last two transformer layers
# The ESM-2 35M model has 12 layers (0-11) in model.esm.encoder.layer
num_layers_to_unfreeze = 2
for layer in model.esm.encoder.layer[-num_layers_to_unfreeze:]:
    for param in layer.parameters():
        param.requires_grad = True

# --- LOSS AND OPTIMIZER ---
# We will let the model calculate its own loss during training, but define it for validation
loss_function = nn.MSELoss() 

# The optimizer will automatically ignore frozen parameters
optimizer = AdamW(model.parameters(), lr=1e-4)

# Count trainable parameters to confirm the base is frozen
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Training
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Train]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # The model automatically calculates loss when labels are provided
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")
    
    # Validation
    model.eval()
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1} [Val]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Get model predictions (logits)
            outputs = model(**inputs)
            predictions = outputs.logits
            
            # Calculate validation loss manually
            total_val_loss += loss_function(predictions.squeeze(), labels).item()
            
            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu())
            epoch_labels.append(labels.cpu())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Initializing ESM-2 35M for Sequence Classification...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing the ESM-2 base model and unfreezing the final 2 layers...
Total trainable parameters: 5,773,441
Starting fine-tuning for 5 epochs...


Epoch 1 [Train]: 100%|██████████| 223/223 [00:07<00:00, 30.44it/s]


Training Loss: 0.1878


Epoch 1 [Val]: 100%|██████████| 25/25 [00:00<00:00, 44.48it/s]


Validation Loss: 0.1496 | Spearman Correlation: 0.6232


Epoch 2 [Train]: 100%|██████████| 223/223 [00:07<00:00, 31.35it/s]


Training Loss: 0.1269


Epoch 2 [Val]: 100%|██████████| 25/25 [00:00<00:00, 43.26it/s]


Validation Loss: 0.1198 | Spearman Correlation: 0.6874


Epoch 3 [Train]: 100%|██████████| 223/223 [00:07<00:00, 31.41it/s]


Training Loss: 0.1065


Epoch 3 [Val]: 100%|██████████| 25/25 [00:00<00:00, 42.59it/s]


Validation Loss: 0.0966 | Spearman Correlation: 0.7498


Epoch 4 [Train]: 100%|██████████| 223/223 [00:07<00:00, 31.51it/s]


Training Loss: 0.0919


Epoch 4 [Val]: 100%|██████████| 25/25 [00:00<00:00, 43.79it/s]


Validation Loss: 0.1060 | Spearman Correlation: 0.7629


Epoch 5 [Train]: 100%|██████████| 223/223 [00:04<00:00, 53.71it/s] 


Training Loss: 0.0893


Epoch 5 [Val]: 100%|██████████| 25/25 [00:00<00:00, 44.45it/s]

Validation Loss: 0.0880 | Spearman Correlation: 0.7812





In [50]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
import importlib.util

# --- Import the base model class ---
from scripts.esmdance_base_model import ESMwrap
from scripts.base_config import config as base_config

# =============================================================================
# 2. NEW MODEL FOR FINE-TUNING
# =============================================================================
class ESMDanceForBinding(nn.Module):
    """
    A wrapper model that uses the pre-trained ESMDance as a frozen feature extractor
    and adds a new, trainable regression head on top.
    
    This version initializes all layers in the constructor for robustness.
    """
    # CRITICAL CHANGE: We now pass the feature vector size during initialization
    def __init__(self, config, feature_vector_size: int):
        super().__init__()
        self.config = config
        print("Initializing base ESMDance model with original 50/13 heads...")
        self.esmdance_base = ESMwrap(esm2_select='model_35M', model_select='esmdance')
        
        original_weights_path = 'pretrained_weights/esmdance_update_60000.pt'
        print(f"Loading original weights from {original_weights_path}...")
        self.esmdance_base.load_state_dict(torch.load(original_weights_path, map_location='cpu'))

        # Freeze all parameters of the base model
        for param in self.esmdance_base.parameters():
            param.requires_grad = False

        # Unfreeze final layers of ESMDance
        for param in self.esmdance_base.res_pred_nn.parameters():
            param.requires_grad = True
        
        for param in self.esmdance_base.res_transform_nn.parameters():
            param.requires_grad = True

        for param in self.esmdance_base.pair_middle_linear.parameters():
            param.requires_grad = True

        for param in self.esmdance_base.pair_pred_linear.parameters():
            param.requires_grad = True

        # Unfreeze final layers of ESM2
        num_layers_to_unfreeze = 2
        for layer in self.esmdance_base.esm2.encoder.layer[-num_layers_to_unfreeze:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # CRITICAL CHANGE: Initialize the binding head immediately in __init__
        print(f"Initializing trainable binding head with input size {feature_vector_size}...")
        self.binding_head = nn.Sequential(
            nn.Linear(feature_vector_size, feature_vector_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(feature_vector_size // 2, 1)
        )
    
    def forward(self, inputs):
        # The forward pass now just needs to extract features and pass to the head
        with torch.no_grad():
            # Get predictions and embeddings from the frozen base
            base_preds = self.esmdance_base(inputs)
            raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
            
            # Pooling logic (this can be a helper method if you prefer)
            attention_mask = inputs['attention_mask'].unsqueeze(-1)
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
            
            res_keys = self.config['training']['res_feature_idx'].keys()
            tensors_to_cat = []
            for k in res_keys:
                tensor = base_preds[k]
                if tensor.dim() == 2:
                    tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    tensors_to_cat.append(tensor)
            res_features = torch.cat(tensors_to_cat, dim=-1)
            pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
        
        # Pass the extracted features through the trainable head
        return self.binding_head(feature_vector)

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================

# --- SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

full_dataset = BindingDataset(df)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# --- MODEL INITIALIZATION ---
# 1. First, perform a "dry run" with the base model to get the feature size
print("Performing dry run to determine feature vector size...")
temp_base_model = ESMwrap(esm2_select='model_35M', model_select='esmdance').to(device)
temp_base_model.eval()
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k, v in dummy_inputs.items()}
    
    # Manually run the feature extraction logic once
    base_preds = temp_base_model(dummy_inputs)
    raw_embeddings = temp_base_model.esm2(**dummy_inputs).last_hidden_state
    attention_mask = dummy_inputs['attention_mask'].unsqueeze(-1)
    pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)
    res_keys = base_config['training']['res_feature_idx'].keys()
    tensors_to_cat = [base_preds[k].unsqueeze(-1) if base_preds[k].dim() == 2 else base_preds[k] for k in res_keys]
    res_features = torch.cat(tensors_to_cat, dim=-1)
    pooled_res_features = (res_features * attention_mask).sum(1) / attention_mask.sum(1)
    feature_vector = torch.cat([pooled_embed, pooled_res_features], dim=-1)
    feature_vector_size = feature_vector.shape[1]
    
    del temp_base_model # Free up memory
    torch.cuda.empty_cache()

print(f"Determined concatenated feature vector size: {feature_vector_size}")

# 2. Now, initialize the final model with the correct size
model = ESMDanceForBinding(config=base_config, feature_vector_size=feature_vector_size).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
# CRITICAL: Pass ONLY the parameters of the new binding_head to the optimizer
optimizer = AdamW(model.binding_head.parameters(), lr=1e-4)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    # --- THIS LOOP WAS MISSING ---
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    # --- END OF MISSING LOOP ---
    
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss:.4f}")
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        # --- THIS LOOP WAS MISSING ---
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())
        # --- END OF MISSING LOOP ---

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")


Using device: cuda
Performing dry run to determine feature vector size...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Determined concatenated feature vector size: 530
Initializing base ESMDance model with original 50/13 heads...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading original weights from pretrained_weights/esmdance_update_60000.pt...
Initializing trainable binding head with input size 530...
Total trainable parameters: 6,869,444
Starting fine-tuning for 5 epochs...


Epoch 1/5 [Training]: 100%|██████████| 223/223 [00:13<00:00, 16.42it/s]


Epoch 1 Training Loss: 0.1981


Epoch 1/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.85it/s]


Epoch 1 Validation Loss: 0.2119 | Spearman Correlation: 0.5905


Epoch 2/5 [Training]: 100%|██████████| 223/223 [00:09<00:00, 24.72it/s] 


Epoch 2 Training Loss: 0.1826


Epoch 2/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.20it/s]


Epoch 2 Validation Loss: 0.1943 | Spearman Correlation: 0.5917


Epoch 3/5 [Training]: 100%|██████████| 223/223 [00:13<00:00, 16.62it/s]


Epoch 3 Training Loss: 0.1675


Epoch 3/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.54it/s]


Epoch 3 Validation Loss: 0.1778 | Spearman Correlation: 0.5977


Epoch 4/5 [Training]: 100%|██████████| 223/223 [00:11<00:00, 20.08it/s] 


Epoch 4 Training Loss: 0.1569


Epoch 4/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.67it/s]


Epoch 4 Validation Loss: 0.1692 | Spearman Correlation: 0.6063


Epoch 5/5 [Training]: 100%|██████████| 223/223 [00:13<00:00, 16.39it/s]


Epoch 5 Training Loss: 0.1489


Epoch 5/5 [Validation]: 100%|██████████| 25/25 [00:01<00:00, 16.63it/s]

Epoch 5 Validation Loss: 0.1643 | Spearman Correlation: 0.6141



