In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys

# Import the Mamba model
from mamba_wf import Mamba, RMSNorm

torch.manual_seed(1)

# Set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda:1


In [2]:
# Load the data
x = torch.load('test_sims/x_combined.pt')
x = x.reshape(x.shape[0], x.shape[1], 1)
theta = torch.load('test_sims/theta_combined.pt')

# Print shapes for verification
print("Input data shape:", x.shape)
print("Target data shape:", theta.shape)

# Hyperparameters
BATCH_SIZE = 100
LEARNING_RATE = 5e-4
EPOCHS = 70
TRAIN_SPLIT = 0.9  # 90% for training, 10% for validation

# Extract dimensions from the data
n_samples = x.shape[0]
seq_len = x.shape[1]
d_model = 16
state_size = 64 

# Split the data into training and validation sets
train_size = int(n_samples * TRAIN_SPLIT)
train_x = x[:train_size]
train_theta = theta[:train_size]
val_x = x[train_size:]
val_theta = theta[train_size:]

# Create DataLoaders
train_dataset = TensorDataset(train_x, train_theta)
val_dataset = TensorDataset(val_x, val_theta)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

Input data shape: torch.Size([100000, 40, 1])
Target data shape: torch.Size([100000, 7])


In [3]:
# Function to add a linear output layer to map from d_model to n_params (which is theta.shape[1])
class MambaWithOutputLayer(nn.Module):
    def __init__(self, seq_len, d_model, state_size, output_size, batch_size, device):
        super(MambaWithOutputLayer, self).__init__()
        self.mamba = Mamba(seq_len=seq_len, d_model=d_model, state_size=state_size, batch_size=batch_size, device=device)
        self.output_layer = nn.Linear(d_model, output_size)
        
    def forward(self, x):
        # The Mamba output has shape [batch_size, seq_len, d_model]
        # We need to process it to get [batch_size, output_size]
        x = self.mamba(x)
        # Take the last timestep output for simplicity
        x = x[:, -1, :]  # Shape: [batch_size, d_model]
        x = self.output_layer(x)  # Shape: [batch_size, output_size]
        return x

In [None]:
# Create the model with output layer
n_params = train_theta.shape[1]
model = MambaWithOutputLayer(seq_len=seq_len, d_model=d_model, state_size=state_size, 
                            output_size=n_params, batch_size=BATCH_SIZE, device=device).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
def train():
    train_losses = []
    val_losses = []
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        # Training
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Training"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            
        epoch_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_train_loss)
        
        # Validation
        model.eval()
        running_val_loss = 0.0
        
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Validation"):
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                running_val_loss += loss.item() * inputs.size(0)
                
        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)
        
        print(f"Epoch {epoch+1}/{EPOCHS}, "
              f"Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}")
        
        # Save model checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': epoch_train_loss,
                'val_loss': epoch_val_loss,
            }, f'mamba_checkpoint_epoch_{epoch+1}.pt')
    
    return train_losses, val_losses

# Plot training and validation loss
def plot_loss(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig('loss_plot.png')
    plt.close()

# Evaluate the model on the validation set
def evaluate():
    model.eval()
    val_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Evaluating"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item() * inputs.size(0)
            
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    # Calculate average loss
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f"Validation Loss: {avg_val_loss:.4f}")
    
    # Concatenate all outputs and targets
    all_outputs = np.concatenate(all_outputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    
    return avg_val_loss

if __name__ == "__main__":
    print("Starting training...")
    train_losses, val_losses = train()
    plot_loss(train_losses, val_losses)
    
    print("\nEvaluating model...")
    avg_val_loss = evaluate()
    
    print("Training and evaluation completed. Model saved.")

Starting training...


Epoch 1/70 - Training: 100%|██████████| 900/900 [00:59<00:00, 15.24it/s]
Epoch 1/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.12it/s]


Epoch 1/70, Train Loss: 2.6843, Val Loss: 1.3122


Epoch 2/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.09it/s]
Epoch 2/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.60it/s]


Epoch 2/70, Train Loss: 1.2594, Val Loss: 1.2469


Epoch 3/70 - Training:   2%|▏         | 18/900 [00:01<00:53, 16.39it/s]

In [8]:
# Test the trained model on a test set from a specific simulator
def test_mamba(sim):
    # Load files and initialize datasets
    test_x = torch.load(f'test_sims/test_x_{sim}.pt')
    test_x = test_x.reshape(test_x.shape[0], test_x.shape[1], 1)
    test_theta = torch.load(f'test_sims/test_theta_{sim}.pt')
    test_dataset = TensorDataset(test_x, test_theta)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # Can't shuffle if you want meaningful results :)

    # Begin evaluating
    predictions = []
    with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Evaluating"):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs)
                predictions.append(outputs)

    # Get all predictions
    predictions = torch.cat(predictions)
    torch.save(predictions, f'test_sims/predictions_{sim}_mamba_combined.pt')
    
    return predictions, test_theta

for sim in ['combined']:
    y = test_mamba(sim)

Evaluating: 100%|██████████| 10/10 [00:00<00:00, 237.89it/s]


In [9]:
# Count parameters in the Mamba model
# Specifically here I show the 2-mutation, therefore numbers are different from Table 1
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

44687