In [1]:
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# Cell 1: Imports and Dummy Dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import SwinModel, SwinConfig
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os

# Dummy dataset loader (replace with ML-JET loader)
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, size=1000):
        self.data = torch.rand(size, 1, 32, 32)
        self.labels_energy = torch.randint(0, 2, (size, 1)).float()
        self.labels_alpha = torch.randint(0, 3, (size,))
        self.labels_q0 = torch.randint(0, 4, (size,))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], {
            'energy_loss_output': self.labels_energy[idx],
            'alpha_output': self.labels_alpha[idx],
            'q0_output': self.labels_q0[idx]
        }


In [38]:
class MambaVisionMultiHead(nn.Module):
    def __init__(self, in_chans=1, img_size=32, embed_dim=128, mamba_layers=4, mamba_hidden=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size=3, padding=1),
            nn.Flatten(2),
            nn.Linear(img_size*img_size, img_size),
           
        )
        self.norm= nn.LayerNorm(embed_dim)
        self.mamba = Mamba(d_model=embed_dim, d_state=mamba_hidden, d_conv=mamba_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head_energy = nn.Linear(embed_dim, 1)
        self.head_alpha  = nn.Linear(embed_dim, 3)
        self.head_q0     = nn.Linear(embed_dim, 4)

    def forward(self, x):
        # x: (B,1,32,32)
        z = self.proj(x)               # (B, embed_dim, 32)
        z = z.permute(2,0,1)           # (seq_len, B, embed_dim)
        out_seq = self.mamba(z)        # (seq_len, B, embed_dim)
        feat = out_seq[-1]             # (B, embed_dim)
        return {
            'energy_loss_output': self.head_energy(feat),
            'alpha_output':  self.head_alpha(feat),
            'q0_output':     self.head_q0(feat)
        }


In [39]:
import torch.optim as optim            # ← add this line

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MambaVisionMultiHead(
    in_chans=1, img_size=32,
    embed_dim=128, mamba_layers=4, mamba_hidden=256
).to(device)

# Losses
crit_e = nn.BCEWithLogitsLoss()
crit_a = nn.CrossEntropyLoss()
crit_q = nn.CrossEntropyLoss()

def composite_loss(preds, labels):
    le = crit_e(preds['energy_loss_output'], labels['energy_loss_output'].float())
    la = crit_a(preds['alpha_output'],  labels['alpha_output'])
    lq = crit_q(preds['q0_output'],     labels['q0_output'])
    return le + la + lq, (le.item(), la.item(), lq.item())

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Step LR scheduler: drop LR by 0.1 every 15 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)


In [40]:
from tqdm.auto import tqdm

def train_one_epoch(loader):
    model.train()
    running_loss = 0.0
    for x, labels in tqdm(loader, desc='Train'):
        x = x.to(device)
        labels = {k: v.to(device) for k,v in labels.items()}

        optimizer.zero_grad()
        preds = model(x)
        loss, _ = composite_loss(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

@torch.no_grad()
def validate(loader):
    model.eval()
    total_loss = 0.0
    correct_e = correct_a = correct_q = total = 0

    for x, labels in tqdm(loader, desc='Val'):
        x = x.to(device)
        labels = {k: v.to(device) for k,v in labels.items()}
        preds = model(x)
        loss, _ = composite_loss(preds, labels)
        total_loss += loss.item()

        e_pred = (torch.sigmoid(preds['energy_loss_output']) > 0.5).long()
        a_pred = preds['alpha_output'].argmax(dim=1)
        q_pred = preds['q0_output'].argmax(dim=1)

        correct_e += (e_pred == labels['energy_loss_output']).sum().item()
        correct_a += (a_pred == labels['alpha_output']).sum().item()
        correct_q += (q_pred == labels['q0_output']).sum().item()
        total += x.size(0)

    return {
        'loss':   total_loss / len(loader),
        'acc_e':  correct_e / total,
        'acc_a':  correct_a / total,
        'acc_q':  correct_q / total
    }


In [30]:
train_loader, val_loader, test_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True), DataLoader(DummyDataset(), batch_size=32, shuffle=True),DataLoader(DummyDataset(), batch_size=32, shuffle=True)

In [41]:

num_epochs = 1
best_val = float('inf')

for epoch in range(1, num_epochs+1):
    train_loss = train_one_epoch(train_loader)
    val_metrics = validate(val_loader)
    scheduler.step()

    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_metrics['loss']:.4f} | "
          f"E Acc: {val_metrics['acc_e']:.2%} | "
          f"A Acc: {val_metrics['acc_a']:.2%} | "
          f"Q Acc: {val_metrics['acc_q']:.2%} | "
          f"LR: {scheduler.get_last_lr()[0]:.2e}"
    )

    if val_metrics['loss'] < best_val:
        best_val = val_metrics['loss']
        torch.save(model.state_dict(), 'best_mamba_vision.pth')


Train: 100%|██████████| 32/32 [00:02<00:00, 14.90it/s]
Val: 100%|██████████| 32/32 [00:00<00:00, 63.55it/s]


Epoch 01 | Train Loss: 3.1787 | Val Loss: 3.1780 | E Acc: 48.60% | A Acc: 33.00% | Q Acc: 25.50% | LR: 1.00e-04


In [13]:
# Cell 3: Loss Computation

def compute_loss(outputs, targets):
    bce = nn.BCEWithLogitsLoss()
    ce = nn.CrossEntropyLoss()
    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])
    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])
    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])
    return loss_energy + loss_alpha + loss_q0



In [14]:
# Cell 4: Training Loop Function

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = {k: v.to(device) for k, v in y.items()}
        optimizer.zero_grad()
        outputs = model(x)
        loss = compute_loss(outputs, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [17]:
# Cell 5: Main Training Script


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MambaVisionMultiHead().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)

for epoch in range(5):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")


RuntimeError: Given normalized_shape=[128], expected input with shape [*, 128], but got input of size[32, 128, 32]