In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchmetrics
import lightning as L
from mamba import MambaImgClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [2]:
class KuzushijiMNISTDataset(Dataset):
    def __init__(self, imgs_npz_file, labels_npz_file, transform=None):
        imgs_data = np.load(imgs_npz_file)
        labels_data = np.load(labels_npz_file)
        self.images = imgs_data['arr_0']
        self.labels = labels_data['arr_0']
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [12]:
class ImgClassifierLightning(L.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)

    # Forward pass
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # add more transforms if needed (channels first?)
])
train_dataset = KuzushijiMNISTDataset('dataset/npz/kmnist-train-imgs.npz', 
                                      'dataset/npz/kmnist-train-labels.npz', 
                                      transform=transform)
val_dataset = KuzushijiMNISTDataset('dataset/npz/kmnist-test-imgs.npz', 
                                    'dataset/npz/kmnist-test-labels.npz', 
                                    transform=transform)

# Instantiate the dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Instantiate model
patch_size = 4
img_size = 28
embed_dim = 256
dropout = 0.1
n_layers = 6
n_channels = 1

model = MambaImgClassifier(patch_size, img_size, n_channels, embed_dim, n_layers, dropout)

# Wrap in Lightning module
lightning_model = ImgClassifierLightning(model)

# Trainer
trainer = L.Trainer(max_epochs=5)
trainer.fit(lightning_model, train_loader, val_loader)