# Fine-tune ESMDance on Custom NMA Data

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from transformers import AutoTokenizer

In [2]:
class FineTuneNMADataset(Dataset):
    """A dataset for fine-tuning on custom NMA features from .npz files."""
    def __init__(self, sequences, nma_features_paths):
        self.sequences = sequences
        self.nma_features_paths = nma_features_paths
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self,idx):
        sequence = self.sequences[idx]
        tokenized_output = self.tokenizer(sequence, return_tensors='pt')
        inputs = {key: val.squeeze(0) for key, val in tokenized_output.items()}

        # Load NMA features
        nma_data = np.load(self.nma_features_paths[idx])
        gnm_msf = torch.from_numpy(nma_data['gnm_msf']).float()
        anm_cor = torch.from_numpy(nma_data['anm_cor']).float()

        labels = {
            'nma_residue1': gnm_msf[0],
            'nma_residue2': gnm_msf[1],
            'nma_residue3': gnm_msf[2],
            'nma_pair1': anm_cor[0],
            'nma_pair2': anm_cor[1],
            'nma_pair3': anm_cor[2]
        }
        
        return inputs, labels

In [3]:
def collate_fn_nma(batch):
    """
    Collator function to pad sequences and features at the batch level.
    This version correctly pads all labels to match the tokenized input length.
    """
    batch_inputs = [item[0] for item in batch]
    batch_labels = [item[1] for item in batch]

    # --- Pad Inputs (This part was always correct) ---
    padded_inputs = {}
    padded_inputs['input_ids'] = pad_sequence(
        [b['input_ids'] for b in batch_inputs], batch_first=True, padding_value=1
    )
    padded_inputs['attention_mask'] = pad_sequence(
        [b['attention_mask'] for b in batch_inputs], batch_first=True, padding_value=0
    )
    
    # This is the target length for all tensors (e.g., 160)
    max_len = padded_inputs['input_ids'].shape[1]

    # --- Pad Labels (CORRECTED LOGIC) ---
    padded_labels = {}
    for key in batch_labels[0].keys():
        if 'residue' in key:
            # --- THIS IS THE FIX ---
            padded_tensors = []
            for b in batch_labels:
                tensor = b[key]  # Shape: (num_residues,) e.g., (158,)
                # Manually pad each residue tensor to the full token length (160)
                num_padding = max_len - tensor.shape[0]
                padded_tensor = torch.nn.functional.pad(tensor, (0, num_padding), value=-1)
                padded_tensors.append(padded_tensor)
            # Stack the now correctly-sized tensors
            padded_labels[key] = torch.stack(padded_tensors)
        
        elif 'pair' in key:
            # The pairwise padding logic was already correct
            padded_tensors = []
            for b in batch_labels:
                tensor = b[key]
                n = tensor.shape[0]
                padded_tensor = torch.nn.functional.pad(tensor, (0, max_len - n, 0, max_len - n), value=-1)
                padded_tensors.append(padded_tensor)
            padded_labels[key] = torch.stack(padded_tensors)

    return padded_inputs, padded_labels

In [5]:
import pandas as pd

mut_df = pd.read_csv('mutant_library/h1-c/mutant_library.csv')
sequences = mut_df['sequence'].tolist()
nma_paths = [f"nma/{mut_string}.npz" for mut_string in mut_df['mut']]

nma_paths

['nma/G62M-V69N-R38M-C14H.npz',
 'nma/S24M.npz',
 'nma/K52E.npz',
 'nma/T25E.npz',
 'nma/K59T-A23Q-R15Y-V31Q-D49S.npz',
 'nma/S55G-V28D.npz',
 'nma/L2S-V10A-S24Y.npz',
 'nma/V7D-D63I-A64N-Y46K.npz',
 'nma/S55L-G26L-K3Y-L2K-C14V.npz',
 'nma/D63T-I6L-V34F-G45I.npz',
 'nma/V44M-K3F-L66E-D36S-A23K.npz',
 'nma/D29N-K74H-K59E-K73A.npz',
 'nma/E65T.npz',
 'nma/K40R-S16P.npz',
 'nma/Q4R-V43Y-N13A-L37E.npz',
 'nma/L21G-L57R-V69H-V31R-A32K.npz',
 'nma/D39V-K5G-Q68G-V10Y.npz',
 'nma/S30F-A11W-S70Q-V43N.npz',
 'nma/P50H-G35H-C14F-L2H.npz',
 'nma/P50W-G47M-K9H-T25P.npz',
 'nma/I41G-K40I-M12T.npz',
 'nma/K3Y-K9S-S16A.npz',
 'nma/G47I-V34Q-V31N-L2A.npz',
 'nma/T25R-A32Q.npz',
 'nma/V28L-D29T-L57G-Q68N-G26T.npz',
 'nma/L33Y.npz',
 'nma/R15E-V28M-M19F.npz',
 'nma/K5N-G45S-V61K-S30D-A32T.npz',
 'nma/D49K-G62C-I41K-M19H-R15I.npz',
 'nma/L53R.npz',
 'nma/G1A-L57I.npz',
 'nma/A23Y-K74F-G27H.npz',
 'nma/L53Q-L57W-K17Q-K40H.npz',
 'nma/V22P.npz',
 'nma/G1I-L53C.npz',
 'nma/G35E-D63E.npz',
 'nma/L33C-S55E-G1A

In [6]:
config = {
    # This section defines where your fine-tuned model and logs will be saved.
    "file_path": {
        "save_dir": "models/esmdance-mutant-nma-fine-tuned/", 
    },
    
    # General training settings.
    "training": {
        "random_seed": 42,
        "dropout": 0.1,
        
        # You should adjust these based on how often you want to save and log.
        # For a shorter fine-tuning run, you'll want to save more frequently.
        "save_per_epoch": 1, # It's easier to think in epochs for fine-tuning.
        
        # --- Feature Indices ---
        "res_feature_idx": {
            'nma_residue1': [0],
            'nma_residue2': [1],
            'nma_residue3': [2],
        },
        "pair_feature_idx": {
            'nma_pair1': [0],
            'nma_pair2': [1],
            'nma_pair3': [2],
        },
    },

    "esmdance": {
        "freeze_esm": True,      # Correct for ESMDance fine-tuning.
        "randomize_esm": False,
        
        # All your sequences are 158, so we set one max_len.
        # We add a little buffer, but it could be exactly 158.
        "max_len": 256,
        
        # Define training by epochs, which is more intuitive for a fixed dataset.
        "num_epochs": 20, # Adjust this based on how your loss behaves.
        
        # Set a single batch size. Adjust based on your GPU memory.
        "batch_size": 4,
        
        # Gradient accumulation helps simulate a larger batch size.
        # Effective batch size = batch_size * gradient_accumulation_steps
        # Example: 8 * 4 = 32
        "gradient_accumulation_steps": 4, 
    },

    # Optimizer settings. These are generally good starting points.
    "optimizer": {
        "peak_lr": 1e-4,
        "epsilon": 1e-8,
        "betas": (0.9, 0.98),
        "weight_decay": 0.01,
        "warmup_steps": 200, # Number of steps for learning rate warmup.
    },

    # --- CRITICAL CHANGE 3: Model Output Dimensions ---
    "model_35M": {
        "model_id": "facebook/esm2_t12_35M_UR50D",
        "atten_dim": 240,
        "embed_dim": 480,
        
        # These now match your NMA-only data.
        "pair_out_dim": 3, # Was 13. Now 3 for your 3 ANM correlation matrices.
        "res_out_dim": 3,  # Was 50. Now 3 for your 3 GNM fluctuation vectors.
    },
}

In [7]:
from torch.utils.data import random_split

full_dataset = FineTuneNMADataset(sequences, nma_paths)

# Split data into training and validation sets (e.g., 90% train, 10% val)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['esmdance']['batch_size'], 
    shuffle=True, 
    collate_fn=collate_fn_nma
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['esmdance']['batch_size'],
    shuffle=False, # No need to shuffle validation data
    collate_fn=collate_fn_nma
)

In [8]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers import EsmModel
from huggingface_hub import PyTorchModelHubMixin

In [9]:
class ESMwrap(nn.Module, PyTorchModelHubMixin):
    def __init__(self, esm2_select, model_select):
        super().__init__()
        # Load the ESM2 model
        self.esm2 = EsmModel.from_pretrained(config[esm2_select]['model_id'])
        self.freeze_esm = config[model_select]['freeze_esm']

        # Freeze self.esm2 parameters if freeze_esm is True
        if self.freeze_esm:
            for param in self.esm2.parameters():
                param.requires_grad = False
            self.esm2.eval()  # Set to evaluation mode

        # Randomize self.esm2 parameters if randomize_esm is True
        if config[model_select]['randomize_esm']:
            self.randomize_model(self.esm2)

        # dimensions of input and output
        embed_dim = config[esm2_select]['embed_dim']
        res_out_dim = config[esm2_select]['res_out_dim']
        atten_dim = config[esm2_select]['atten_dim']
        pair_out_dim = config[esm2_select]['pair_out_dim']

        # Residue-level prediction layer
        self.res_pred_nn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.LayerNorm(embed_dim),
            nn.Dropout(config['training']['dropout']),  # Apply dropout after LayerNorm
            nn.Linear(embed_dim, res_out_dim)
        )

        # transform res embedding for Pairwise prediction
        self.res_transform_nn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.LayerNorm(embed_dim),
            nn.Dropout(config['training']['dropout']),  # Apply dropout after LayerNorm
            nn.Linear(embed_dim, embed_dim*2)
        )

        # Pairwise prediction layer
        self.pair_middle_linear = nn.Linear(embed_dim*2, atten_dim)
        self.pair_pred_linear = nn.Linear(atten_dim + atten_dim, pair_out_dim)

        # Activation functions
        self.gelu = nn.GELU()
        self.softplus = nn.Softplus(beta=1.0, threshold=2.0)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=-1)

        # Feature indices from config
        self.res_feature_idx = config['training']['res_feature_idx']
        self.pair_feature_idx = config['training']['pair_feature_idx']

        # Initialize biases to zero
        self._init_bias_zero()

    def randomize_model(self, model):
        """ Randomize the parameters of the given model. """
        for module_ in model.named_modules():
            if isinstance(module_[1], (torch.nn.Linear, torch.nn.Embedding)):
                if hasattr(module_[1], 'bias') and module_[1].bias is not None:
                    module_[1].bias.data.zero_()
                if hasattr(module_[1], 'weight'):
                    if 'query' in module_[0] or 'key' in module_[0] or 'value' in module_[0]:
                        module_[1].weight = nn.init.xavier_uniform_(module_[1].weight, gain=1 / math.sqrt(2))
                    else:
                        module_[1].weight = nn.init.xavier_uniform_(module_[1].weight)
                            
            elif isinstance(module_[1], nn.LayerNorm):
                if hasattr(module_[1], 'bias'):
                    module_[1].bias.data.zero_()
                if hasattr(module_[1], 'weight'):
                    module_[1].weight.data.fill_(1.0)
                
            elif isinstance(module_[1], nn.Dropout):
                module_[1].p = config['training']['dropout']


    def _init_bias_zero(self):
        """ Set all biases in the model (excluding esm2) to zero. """
        for name, module in self.named_modules():
            if "esm2" not in name and isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)


    def forward(self, inputs, return_res_emb=False, return_attention_map=False, return_res_pred=True, return_pair_pred=True):
        output = {}

        # ESM forward pass, Ensure no gradients are stored for frozen ESM2
        if self.freeze_esm:
            with torch.no_grad():
                esm_output = self.esm2(**inputs, output_attentions=True)
        else:
            esm_output = self.esm2(**inputs, output_attentions=True)

        res_emb = esm_output['last_hidden_state']
        pair_atten = torch.cat(esm_output['attentions'], dim=1).permute(0, 2, 3, 1)

        if return_res_emb:
            output['res_emb'] = res_emb
        if return_attention_map:
            output['attention_map'] = pair_atten

        # Residue-level prediction
        if return_res_pred:
            res_pred = self.res_pred_nn(res_emb)
            for feature in self.res_feature_idx:
                if feature == 'rmsf_nor':
                    # Normalized RMSF (max = 1)
                    output[feature] = self.sigmoid(res_pred[:, :, self.res_feature_idx[feature]])
                elif feature in ['ss', 'chi', 'phi', 'psi']:
                    # Secondary structure, chi, phi, psi sum up to 1
                    output[feature] = self.softmax(res_pred[:, :, self.res_feature_idx[feature]])
                else:
                    # All other features are non-negative
                    output[feature] = self.softplus(res_pred[:, :, self.res_feature_idx[feature]])

        # Pairwise transformation
        s = self.res_transform_nn(res_emb)
        q, k = s.chunk(2, dim=-1)
        prod = q[:, None, :, :] * k[:, :, None, :]
        diff = q[:, None, :, :] - k[:, :, None, :]
        pair_middle = self.gelu(self.pair_middle_linear(torch.cat([prod, diff], dim=-1)))

        # Pairwise prediction
        if return_pair_pred:
            pair_pred = self.pair_pred_linear(torch.cat([pair_middle, pair_atten], dim=-1))

            for feature in self.pair_feature_idx:
                if feature in ['corr', 'nma_pair1', 'nma_pair2', 'nma_pair3']:
                    # Co-movement and NMA co-movement correlations: range [-1, 1]
                    output[feature] = self.sigmoid(pair_pred[:, :, :, self.pair_feature_idx[feature]]) * 2 - 1.0
                else:
                    # All interaction features are non-negative
                    output[feature] = self.softplus(pair_pred[:, :, :, self.pair_feature_idx[feature]])

        return output

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [11]:
# Set random seeds for reproducibility
torch.manual_seed(config['training']['random_seed'])
np.random.seed(config['training']['random_seed'])

In [12]:
from pathlib import Path
# Create save directory from config
save_dir = Path(config['file_path']['save_dir'])
save_dir.mkdir(parents=True, exist_ok=True)

In [13]:
model = ESMwrap(esm2_select='model_35M', model_select='esmdance').to(device)
checkpoint = torch.load('pretrained_weights/esmdance_update_60000.pt')
model_state_dict = model.state_dict()

filtered_state_dict = {
            k: v for k, v in checkpoint.items() 
            if k in model_state_dict and v.shape == model_state_dict[k].shape
        }
        
# Update our new model's state dict with the filtered weights
model_state_dict.update(filtered_state_dict)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model.load_state_dict(model_state_dict, strict=False)

<All keys matched successfully>

In [15]:
from torchinfo import summary

batch_size = 1
sequence_length = 160 # 158 residues + 2 special tokens

dummy_input_ids = torch.randint(0, 33, (batch_size, sequence_length), dtype=torch.long)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)

dummy_inputs_dict = {
    "input_ids": dummy_input_ids,
    "attention_mask": dummy_attention_mask
}

# This tells torchinfo to pass the dictionary as a single positional argument,
# which matches your model's forward(self, inputs) signature.
input_data_for_summary = (dummy_inputs_dict,)

summary(
    model,
    # Pass the TUPLE containing the dictionary
    input_data=input_data_for_summary,
    
    # The other summary arguments are for torchinfo itself and will no longer be passed to your model
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
    device="cpu"
)

Layer (type (var_name))                                                     Input Shape          Output Shape         Param #              Trainable
ESMwrap (ESMwrap)                                                           [1, 160]             [1, 160, 160, 1]     --                   Partial
├─EsmModel (esm2)                                                           --                   [1, 20, 160, 160]    241                  False
│    └─EsmEmbeddings (embeddings)                                           --                   [1, 160, 480]        492,480              False
│    │    └─Embedding (word_embeddings)                                     [1, 160]             [1, 160, 480]        (15,840)             False
│    └─EsmEncoder (encoder)                                                 [1, 160, 480]        [1, 20, 160, 160]    --                   False
│    │    └─ModuleList (layer)                                              --                   --                   (33,25

In [16]:
from torch.optim import AdamW
from torch.amp import GradScaler, autocast

loss_function = nn.MSELoss(reduction='none') # Use reduction='none' for custom masking
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"Total trainable parameters: {sum(p.numel() for p in trainable_params):,}")
optimizer = AdamW(trainable_params, lr=config['optimizer']['peak_lr'], betas=config['optimizer']['betas'])

scaler = GradScaler() # For mixed-precision training

Total trainable parameters: 1,158,966


In [17]:
from tqdm import tqdm
from torch.amp import autocast

model.to(device)

num_epochs = config['esmdance']['num_epochs']
grad_accum_steps = config['esmdance']['gradient_accumulation_steps']

print(f"Starting fine-tuning for {num_epochs} epochs...")
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}\n------------------------------')
    # =======================================
    #               TRAINING
    # =======================================
    model.train() # Set the model to training mode
    total_train_loss = 0
    
    for i, (inputs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Training]")):
        inputs = {key: val.to(device) for key, val in inputs.items()}
        labels = {key: val.to(device) for key, val in labels.items()}
        
        with autocast(device_type='cuda', dtype=torch.float16):
            predictions = model(inputs)
            
            # --- Loss Calculation ---
            res_loss = 0
            res_mask = labels['nma_residue1'] != -1
            for k in ['nma_residue1', 'nma_residue2', 'nma_residue3']:
                pred_k = predictions[k].squeeze(-1)
                label_k = labels[k]
                element_wise_loss = loss_function(pred_k, label_k)
                valid_losses = element_wise_loss[res_mask]
                res_loss += valid_losses.mean()

            pair_loss = 0
            pair_mask = labels['nma_pair1'] != -1
            for k in ['nma_pair1', 'nma_pair2', 'nma_pair3']:
                pred_pair_k = predictions[k].squeeze(-1)
                label_pair_k = labels[k]
                element_wise_loss = loss_function(pred_pair_k, label_pair_k)
                valid_losses = element_wise_loss[pair_mask]
                pair_loss += valid_losses.mean()
            
            loss = (3 * pair_loss + res_loss) / grad_accum_steps
        
        # --- Gradient Accumulation & Backpropagation ---
        scaler.scale(loss).backward()
        
        if (i + 1) % grad_accum_steps == 0 or (i + 1) == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_train_loss += loss.item() * grad_accum_steps
        
    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")

    # =======================================
    #              VALIDATION
    # =======================================
    model.eval() # Set the model to evaluation mode
    total_val_loss = 0
    
    # Disable gradient calculations for validation to save memory and compute
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]")):
            inputs = {key: val.to(device) for key, val in inputs.items()}
            labels = {key: val.to(device) for key, val in labels.items()}
            
            # Forward pass only, still use autocast for consistency
            with autocast(device_type='cuda', dtype=torch.float16):
                predictions = model(inputs)

                # --- Loss Calculation (Identical to training) ---
                res_loss = 0
                res_mask = labels['nma_residue1'] != -1
                for k in ['nma_residue1', 'nma_residue2', 'nma_residue3']:
                    pred_k = predictions[k].squeeze(-1)
                    label_k = labels[k]
                    element_wise_loss = loss_function(pred_k, label_k)
                    valid_losses = element_wise_loss[res_mask]
                    res_loss += valid_losses.mean()

                pair_loss = 0
                pair_mask = labels['nma_pair1'] != -1
                for k in ['nma_pair1', 'nma_pair2', 'nma_pair3']:
                    pred_pair_k = predictions[k].squeeze(-1)
                    label_pair_k = labels[k]
                    element_wise_loss = loss_function(pred_pair_k, label_pair_k)
                    valid_losses = element_wise_loss[pair_mask]
                    pair_loss += valid_losses.mean()
                
                # Note: We do NOT divide by grad_accum_steps for validation loss
                val_loss = 3 * pair_loss + res_loss
            
            total_val_loss += val_loss.item()
            
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

Starting fine-tuning for 20 epochs...

Epoch 1/20
------------------------------


Epoch 1/20 [Training]: 100%|██████████| 68/68 [00:06<00:00, 10.87it/s]


Training Loss: 1.0352


Epoch 1/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.13it/s]


Validation Loss: 0.4626

Epoch 2/20
------------------------------


Epoch 2/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.37it/s]


Training Loss: 0.3651


Epoch 2/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.93it/s]


Validation Loss: 0.2625

Epoch 3/20
------------------------------


Epoch 3/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.41it/s]


Training Loss: 0.2364


Epoch 3/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 20.22it/s]


Validation Loss: 0.1970

Epoch 4/20
------------------------------


Epoch 4/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.01it/s]


Training Loss: 0.1832


Epoch 4/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.21it/s]


Validation Loss: 0.1592

Epoch 5/20
------------------------------


Epoch 5/20 [Training]: 100%|██████████| 68/68 [00:01<00:00, 42.40it/s]


Training Loss: 0.1487


Epoch 5/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 18.80it/s]


Validation Loss: 0.1296

Epoch 6/20
------------------------------


Epoch 6/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.09it/s]


Training Loss: 0.1218


Epoch 6/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.38it/s]


Validation Loss: 0.1063

Epoch 7/20
------------------------------


Epoch 7/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.89it/s]


Training Loss: 0.1008


Epoch 7/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.50it/s]


Validation Loss: 0.0881

Epoch 8/20
------------------------------


Epoch 8/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.67it/s]


Training Loss: 0.0847


Epoch 8/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.03it/s]


Validation Loss: 0.0740

Epoch 9/20
------------------------------


Epoch 9/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.23it/s]


Training Loss: 0.0725


Epoch 9/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 21.94it/s]


Validation Loss: 0.0632

Epoch 10/20
------------------------------


Epoch 10/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.38it/s]


Training Loss: 0.0633


Epoch 10/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.54it/s]


Validation Loss: 0.0549

Epoch 11/20
------------------------------


Epoch 11/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.49it/s]


Training Loss: 0.0557


Epoch 11/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.24it/s]


Validation Loss: 0.0482

Epoch 12/20
------------------------------


Epoch 12/20 [Training]: 100%|██████████| 68/68 [00:01<00:00, 53.47it/s]


Training Loss: 0.0497


Epoch 12/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 24.03it/s]


Validation Loss: 0.0429

Epoch 13/20
------------------------------


Epoch 13/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.71it/s]


Training Loss: 0.0449


Epoch 13/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.12it/s]


Validation Loss: 0.0383

Epoch 14/20
------------------------------


Epoch 14/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.95it/s]


Training Loss: 0.0410


Epoch 14/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 23.26it/s]


Validation Loss: 0.0348

Epoch 15/20
------------------------------


Epoch 15/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.58it/s]


Training Loss: 0.0374


Epoch 15/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.42it/s]


Validation Loss: 0.0316

Epoch 16/20
------------------------------


Epoch 16/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.71it/s]


Training Loss: 0.0346


Epoch 16/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 21.96it/s]


Validation Loss: 0.0292

Epoch 17/20
------------------------------


Epoch 17/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.84it/s]


Training Loss: 0.0321


Epoch 17/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 24.66it/s]


Validation Loss: 0.0267

Epoch 18/20
------------------------------


Epoch 18/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.22it/s]


Training Loss: 0.0300


Epoch 18/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.42it/s]


Validation Loss: 0.0249

Epoch 19/20
------------------------------


Epoch 19/20 [Training]: 100%|██████████| 68/68 [00:01<00:00, 44.30it/s]


Training Loss: 0.0281


Epoch 19/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 21.27it/s]


Validation Loss: 0.0233

Epoch 20/20
------------------------------


Epoch 20/20 [Training]: 100%|██████████| 68/68 [00:04<00:00, 15.46it/s]


Training Loss: 0.0266


Epoch 20/20 [Validation]: 100%|██████████| 8/8 [00:00<00:00, 22.28it/s]

Validation Loss: 0.0219





In [18]:
model_save_path = save_dir / "esmdance_fine-tuned_with_nma_data.pth"
torch.save(model.state_dict(), model_save_path)