In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
import numpy as np
import time
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Define a simplified Mamba model class
class SimpleMambaSSM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMambaSSM, self).__init__()
        self.hidden_dim = hidden_dim
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.state_matrix = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.input_matrix = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.output_matrix = nn.Parameter(torch.randn(hidden_dim, output_dim))
        self.selective_gate = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        # x: (batch_size, sequence_length, input_dim)
        batch_size, sequence_length, _ = x.size()
        hidden_state = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        outputs = []

        for t in range(sequence_length):
            current_input = x[:, t, :]
            # Compute selective gating mechanism
            gate = torch.sigmoid(self.selective_gate(current_input))
            
            # Update hidden state with selective mechanism
            hidden_state = gate * (torch.matmul(hidden_state, self.state_matrix) + torch.matmul(current_input, self.input_matrix.T))
            
            # Compute output
            output = torch.matmul(hidden_state, self.output_matrix)
            outputs.append(output.unsqueeze(1))

        return torch.cat(outputs, dim=1)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nhead, num_layers):
        super(TransformerModel, self).__init__()
        self.positional_encoding = nn.Parameter(torch.randn(1, 100, input_dim))  # Assuming max length of 100
        encoder_layers = TransformerEncoderLayer(d_model=input_dim, nhead=nhead, dim_feedforward=hidden_dim)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        self.fc_out = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # x: (batch_size, sequence_length, input_dim)
        batch_size, sequence_length, _ = x.size()
        x = x + self.positional_encoding[:, :sequence_length, :]
        x = x.permute(1, 0, 2)  # Transformer expects (sequence_length, batch_size, input_dim)
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2)  # Back to (batch_size, sequence_length, input_dim)
        x = self.fc_out(x)
        return x

# Define a simple function to compare two models
def compare_models(model1, model2):
    # Print model summaries
    print("\n========================= Model 1 Summary (Mamba) =========================")
    print(model1)
    total_params_model1 = sum(p.numel() for p in model1.parameters())
    trainable_params_model1 = sum(p.numel() for p in model1.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params_model1}")
    print(f"Trainable Parameters: {trainable_params_model1}")

    print("\n========================= Model 2 Summary (Transformer) =========================")
    print(model2)
    total_params_model2 = sum(p.numel() for p in model2.parameters())
    trainable_params_model2 = sum(p.numel() for p in model2.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params_model2}")
    print(f"Trainable Parameters: {trainable_params_model2}")

    # Print differences in parameter counts
    print("\n========================= Model Comparison =========================")
    print(f"Difference in Total Parameters: {total_params_model2 - total_params_model1}")
    print(f"Difference in Trainable Parameters: {trainable_params_model2 - trainable_params_model1}")

# Dummy models for comparison
input_dim = 10
hidden_dim = 20
output_dim = 1
num_heads = 2
num_layers = 2

mamba_model = SimpleMambaSSM(input_dim, hidden_dim, output_dim)
transformer_model = TransformerModel(input_dim, hidden_dim, output_dim, nhead=num_heads, num_layers=num_layers)

# Compare models
compare_models(mamba_model, transformer_model)

# Dummy dataset for demonstration purposes
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, sequence_length, input_dim):
        self.data = torch.randn(num_samples, sequence_length, input_dim)
        self.targets = torch.randn(num_samples, sequence_length, 1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# Define a training function with time measurement
def timed_train(model, dataloader, criterion, optimizer, device, num_epochs=5):
    start_time = time.time()
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.view(-1), targets.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{num_epochs} | Training Loss: {total_loss / len(dataloader):.4f}")
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total Training Time: {elapsed_time:.2f} seconds\n")
    return elapsed_time

# Example usage
input_dim = 10
hidden_dim = 20
output_dim = 1
num_samples = 100
sequence_length = 5
num_heads = 2
num_layers = 2

dataset = DummyDataset(num_samples, sequence_length, input_dim)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

# Training Mamba model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamba_model = mamba_model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(mamba_model.parameters(), lr=0.001)

print("\n========================= Training Mamba Model =========================")
mamba_training_time = timed_train(mamba_model, dataloader, criterion, optimizer, device)

# Training Transformer model
transformer_model = transformer_model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(transformer_model.parameters(), lr=0.001)

print("\n========================= Training Transformer Model =========================")
transformer_training_time = timed_train(transformer_model, dataloader, criterion, optimizer, device)

# Comparison of training times
print("\n========================= Training Time Comparison =========================")
print(f"Mamba Model Training Time: {mamba_training_time:.2f} seconds")
print(f"Transformer Model Training Time: {transformer_training_time:.2f} seconds")
print(f"Difference in Training Time: {transformer_training_time - mamba_training_time:.2f} seconds")



SimpleMambaSSM(
  (input_proj): Linear(in_features=10, out_features=20, bias=True)
  (selective_gate): Linear(in_features=10, out_features=20, bias=True)
)
Total Parameters: 1060
Trainable Parameters: 1060

TransformerModel(
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
        )
        (linear1): Linear(in_features=10, out_features=20, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=20, out_features=10, bias=True)
        (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
    