# Fine-tuning ESMDance Models on Yeast Data

Objectives
- Fine-tune ESMDance base model and mutant NMA expert model on yeast data.
- Compare to ESM2 base models

## Data Setup

In [2]:
import torch
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split

In [3]:
input_dir = Path('yeast_data')
df = pd.read_csv(str(input_dir / 'avrpikC_full.csv'))
df

Unnamed: 0,aa_sequence,enrichment_score
0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.468796
1,GLKRIIVIKVAREGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.415944
2,GLKRIIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.389615
3,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.359651
4,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.343857
...,...,...
3955,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.041749
3956,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEA...,-1.041749
3957,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-1.057543
3958,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKTEV...,-1.057543


In [4]:
class BindingDataset(Dataset):
    """Dataset for the final binding prediction task."""
    def __init__(self, dataframe):
        self.df = dataframe
        # Ensure the columns exist
        assert 'aa_sequence' in self.df.columns, "DataFrame must have 'aa_sequence' column."
        assert 'enrichment_score' in self.df.columns, "DataFrame must have 'enrichment_score' column."
        
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sequence = self.df.iloc[idx]['aa_sequence']
        # The label is a single float value
        label = torch.tensor(self.df.iloc[idx]['enrichment_score'], dtype=torch.float)

        # Tokenize the sequence
        tokenized_output = self.tokenizer(
            sequence,
            max_length=160, # Use a fixed max length. 158 residues plus 2 extra tokens
            return_tensors='pt'
        )
        
        # Return input_ids: attention masks, removing batch dimension
        inputs = {key: val.squeeze(0) for key, val in tokenized_output.items()}
        
        return inputs, label

In [5]:
full_dataset = BindingDataset(df)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

## Approach 1: Frozen ESM and ESMDance
Use base ESM, ESMDance, and fine-tuned NMA as a feature extractor for simple linear regression head. This method pools each of the features into a single numerical representation, which doesn't actually make that much sense.

In [7]:
import torch
from torch import nn
from scripts.esmdance_flex_model import ESMwrap # Customized ESMDance model definition for dynamic config usage

class FeatureExtractor(nn.Module):
    def __init__(self, original_model_config, nma_model_config, nma_model_path):
        super().__init__()
        
        # Instantiate the original ESMDance with its 50/13 res/pair features config
        print("Initializing original ESMDance model...")
        self.original_esmdance = ESMwrap(model_config=original_model_config)
        self.original_esmdance = self.original_esmdance.from_pretrained("ChaoHou/ESMDance", model_config=original_model_config)

        # Instantiate the NMA-tuned model with the 3/3 res/pair features config
        print(f"Initializing custom NMA-tuned model from {nma_model_path}...")
        self.nma_esmdance = ESMwrap(model_config=nma_model_config)
        self.nma_esmdance.load_state_dict(torch.load(nma_model_path, map_location='cpu'))

        # Freeze parameters
        print("Freezing all parameters in the feature extractor...")
        for param in self.parameters():
            param.requires_grad = False
        self.eval()

    def forward(self, inputs):
        with torch.no_grad():
            md_preds = self.original_esmdance(inputs) # Get features from original ESMDance
            nma_preds = self.nma_esmdance(inputs) # Get features from the mutant, fine-tuned ESMDance
            raw_embeddings = self.original_esmdance.esm2(**inputs).last_hidden_state # Get features from ESM2
            attention_mask = inputs['attention_mask'].unsqueeze(-1) # Add dimension to allow matrix multiplication with raw embeddings
            
            # Pool raw embeddings, averaging across residues to give each sequence a single representation
            pooled_embed = (raw_embeddings * attention_mask).sum(1) / attention_mask.sum(1)

            # Gather and unify dimensions for all MD residue features
            original_res_keys = self.original_esmdance.config['training']['res_feature_idx'].keys()
            md_tensors_to_cat = []
            for k in original_res_keys:
                tensor = md_preds[k]
                if tensor.dim() == 2:
                    # If tensor is 2D (e.g., shape [B, L]), add a feature dimension
                    md_tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    # If tensor is already 3D (e.g., shape [B, L, F]), add it as is
                    md_tensors_to_cat.append(tensor)
            md_res_features = torch.cat(md_tensors_to_cat, dim=-1)

            # Gather and unify dimensions for all 3 NMA residue features
            nma_res_keys = self.nma_esmdance.config['training']['res_feature_idx'].keys()
            nma_tensors_to_cat = []
            for k in nma_res_keys:
                tensor = nma_preds[k]
                if tensor.dim() == 2:
                    nma_tensors_to_cat.append(tensor.unsqueeze(-1))
                else:
                    nma_tensors_to_cat.append(tensor)
            nma_res_features = torch.cat(nma_tensors_to_cat, dim=-1)

            # Pool the correctly shaped features
            pooled_md_res = (md_res_features * attention_mask).sum(1) / attention_mask.sum(1)
            pooled_nma_res = (nma_res_features * attention_mask).sum(1) / attention_mask.sum(1)
            
            # Concatenate all features into one vector
            final_feature_vector = torch.cat([pooled_embed, pooled_md_res, pooled_nma_res], dim=-1)
            
        return final_feature_vector


class BindingHead(nn.Module):
    """
    The small regression head that we will train.
    It takes the concatenated feature vector as input.
    """
    def __init__(self, input_features):
        super().__init__()
        self.regression_head = nn.Sequential(
            nn.Linear(input_features, input_features // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_features // 2, 1) # Output a single value for enrichment
        )

    def forward(self, x):
        return self.regression_head(x)

In [8]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scripts.base_config import config as base_config
from scripts.nma_finetuned_config import config as nma_config
from scipy.stats import spearmanr


# --- SETUP ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
nma_model_path = 'models/esmdance-mutant-nma-fine-tuned/esmdance_fine-tuned_with_nma_data.pth' # Path to your NMA-tuned model

# Initialize the frozen feature extractor and the trainable binding head
feature_extractor = FeatureExtractor(original_model_config=base_config, 
                                     nma_model_config=nma_config, 
                                     nma_model_path=nma_model_path).to(device)

# Determine the input size for the binding head after one forward pass
print("Determining feature vector size...")
with torch.no_grad():
    dummy_inputs, _ = next(iter(train_loader))
    dummy_inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
    dummy_feature_vector = feature_extractor(dummy_inputs,)
    feature_vector_size = dummy_feature_vector.shape[1]

print(f"Concatenated feature vector size: {feature_vector_size}")
binding_head = BindingHead(feature_vector_size).to(device)

# --- LOSS AND OPTIMIZER ---
# MSE is a good loss function for regression tasks like enrichment scores
loss_function = nn.MSELoss()

# Pass ONLY the parameters of the binding_head to the optimizer
optimizer = AdamW(binding_head.parameters(), lr=1e-4)

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting training of the binding head for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    binding_head.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"[Train]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        feature_vector = feature_extractor(inputs)
        predictions = binding_head(feature_vector)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    binding_head.eval()
    total_val_loss = 0

    epoch_predictions = []
    epoch_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Val]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            feature_vector = feature_extractor(inputs)
            predictions = binding_head(feature_vector)
            
            total_val_loss += loss_function(predictions, labels).item()
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    # Concatenate all batch tensors into single, large tensors
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    
    # Calculate Spearman's rank correlation coefficient
    # spearmanr returns two values: the correlation and the p-value
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    # --- Print metrics ---
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")

Using device: cuda
Initializing original ESMDance model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initializing custom NMA-tuned model from models/esmdance-mutant-nma-fine-tuned/esmdance_fine-tuned_with_nma_data.pth...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing all parameters in the feature extractor...
Determining feature vector size...
Concatenated feature vector size: 533
Starting training of the binding head for 5 epochs...

Epoch 1/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:19<00:00, 11.27it/s]
[Val]: 100%|██████████| 25/25 [00:02<00:00, 10.04it/s]


Training Loss: 0.1986 | Validation Loss: 0.1919 | Spearman Correlation: 0.5796

Epoch 2/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:19<00:00, 11.26it/s]
[Val]: 100%|██████████| 25/25 [00:02<00:00,  9.91it/s]


Training Loss: 0.1803 | Validation Loss: 0.1718 | Spearman Correlation: 0.5766

Epoch 3/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:19<00:00, 11.28it/s]
[Val]: 100%|██████████| 25/25 [00:02<00:00, 10.14it/s]


Training Loss: 0.1651 | Validation Loss: 0.1576 | Spearman Correlation: 0.5845

Epoch 4/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:22<00:00,  9.79it/s]
[Val]: 100%|██████████| 25/25 [00:02<00:00,  9.91it/s]


Training Loss: 0.1546 | Validation Loss: 0.1478 | Spearman Correlation: 0.5947

Epoch 5/5
----------------------------


[Train]: 100%|██████████| 223/223 [00:19<00:00, 11.29it/s]
[Val]: 100%|██████████| 25/25 [00:02<00:00,  9.86it/s]

Training Loss: 0.1488 | Validation Loss: 0.1413 | Spearman Correlation: 0.6044





## Approach 2: Fine-tuned ESMDance with Attention
This model loads the fine-tuned ESMDance model on the mutant NMA data with attention to integrate the pair-wise and residue-level outputs of ESMDance. It also unfreezes the final two layers of ESM2 for improved encodings.

In [6]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
from scipy.stats import spearmanr
from transformers import AutoTokenizer

# --- Import your custom NMA-tuned config and the base model class ---
from scripts.nma_finetuned_config import config as nma_config
from scripts.esmdance_flex_model import ESMwrap

class AttentionBindingHead(nn.Module):
    """The advanced binding head that processes per-residue and per-pair features."""
    def __init__(self, embed_dim, pair_dim, num_heads=7):
        super().__init__()
        self.pair_bias_net = nn.Sequential(
            nn.Linear(pair_dim, num_heads), nn.ReLU(), nn.Linear(num_heads, num_heads)
        )
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4), nn.ReLU(), nn.Linear(embed_dim * 4, embed_dim)
        )
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.regressor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(0.2), nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, residue_features, pair_features, attention_mask):
        pair_bias = self.pair_bias_net(pair_features).permute(0, 3, 1, 2)
        batch_size, num_heads, seq_len, _ = pair_bias.shape
        pair_bias = pair_bias.reshape(batch_size * num_heads, seq_len, seq_len)
        padding_mask = (attention_mask == 0)
        
        attn_output, _ = self.attention(
            residue_features, residue_features, residue_features,
            key_padding_mask=padding_mask, attn_mask=pair_bias
        )
        residue_features = self.layer_norm1(residue_features + attn_output)
        ffn_output = self.ffn(residue_features)
        residue_features = self.layer_norm2(residue_features + ffn_output)
        cls_token_embedding = residue_features[:, 0, :]
        return self.regressor(cls_token_embedding)

class NMAFineTuningForBinding(nn.Module):
    """
    A single, end-to-end model for fine-tuning ESMDance on binding data.
    """
    def __init__(self, nma_model_config: dict, nma_model_path: str):
        super().__init__()
        
        # 1. Instantiate your NMA-tuned model with its 3/3 config
        print("Initializing custom NMA-tuned model...")
        self.esmdance_base = ESMwrap(model_config=nma_model_config)
        self.esmdance_base.load_state_dict(torch.load(nma_model_path, map_location='cpu'))

        # 2. --- SELECTIVE UNFREEZING ---
        print("Freezing ESM-2 base and unfreezing top 2 layers and prediction heads...")
        # First, freeze the entire ESM-2 sub-module
        for param in self.esmdance_base.esm2.parameters():
            param.requires_grad = False
        
        # Then, unfreeze only the parameters of the last two transformer layers
        num_layers_to_unfreeze = 2
        for layer in self.esmdance_base.esm2.encoder.layer[-num_layers_to_unfreeze:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # NOTE: The ESMDance prediction heads (`res_pred_nn`, etc.) are separate from
        # `esm2` and will be trainable by default, which is what we want.

        # 3. --- INITIALIZE THE BINDING HEAD ---
        embed_dim = nma_model_config['model_35M']['embed_dim']
        res_out_dim = nma_model_config['model_35M']['res_out_dim']   # This is 3
        pair_out_dim = nma_model_config['model_35M']['pair_out_dim'] # This is 3
        
        self.binding_head = AttentionBindingHead(
            embed_dim=embed_dim + res_out_dim, # 480 + 3 = 483
            pair_dim=pair_out_dim              # 3
        )

    def forward(self, inputs):
        base_outputs = self.esmdance_base(inputs)
        raw_embeddings = self.esmdance_base.esm2(**inputs).last_hidden_state
        
        # Gather NMA residue features (3 of them)
        res_keys = self.esmdance_base.config['training']['res_feature_idx'].keys()
        res_tensors = [base_outputs[k].unsqueeze(-1) if base_outputs[k].dim() == 2 else base_outputs[k] for k in res_keys]
        predicted_res_features = torch.cat(res_tensors, dim=-1)
        
        # Gather NMA pair features
        pair_keys = self.esmdance_base.config['training']['pair_feature_idx'].keys()
        
        # Use torch.cat to join along the existing feature dimension (the last one)
        predicted_pair_features = torch.cat([base_outputs[k] for k in pair_keys], dim=-1)
        
        # Combine inputs for the head
        final_residue_features = torch.cat([raw_embeddings, predicted_res_features], dim=-1)
        
        return self.binding_head(
            residue_features=final_residue_features,
            pair_features=predicted_pair_features,
            attention_mask=inputs['attention_mask']
        )

# =============================================================================
#                            MAIN TRAINING SCRIPT
# =============================================================================
# --- SETUP, DATA LOADING ---
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- MODEL INITIALIZATION ---
nma_model_path = 'models/esmdance-mutant-nma-fine-tuned_relaxed/esmdance_fine-tuned_with_nma_data.pth'

# The new model initialization is now much cleaner
model = NMAFineTuningForBinding(
    nma_model_config=nma_config,
    nma_model_path=nma_model_path
).to(device)

# --- LOSS AND OPTIMIZER ---
loss_function = nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=5e-5) # Use a smaller LR for fine-tuning

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# --- TRAINING & VALIDATION LOOP ---
num_epochs = 5
print(f"Starting fine-tuning for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}\n----------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the binding head to training mode (activates dropout)
    total_train_loss = 0
    
    for inputs, labels in tqdm(train_loader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device).unsqueeze(1) # Reshape labels for MSELoss
        
        optimizer.zero_grad()
        
        predictions = model(inputs)
        
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the binding head to evaluation mode (disables dropout)
    total_val_loss = 0
    epoch_predictions = []
    epoch_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device).unsqueeze(1)
            
            predictions = model(inputs)
            
            total_val_loss += loss_function(predictions, labels).item()

            # Collect predictions and labels for Spearman correlation
            epoch_predictions.append(predictions.cpu().detach())
            epoch_labels.append(labels.cpu().detach())

    avg_val_loss = total_val_loss / len(val_loader)

    # Calculate Spearman Correlation
    all_predictions = torch.cat(epoch_predictions).numpy().flatten()
    all_labels = torch.cat(epoch_labels).numpy().flatten()
    spearman_corr, p_value = spearmanr(all_predictions, all_labels)
    
    print(f"Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f} | Spearman Correlation: {spearman_corr:.4f}")

Using device: cuda
Initializing custom NMA-tuned model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Freezing ESM-2 base and unfreezing top 2 layers and prediction heads...
Total trainable parameters: 9,623,763
Starting fine-tuning for 5 epochs...

Epoch 1/5
----------------------------


[Training]:   0%|          | 0/223 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
[Training]: 100%|██████████| 223/223 [00:51<00:00,  4.34it/s]
[Validation]: 100%|██████████| 25/25 [00:03<00:00,  6.34it/s]


Training Loss: 0.1883 | Validation Loss: 0.1402 | Spearman Correlation: 0.6600

Epoch 2/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:49<00:00,  4.52it/s]
[Validation]: 100%|██████████| 25/25 [00:03<00:00,  6.41it/s]


Training Loss: 0.1297 | Validation Loss: 0.1057 | Spearman Correlation: 0.7073

Epoch 3/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:48<00:00,  4.61it/s]
[Validation]: 100%|██████████| 25/25 [00:03<00:00,  6.52it/s]


Training Loss: 0.1132 | Validation Loss: 0.1033 | Spearman Correlation: 0.7410

Epoch 4/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:23<00:00,  9.69it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.63it/s]


Training Loss: 0.1009 | Validation Loss: 0.0864 | Spearman Correlation: 0.7583

Epoch 5/5
----------------------------


[Training]: 100%|██████████| 223/223 [00:20<00:00, 10.82it/s]
[Validation]: 100%|██████████| 25/25 [00:01<00:00, 17.20it/s]

Training Loss: 0.0971 | Validation Loss: 0.0861 | Spearman Correlation: 0.7762





Without relax (5 epochs): Validation Loss: 0.0964 | Spearman Correlation: 0.7496

With relax (5 epochs): Validation loss: 0.0861 | Spearman correlation: 0.7762

Compared to the most similar baseline, trainable ESMDance without pooling (Validation Loss: 0.0957 | Spearman Correlation: 0.7465), the model fine-tuned on relaxed structures does a little bit better, at least on this run (pending replicates with different seeds). It also does slightly better than training the full 8M ESM2 (Validation Loss: 0.0859 | Spearman Correlation: 0.7586) and 35M ESM2 with two layers unfrozen (Validation Loss: 0.1062 | Spearman Correlation: 0.7602).