In [8]:
from ProteinTrajectoryDataset import *
import numpy as np
import torch
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import optuna
from Models import *

device_ids = [0]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(f"Using device {device}, device_ids={device_ids}")

%load_ext autoreload
%autoreload 2

MAX_PROTEIN_LENGTH = 50

Using device cuda:0, device_ids=[0]
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
train_dataset = ProteinTrajectoryDataset('train_trajectories', n_steps=1)
val_dataset = ProteinTrajectoryDataset('val_trajectories', n_steps=1)

Using device cuda:0, device_ids=[0]
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Loading trajectories: 100%|████████████████████████████████████████████████████████████| 80/80 [03:08<00:00,  2.35s/it]
Loading trajectories: 100%|████████████████████████████████████████████████████████████| 20/20 [00:41<00:00,  2.05s/it]


NameError: name 'collate_fn' is not defined

In [4]:
def collate_fn(batch):
    coords = [item['coords'] for item in batch]
    residues = [item['residues'] for item in batch]
    deltas = [item['delta'] for item in batch]
    lengths = torch.tensor([c.shape[0] for c in coords])
    # Pad sequences to the maximum length in the batch (batch_first=True gives shape [batch, max_length, features])
    coords = pad_sequence(coords, batch_first=True)
    residues = pad_sequence(residues, batch_first=True)
    deltas = pad_sequence(deltas, batch_first=True)
    return coords, residues, deltas, lengths

def masked_mse_loss(pred, target, lengths):
    batch_size, max_length, _ = pred.shape
    # Create a mask with shape (batch_size, max_length) where each element is True if it is a valid timestep.
    mask = torch.arange(max_length, device=pred.device).expand(batch_size, max_length) < lengths.unsqueeze(1)
    # Expand mask to match the last dimension (3) of our tensors.
    mask = mask.unsqueeze(2).float()  # shape becomes (batch_size, max_length, 1)
    
    # Compute squared differences
    mse = (pred - target) ** 2
    # Zero-out padded elements using the mask and compute average loss only over valid elements.
    loss = (mse * mask).sum() / mask.sum()
    return loss

def train_model(model, dataloader, optimizer, num_epochs=10):
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_samples = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for coords, residues, deltas, lengths in pbar:
            
            # Move data to device
            coords = coords.to(device)
            residues = residues.to(device)
            deltas = deltas.to(device)
            lengths = lengths.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass: predict deltas using currents.
            pred_deltas = model(coords, residues, lengths)
            
            # Compute the masked MSE loss.
            loss = masked_mse_loss(pred_deltas, deltas, lengths)
            
            # Backward pass and optimization.
            loss.backward()
            optimizer.step()
            
            # Update running loss.
            batch_size = coords.size(0)
            running_loss += loss.item() * batch_size
            total_samples += batch_size
            
            avg_loss = running_loss / total_samples
            pbar.set_postfix(avg_loss=avg_loss, current_loss=loss.item())
        
        # Print epoch loss summary.
        epoch_loss = running_loss / total_samples
        print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss:.5f}")

def eval_model(model, dataloader):
    
    model.eval()
    running_loss = 0.0
    total_samples = 0
    
    pbar = tqdm(dataloader, desc=f"Evaluating")
    for coords, residues, deltas, lengths in pbar:
        
        # Move data to device
        coords = coords.to(device)
        residues = residues.to(device)
        deltas = deltas.to(device)
        lengths = lengths.to(device)
        
        # Forward pass: predict deltas using currents.
        pred_deltas = model(coords, residues, lengths)
        
        # Compute the masked MSE loss.
        loss = masked_mse_loss(pred_deltas, deltas, lengths)
        
        # Update running loss.
        batch_size = coords.size(0)
        running_loss += loss.item() * batch_size
        total_samples += batch_size
        
        avg_loss = running_loss / total_samples
        pbar.set_postfix(avg_loss=avg_loss, current_loss=loss.item())
    
    # Print epoch loss summary.
    epoch_loss = running_loss / total_samples
    print(f"Evaluating - Average Loss: {epoch_loss:.5f}")

def cnn_objective(trial):
    # Sample hyperparameters
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD'])
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256, 512, 1024])

    # Prepare data
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # Build model
    model = CNN_Model(
        channels_1=trial.suggest_categorical('channels_1', [4, 8, 16, 32, 64, 128]),
        channels_2=trial.suggest_categorical('channels_2', [4, 8, 16, 32, 64, 128]),
        kernel_size_1=trial.suggest_categorical('kernel_size_1', [1, 3, 5, 7]),
        kernel_size_2=trial.suggest_categorical('kernel_size_2', [1, 3, 5, 7]),
        kernel_size_3=trial.suggest_categorical('kernel_size_3', [1, 3, 5, 7]),
    )
    
    # Wrap with DataParallel if multiple GPUs are specified
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    model = model.to(device)

    # Choose optimizer
    params = model.parameters()
    if optimizer_name == 'Adam':
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    else:
        optimizer = torch.optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)

    coordinate_factor = trial.suggest_float('coordinate_factor', 0.1, 10, log=True)

    # Training loop
    max_epochs = 100  # Set a reasonable upper limit to avoid infinite loops
    epoch = 0
    while epoch < max_epochs:
        model.train()
        for coords, residues, deltas, lengths in tqdm.tqdm(train_loader):
            coords = coords.to(device) * coordinate_factor
            residues = residues.to(device)
            deltas = deltas.to(device) * coordinate_factor
            lengths = lengths.to(device)

            optimizer.zero_grad()
            pred_deltas = model(coords, residues, lengths)
            loss = masked_mse_loss(pred_deltas, deltas, lengths)
            loss.backward()
            optimizer.step()

        # Intermediate evaluation
        running_loss = 0
        total_samples = 0
        model.eval()
        for coords, residues, deltas, lengths in tqdm.tqdm(val_loader):
            coords = coords.to(device) * coordinate_factor
            residues = residues.to(device)
            deltas = deltas.to(device) * coordinate_factor
            lengths = lengths.to(device)

            pred_deltas = model(coords, residues, lengths)
            loss = masked_mse_loss(pred_deltas, deltas, lengths)

            batch_size = coords.size(0)
            running_loss += loss.item() * batch_size
            total_samples += batch_size
        
        val_loss = running_loss / total_samples / (coordinate_factor ** 2)
        trial.report(val_loss, epoch)

        # Optuna pruning
        if trial.should_prune():
            raise optuna.TrialPruned()

        epoch += 1

    return val_loss

In [5]:
dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

In [13]:
from Models import *

#model = CNN_Model().to(device)
model = EGNN_Model(MAX_PROTEIN_LENGTH, depth=1, emb_dim=8).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01)

In [16]:
train_model(model, dataloader, optimizer, num_epochs=1000)

Epoch 1/1000:  87%|██████████████████████▌   | 271/313 [01:13<00:11,  3.70it/s, avg_loss=0.00266, current_loss=0.00245]


KeyboardInterrupt: 

In [None]:
eval_model(model, dataloader)

In [69]:
from Models import *
eval_model(Stationary_Model().to(device), dataloader) # 0.00262


Evaluating:   0%|                                                                              | 0/399 [00:00<?, ?it/s][A
Evaluating:   0%|                                          | 0/399 [00:00<?, ?it/s, avg_loss=0.271, current_loss=0.271][A
Evaluating:   0%|                                          | 0/399 [00:00<?, ?it/s, avg_loss=0.268, current_loss=0.265][A
Evaluating:   0%|                                           | 0/399 [00:00<?, ?it/s, avg_loss=0.26, current_loss=0.242][A
Evaluating:   0%|                                          | 0/399 [00:00<?, ?it/s, avg_loss=0.263, current_loss=0.272][A
Evaluating:   1%|▎                                 | 4/399 [00:00<00:10, 37.58it/s, avg_loss=0.263, current_loss=0.272][A
Evaluating:   1%|▎                                  | 4/399 [00:00<00:10, 37.58it/s, avg_loss=0.264, current_loss=0.27][A
Evaluating:   1%|▎                                 | 4/399 [00:00<00:10, 37.58it/s, avg_loss=0.265, current_loss=0.266][A
Evaluating:   1

Evaluating - Average Loss: 0.26226





In [77]:
for currents, deltas, lengths in dataloader:
        
    # Move data to device
    currents = currents.to(device)
    deltas = deltas.to(device)
    lengths = lengths.to(device)
    
    # Forward pass: predict deltas using currents.
    pred_deltas = model(currents)
    print(pred_deltas)
    break

tensor([[[ 0.0177,  0.0076,  0.0145],
         [ 0.0380, -0.0054, -0.0150],
         [ 0.0084, -0.0430, -0.0192],
         ...,
         [ 0.0533, -0.0061, -0.0016],
         [ 0.0430, -0.0043,  0.0008],
         [-0.0049,  0.0170,  0.0106]],

        [[ 0.0305, -0.0143,  0.0384],
         [ 0.0234, -0.0359,  0.0153],
         [ 0.0332, -0.0174,  0.0114],
         ...,
         [ 0.0533, -0.0061, -0.0016],
         [ 0.0430, -0.0043,  0.0008],
         [-0.0049,  0.0170,  0.0106]],

        [[ 0.0356, -0.0230,  0.0304],
         [ 0.0521, -0.0095,  0.0244],
         [ 0.0305, -0.0032, -0.0018],
         ...,
         [ 0.0533, -0.0061, -0.0016],
         [ 0.0430, -0.0043,  0.0008],
         [-0.0049,  0.0170,  0.0106]],

        ...,

        [[ 0.0338, -0.0177, -0.0106],
         [ 0.0272,  0.0040,  0.0302],
         [ 0.0475,  0.0062,  0.0303],
         ...,
         [ 0.0533, -0.0061, -0.0016],
         [ 0.0430, -0.0043,  0.0008],
         [-0.0049,  0.0170,  0.0106]],

        [[

In [27]:
study = optuna.create_study(
        direction='minimize',
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2)
    )
study.optimize(cnn_objective, n_trials=1)

[I 2025-04-25 11:50:37,903] A new study created in memory with name: no-name-663c8c93-9cfe-470a-a551-431c4a0c6214
100%|████████████████████████████████████████████████████████████████████████████████| 640/640 [00:11<00:00, 55.65it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:02<00:00, 57.20it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 640/640 [00:11<00:00, 56.01it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:02<00:00, 58.66it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 640/640 [00:11<00:00, 54.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:02<00:00, 53.06it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 640/640 [00:13<00:00, 49.07it/s]
100%|█████████████████████████████████████████

KeyboardInterrupt: 