In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys

# Import the Mamba model
from mamba_wf import Mamba, RMSNorm

torch.manual_seed(1)

# Set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load the data
x = torch.load('test_sims/x_combined.pt')
x = x.reshape(x.shape[0], x.shape[1], 1)
theta = torch.load('test_sims/theta_combined.pt')

# Print shapes for verification
print("Input data shape:", x.shape)
print("Target data shape:", theta.shape)

# Hyperparameters
BATCH_SIZE = 100
LEARNING_RATE = 5e-4
EPOCHS = 70
TRAIN_SPLIT = 0.9  # 80% for training, 20% for validation

# Extract dimensions from the data
n_samples = x.shape[0]
seq_len = x.shape[1]
d_model = 16
state_size = 64 

# Split the data into training and validation sets
train_size = int(n_samples * TRAIN_SPLIT)
train_x = x[:train_size]
train_theta = theta[:train_size]
val_x = x[train_size:]
val_theta = theta[train_size:]

# Create DataLoaders
train_dataset = TensorDataset(train_x, train_theta)
val_dataset = TensorDataset(val_x, val_theta)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [None]:
# Function to add a linear output layer to map from d_model to n_params (which is theta.shape[1])
class MambaWithOutputLayer(nn.Module):
    def __init__(self, seq_len, d_model, state_size, output_size, device):
        super(MambaWithOutputLayer, self).__init__()
        self.mamba = Mamba(seq_len=seq_len, d_model=d_model, state_size=state_size, device=device)
        self.output_layer = nn.Linear(d_model, output_size)
        
    def forward(self, x):
        # The Mamba output has shape [batch_size, seq_len, d_model]
        # We need to process it to get [batch_size, output_size]
        x = self.mamba(x)
        # Take the last timestep output for simplicity
        x = x[:, -1, :]  # Shape: [batch_size, d_model]
        x = self.output_layer(x)  # Shape: [batch_size, output_size]
        return x

In [7]:
# Create the model with output layer
n_params = train_theta.shape[1]
model = MambaWithOutputLayer(seq_len=seq_len, d_model=d_model, state_size=state_size, 
                            output_size=n_params, device=device).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
def train():
    train_losses = []
    val_losses = []
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        # Training
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Training"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            
        epoch_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_train_loss)
        
        # Validation
        model.eval()
        running_val_loss = 0.0
        
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Validation"):
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                running_val_loss += loss.item() * inputs.size(0)
                
        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)
        
        print(f"Epoch {epoch+1}/{EPOCHS}, "
              f"Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}")
        
        # Save model checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': epoch_train_loss,
                'val_loss': epoch_val_loss,
            }, f'mamba_checkpoint_epoch_{epoch+1}.pt')
    
    return train_losses, val_losses

# Plot training and validation loss
def plot_loss(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig('loss_plot.png')
    plt.close()

# Evaluate the model on the validation set
def evaluate():
    model.eval()
    val_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Evaluating"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item() * inputs.size(0)
            
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    # Calculate average loss
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f"Validation Loss: {avg_val_loss:.4f}")
    
    # Concatenate all outputs and targets
    all_outputs = np.concatenate(all_outputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    
    return avg_val_loss

if __name__ == "__main__":
    print("Starting training...")
    train_losses, val_losses = train()
    plot_loss(train_losses, val_losses)
    
    print("\nEvaluating model...")
    avg_val_loss = evaluate()
    
    print("Training and evaluation completed. Model saved.")

Using device: cuda:1
Input data shape: torch.Size([100000, 40, 1])
Target data shape: torch.Size([100000, 7])
Starting training...


Epoch 1/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.90it/s]
Epoch 1/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.43it/s]


Epoch 1/70, Train Loss: 2.6844, Val Loss: 1.3116


Epoch 2/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.73it/s]
Epoch 2/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.97it/s]


Epoch 2/70, Train Loss: 1.2597, Val Loss: 1.2469


Epoch 3/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.74it/s]
Epoch 3/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.26it/s]


Epoch 3/70, Train Loss: 1.2100, Val Loss: 1.2612


Epoch 4/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.88it/s]
Epoch 4/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.10it/s]


Epoch 4/70, Train Loss: 1.1799, Val Loss: 1.1664


Epoch 5/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.88it/s]
Epoch 5/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.32it/s]


Epoch 5/70, Train Loss: 1.1554, Val Loss: 1.1690


Epoch 6/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.79it/s]
Epoch 6/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.27it/s]


Epoch 6/70, Train Loss: 1.1415, Val Loss: 1.1683


Epoch 7/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 16.05it/s]
Epoch 7/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 209.02it/s]


Epoch 7/70, Train Loss: 1.1273, Val Loss: 1.1109


Epoch 8/70 - Training: 100%|██████████| 900/900 [00:54<00:00, 16.39it/s]
Epoch 8/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 209.13it/s]


Epoch 8/70, Train Loss: 1.1142, Val Loss: 1.1325


Epoch 9/70 - Training: 100%|██████████| 900/900 [00:54<00:00, 16.38it/s]
Epoch 9/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 209.12it/s]


Epoch 9/70, Train Loss: 1.1063, Val Loss: 1.0953


Epoch 10/70 - Training: 100%|██████████| 900/900 [00:54<00:00, 16.40it/s]
Epoch 10/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.66it/s]


Epoch 10/70, Train Loss: 1.0943, Val Loss: 1.0969


Epoch 11/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.34it/s]
Epoch 11/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.72it/s]


Epoch 11/70, Train Loss: 1.0908, Val Loss: 1.0865


Epoch 12/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.32it/s]
Epoch 12/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 209.12it/s]


Epoch 12/70, Train Loss: 1.0814, Val Loss: 1.0810


Epoch 13/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.34it/s]
Epoch 13/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.44it/s]


Epoch 13/70, Train Loss: 1.0768, Val Loss: 1.0667


Epoch 14/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.32it/s]
Epoch 14/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.93it/s]


Epoch 14/70, Train Loss: 1.0690, Val Loss: 1.0648


Epoch 15/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.33it/s]
Epoch 15/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.97it/s]


Epoch 15/70, Train Loss: 1.0653, Val Loss: 1.0709


Epoch 16/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.32it/s]
Epoch 16/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.01it/s]


Epoch 16/70, Train Loss: 1.0608, Val Loss: 1.0584


Epoch 17/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.77it/s]
Epoch 17/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.82it/s]


Epoch 17/70, Train Loss: 1.0548, Val Loss: 1.0527


Epoch 18/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.78it/s]
Epoch 18/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.74it/s]


Epoch 18/70, Train Loss: 1.0521, Val Loss: 1.0698


Epoch 19/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.81it/s]
Epoch 19/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.09it/s]


Epoch 19/70, Train Loss: 1.0445, Val Loss: 1.0485


Epoch 20/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.83it/s]
Epoch 20/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.28it/s]


Epoch 20/70, Train Loss: 1.0437, Val Loss: 1.0310


Epoch 21/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.78it/s]
Epoch 21/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.67it/s]


Epoch 21/70, Train Loss: 1.0368, Val Loss: 1.0395


Epoch 22/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.19it/s]
Epoch 22/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.26it/s]


Epoch 22/70, Train Loss: 1.0371, Val Loss: 1.0379


Epoch 23/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.17it/s]
Epoch 23/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.89it/s]


Epoch 23/70, Train Loss: 1.0335, Val Loss: 1.0264


Epoch 24/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 24/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.03it/s]


Epoch 24/70, Train Loss: 1.0304, Val Loss: 1.0278


Epoch 25/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.18it/s]
Epoch 25/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.34it/s]


Epoch 25/70, Train Loss: 1.0260, Val Loss: 1.0257


Epoch 26/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 26/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.44it/s]


Epoch 26/70, Train Loss: 1.0259, Val Loss: 1.0174


Epoch 27/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.14it/s]
Epoch 27/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.52it/s]


Epoch 27/70, Train Loss: 1.0212, Val Loss: 1.0286


Epoch 28/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 28/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.57it/s]


Epoch 28/70, Train Loss: 1.0196, Val Loss: 1.0260


Epoch 29/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.16it/s]
Epoch 29/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.52it/s]


Epoch 29/70, Train Loss: 1.0173, Val Loss: 1.0311


Epoch 30/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 30/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.67it/s]


Epoch 30/70, Train Loss: 1.0146, Val Loss: 1.0233


Epoch 31/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.18it/s]
Epoch 31/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.27it/s]


Epoch 31/70, Train Loss: 1.0118, Val Loss: 1.0284


Epoch 32/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.16it/s]
Epoch 32/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.20it/s]


Epoch 32/70, Train Loss: 1.0117, Val Loss: 1.0285


Epoch 33/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.13it/s]
Epoch 33/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.66it/s]


Epoch 33/70, Train Loss: 1.0087, Val Loss: 1.0185


Epoch 34/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.15it/s]
Epoch 34/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.40it/s]


Epoch 34/70, Train Loss: 1.0068, Val Loss: 1.0114


Epoch 35/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 35/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 209.25it/s]


Epoch 35/70, Train Loss: 1.0062, Val Loss: 1.0106


Epoch 36/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.25it/s]
Epoch 36/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.29it/s]


Epoch 36/70, Train Loss: 1.0021, Val Loss: 1.0064


Epoch 37/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.28it/s]
Epoch 37/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.39it/s]


Epoch 37/70, Train Loss: 1.0020, Val Loss: 1.0228


Epoch 38/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.26it/s]
Epoch 38/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.78it/s]


Epoch 38/70, Train Loss: 1.0005, Val Loss: 1.0064


Epoch 39/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.22it/s]
Epoch 39/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.71it/s]


Epoch 39/70, Train Loss: 0.9968, Val Loss: 1.0067


Epoch 40/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.23it/s]
Epoch 40/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.82it/s]


Epoch 40/70, Train Loss: 0.9962, Val Loss: 0.9980


Epoch 41/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.23it/s]
Epoch 41/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.38it/s]


Epoch 41/70, Train Loss: 0.9946, Val Loss: 1.0070


Epoch 42/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.17it/s]
Epoch 42/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.50it/s]


Epoch 42/70, Train Loss: 0.9951, Val Loss: 1.0185


Epoch 43/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 43/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.21it/s]


Epoch 43/70, Train Loss: 0.9939, Val Loss: 1.0122


Epoch 44/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 44/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.76it/s]


Epoch 44/70, Train Loss: 0.9909, Val Loss: 1.0058


Epoch 45/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.19it/s]
Epoch 45/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.89it/s]


Epoch 45/70, Train Loss: 0.9886, Val Loss: 1.0013


Epoch 46/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.17it/s]
Epoch 46/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.85it/s]


Epoch 46/70, Train Loss: 0.9896, Val Loss: 0.9902


Epoch 47/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 47/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.86it/s]


Epoch 47/70, Train Loss: 0.9867, Val Loss: 0.9912


Epoch 48/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.23it/s]
Epoch 48/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 208.08it/s]


Epoch 48/70, Train Loss: 0.9849, Val Loss: 0.9905


Epoch 49/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.19it/s]
Epoch 49/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.18it/s]


Epoch 49/70, Train Loss: 0.9851, Val Loss: 1.0025


Epoch 50/70 - Training: 100%|██████████| 900/900 [00:56<00:00, 15.98it/s]
Epoch 50/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.16it/s]


Epoch 50/70, Train Loss: 0.9837, Val Loss: 1.0102


Epoch 51/70 - Training: 100%|██████████| 900/900 [00:57<00:00, 15.78it/s]
Epoch 51/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.20it/s]


Epoch 51/70, Train Loss: 0.9828, Val Loss: 1.0042


Epoch 52/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 52/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.42it/s]


Epoch 52/70, Train Loss: 0.9823, Val Loss: 0.9960


Epoch 53/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.15it/s]
Epoch 53/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.40it/s]


Epoch 53/70, Train Loss: 0.9810, Val Loss: 1.0040


Epoch 54/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.16it/s]
Epoch 54/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.64it/s]


Epoch 54/70, Train Loss: 0.9806, Val Loss: 0.9924


Epoch 55/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.19it/s]
Epoch 55/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.97it/s]


Epoch 55/70, Train Loss: 0.9783, Val Loss: 0.9921


Epoch 56/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 56/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.95it/s]


Epoch 56/70, Train Loss: 0.9774, Val Loss: 1.0023


Epoch 57/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 57/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.09it/s]


Epoch 57/70, Train Loss: 0.9770, Val Loss: 0.9913


Epoch 58/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.18it/s]
Epoch 58/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.95it/s]


Epoch 58/70, Train Loss: 0.9768, Val Loss: 0.9942


Epoch 59/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 59/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.59it/s]


Epoch 59/70, Train Loss: 0.9755, Val Loss: 0.9858


Epoch 60/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.19it/s]
Epoch 60/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.44it/s]


Epoch 60/70, Train Loss: 0.9748, Val Loss: 1.0006


Epoch 61/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.16it/s]
Epoch 61/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.34it/s]


Epoch 61/70, Train Loss: 0.9722, Val Loss: 1.0168


Epoch 62/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.23it/s]
Epoch 62/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.26it/s]


Epoch 62/70, Train Loss: 0.9734, Val Loss: 0.9909


Epoch 63/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 63/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.44it/s]


Epoch 63/70, Train Loss: 0.9715, Val Loss: 0.9778


Epoch 64/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.21it/s]
Epoch 64/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.49it/s]


Epoch 64/70, Train Loss: 0.9701, Val Loss: 0.9845


Epoch 65/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.22it/s]
Epoch 65/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.76it/s]


Epoch 65/70, Train Loss: 0.9704, Val Loss: 0.9945


Epoch 66/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.20it/s]
Epoch 66/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.68it/s]


Epoch 66/70, Train Loss: 0.9682, Val Loss: 0.9995


Epoch 67/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.24it/s]
Epoch 67/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.30it/s]


Epoch 67/70, Train Loss: 0.9682, Val Loss: 0.9913


Epoch 68/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.24it/s]
Epoch 68/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.77it/s]


Epoch 68/70, Train Loss: 0.9675, Val Loss: 0.9838


Epoch 69/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.22it/s]
Epoch 69/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.79it/s]


Epoch 69/70, Train Loss: 0.9686, Val Loss: 0.9855


Epoch 70/70 - Training: 100%|██████████| 900/900 [00:55<00:00, 16.24it/s]
Epoch 70/70 - Validation: 100%|██████████| 100/100 [00:00<00:00, 207.51it/s]


Epoch 70/70, Train Loss: 0.9655, Val Loss: 0.9941

Evaluating model...


Evaluating: 100%|██████████| 100/100 [00:00<00:00, 206.73it/s]

Validation Loss: 0.9941
Training and evaluation completed. Model saved.





In [8]:
# Test the trained model on a test set from a specific simulator
def test_mamba(sim):
    # Load files and initialize datasets
    test_x = torch.load(f'test_sims/test_x_{sim}.pt')
    test_x = test_x.reshape(test_x.shape[0], test_x.shape[1], 1)
    test_theta = torch.load(f'test_sims/test_theta_{sim}.pt')
    test_dataset = TensorDataset(test_x, test_theta)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # Can't shuffle if you want meaningful results :)

    # Begin evaluating
    predictions = []
    with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Evaluating"):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs)
                predictions.append(outputs)

    # Get all predictions
    predictions = torch.cat(predictions)
    torch.save(predictions, f'test_sims/predictions_{sim}_mamba_combined.pt')
    
    return predictions, test_theta

for sim in ['combined']:
    y = test_mamba(sim)

Evaluating: 100%|██████████| 10/10 [00:00<00:00, 237.89it/s]


In [9]:
# Count parameters in the Mamba model
# Specifically here I show the 2-mutation, therefore numbers are different from Table 1
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

44687