## Mamba Tutorial
### Description
A simple MNIST tutorial for Mamba Model.
We are going to use official Mamba model from https://github.com/state-spaces/mamba.




In [2]:
! pip install -q mamba-ssm lightning datasets

In [1]:
import torch
from mamba_ssm import Mamba
import lightning as L
from datasets import load_dataset
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms

In [None]:
# Define Config for MNIST
TRAIN_DATA="MNIST"
IMG_SIZE = 28
IMG_CHANNEL = 1
PATCH_SIZE = 4
BATCH_SIZE = 32
INIT_LR = 1e-4
D_MODEL = 128 # Dimension of hidden state in SSM
OUTPUT_CLASS_NUM=10
DROPOUT=0.1
N_LAYERS=3

## Dataset Configuration
ds = load_dataset("ylecun/mnist")
train_dataset = ds['train']
validation_dataset = ds['test']

print(len(train_dataset))
print(len(validation_dataset))

In [2]:
# Define Config for CIFAR-10
TRAIN_DATA="CIFAR-10"
IMG_SIZE = 32
IMG_CHANNEL = 3
PATCH_SIZE = 4
BATCH_SIZE = 64
INIT_LR = 3e-4
D_MODEL = 256 # Dimension of hidden state in SSM
OUTPUT_CLASS_NUM=10
DROPOUT=0.15
N_LAYERS=10

## Dataset Configuration
ds = load_dataset("uoft-cs/cifar10")
train_dataset = ds['train']
validation_dataset = ds['test']

print(len(train_dataset))
print(len(validation_dataset))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


50000
10000


In [3]:
# Define a tranform function for image
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomEqualize(),
    transforms.RandomAutocontrast(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomErasing()
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [4]:
## Define a collate_fn for dataset
def train_collate_fn(batch):
  if TRAIN_DATA=="MNIST":
    images = [train_transform(item['image']) for item in batch]
  elif TRAIN_DATA=="CIFAR-10":
    images = [train_transform(item['img']) for item in batch]
  labels = [item["label"] for item in batch]

  # Divide the image into (patch_size * patch_size) size patches
  patched_images = []
  for image in images:
    c, h, w = image.shape

    # check the h,w matches the IMG_SIZE
    assert c == IMG_CHANNEL
    assert h == IMG_SIZE
    assert w == IMG_SIZE

    patch_num = (h // PATCH_SIZE) * (w // PATCH_SIZE)
    image = image.view((c, patch_num, PATCH_SIZE, PATCH_SIZE))
    image = image.view((c * patch_num, PATCH_SIZE * PATCH_SIZE))
    patched_images.append(image)

  patched_images = torch.stack(patched_images)
  labels = torch.LongTensor(labels)

  return patched_images, labels

def validation_collate_fn(batch):
  if TRAIN_DATA=="MNIST":
    images = [test_transform(item['image']) for item in batch]
  elif TRAIN_DATA=="CIFAR-10":
    images = [test_transform(item['img']) for item in batch]
  labels = [item["label"] for item in batch]

  # Divide the image into (PATCH_SIZE * PATCH_SIZE) size patches
  patched_images = []
  for image in images:
    c, h, w = image.shape

    # check the h,w matches the IMG_SIZE
    assert c == IMG_CHANNEL
    assert h == IMG_SIZE
    assert w == IMG_SIZE

    patch_num = (h // PATCH_SIZE) * (w // PATCH_SIZE)
    image = image.view((c, patch_num, PATCH_SIZE, PATCH_SIZE))
    image = image.view((c * patch_num, PATCH_SIZE * PATCH_SIZE))
    patched_images.append(image)

  patched_images = torch.stack(patched_images)
  labels = torch.LongTensor(labels)

  return patched_images, labels

In [5]:
# Define Mamba Model for MNIST classification
class MambaBlock(nn.Module):
  def __init__(self):
    super().__init__()
    self.mamba = Mamba(D_MODEL)
    self.norm = nn.LayerNorm(D_MODEL)
    self.dropout = nn.Dropout(DROPOUT)

  def forward(self, x):
    x = self.norm(self.mamba(x) + x)
    x = self.dropout(x)
    return x

class Mamba_MNIST_Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(PATCH_SIZE * PATCH_SIZE, D_MODEL)
    self.mambas = nn.Sequential(*[MambaBlock() for _ in range(N_LAYERS)])
    self.flatten = nn.Flatten()
    self.dropout = nn.Dropout(DROPOUT)

    self.linear2 = nn.Linear(D_MODEL * ((IMG_CHANNEL * IMG_SIZE * IMG_SIZE) // (PATCH_SIZE * PATCH_SIZE)), 1024)
    self.norm2 = nn.BatchNorm1d(1024)
    self.linear3 = nn.Linear(1024, OUTPUT_CLASS_NUM)

  def forward(self, x):
    x = self.linear1(x)
    x = self.mambas(x)
    x = self.flatten(x)
    x = self.linear2(x)
    x = self.dropout(self.norm2(x))
    x = self.linear3(x)
    return x

In [6]:
# Define a lightning module for training
class Mamba_MNIST_LightningModel(L.LightningModule):
  def __init__(self, model, config):
    super().__init__()
    self.model = model
    self.config = config

    self.lr = self.config.get("INIT_LR", 1e-4)
    self.batch_size = self.config.get("BATCH_SIZE", 16)
    self.criterion = nn.CrossEntropyLoss()


    self.train_losses = []
    self.val_losses = []
    self.val_total = 0
    self.val_correct = 0

  def on_train_epoch_start(self):
    self.train_losses = []
  def on_validation_epoch_start(self):
    self.val_losses = []
    self.val_total = 0
    self.val_correct = 0

  def training_step(self, batch, batch_idx):
    patched_images, labels = batch

    outputs = self.model(patched_images)
    losses = self.criterion(outputs, labels)

    self.train_losses.append(losses.item())

    return losses

  def validation_step(self, batch, batch_idx):
    patched_images, labels = batch

    outputs = self.model(patched_images)
    losses = self.criterion(outputs, labels)

    self.val_losses.append(losses.item())

    _, predictions = outputs.max(1)
    self.val_correct += (predictions == labels).sum()
    self.val_total += predictions.size(0)

    return losses

  def on_train_epoch_end(self):
    print(
        f"EPOCH #{self.current_epoch}: Avg Training Loss = {np.mean(self.train_losses)}"
    )
  def on_validation_epoch_end(self):
    print(
        f"EPOCH #{self.current_epoch}: Avg Validation Loss = {np.mean(self.val_losses)} | Validation Accuracy = {self.val_correct * 100 / self.val_total:.2f}%"
    )



  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
    return optimizer

  def train_dataloader(self):
    return DataLoader(
        train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        collate_fn=train_collate_fn,
        pin_memory=True,
        num_workers=4
    )

  def val_dataloader(self):
    return DataLoader(
        validation_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        collate_fn=validation_collate_fn,
        pin_memory=True,
        num_workers=4
    )



In [7]:
# Actual Training happens here!
config = {
    "INIT_LR": INIT_LR,
    "BATCH_SIZE": BATCH_SIZE,

}


model = Mamba_MNIST_Model()

model_module = Mamba_MNIST_LightningModel(model, config)

trainer = L.Trainer(
    max_epochs=100,
    gradient_clip_val=1.0,
    num_sanity_val_steps=5,
)

trainer.fit(model_module)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more det

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

EPOCH #0: Avg Validation Loss = 2.3361221313476563 | Validation Accuracy = 8.44%


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #0: Avg Validation Loss = 1.4678503487520158 | Validation Accuracy = 49.44%
EPOCH #0: Avg Training Loss = 1.7829031392436503


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #1: Avg Validation Loss = 1.2212475967255367 | Validation Accuracy = 59.01%
EPOCH #1: Avg Training Loss = 1.3547917294227863


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #2: Avg Validation Loss = 1.0433558984926552 | Validation Accuracy = 65.20%
EPOCH #2: Avg Training Loss = 1.1793185361968281


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #3: Avg Validation Loss = 1.0235107358853528 | Validation Accuracy = 65.74%
EPOCH #3: Avg Training Loss = 1.063101741740161


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #4: Avg Validation Loss = 1.0629111858689861 | Validation Accuracy = 65.44%
EPOCH #4: Avg Training Loss = 0.9882741229003652


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #5: Avg Validation Loss = 0.8946021468776046 | Validation Accuracy = 70.52%
EPOCH #5: Avg Training Loss = 0.9210892021656036


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #6: Avg Validation Loss = 0.9061646727240009 | Validation Accuracy = 70.79%
EPOCH #6: Avg Training Loss = 0.8689691053174645


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #7: Avg Validation Loss = 0.8755370427848427 | Validation Accuracy = 71.59%
EPOCH #7: Avg Training Loss = 0.821764680697485


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #8: Avg Validation Loss = 0.8552716613575152 | Validation Accuracy = 72.29%
EPOCH #8: Avg Training Loss = 0.779408980932687


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #9: Avg Validation Loss = 0.8245862403493018 | Validation Accuracy = 73.94%
EPOCH #9: Avg Training Loss = 0.745975556314144


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #10: Avg Validation Loss = 0.8380377364766066 | Validation Accuracy = 73.58%
EPOCH #10: Avg Training Loss = 0.7180001115631265


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #11: Avg Validation Loss = 0.7875459143880067 | Validation Accuracy = 74.82%
EPOCH #11: Avg Training Loss = 0.6881972417197264


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #12: Avg Validation Loss = 0.8435224554720958 | Validation Accuracy = 73.23%
EPOCH #12: Avg Training Loss = 0.6614220188859173


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #13: Avg Validation Loss = 0.8368188978000811 | Validation Accuracy = 74.79%
EPOCH #13: Avg Training Loss = 0.6296195630603434


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #14: Avg Validation Loss = 0.8556645256318863 | Validation Accuracy = 73.94%
EPOCH #14: Avg Training Loss = 0.6073161080822616


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #15: Avg Validation Loss = 0.8066296457864677 | Validation Accuracy = 75.61%
EPOCH #15: Avg Training Loss = 0.5893463083850149


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #16: Avg Validation Loss = 0.8113892411540269 | Validation Accuracy = 75.68%
EPOCH #16: Avg Training Loss = 0.5742428798581023


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #17: Avg Validation Loss = 0.8514226963565608 | Validation Accuracy = 74.82%
EPOCH #17: Avg Training Loss = 0.548592496215535


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #18: Avg Validation Loss = 0.7743077627412832 | Validation Accuracy = 76.65%
EPOCH #18: Avg Training Loss = 0.5373373599842076


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #19: Avg Validation Loss = 0.8178247967912893 | Validation Accuracy = 76.62%
EPOCH #19: Avg Training Loss = 0.5169023672485595


Validation: |          | 0/? [00:00<?, ?it/s]

EPOCH #20: Avg Validation Loss = 0.8424289729564812 | Validation Accuracy = 75.78%
EPOCH #20: Avg Training Loss = 0.5047502024551792


INFO: 
Detected KeyboardInterrupt, attempting graceful shutdown ...
INFO:lightning.pytorch.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined