In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchmetrics
import lightning as L
from models import AutoEncoder, MambaImgClassifier
import matplotlib.pyplot as plt
import learn2learn as l2l

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# KMNIST

In [None]:
class KuzushijiMNISTDataset(Dataset):
    def __init__(self, imgs_npz_file, labels_npz_file, transform=None):
        imgs_data = np.load(imgs_npz_file)
        labels_data = np.load(labels_npz_file)
        self.images = imgs_data['arr_0']
        self.labels = labels_data['arr_0']
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # add more transforms if needed (channels first?)
])
train_dataset = KuzushijiMNISTDataset('dataset/npz/kmnist-train-imgs.npz', 
                                      'dataset/npz/kmnist-train-labels.npz', 
                                      transform=transform)
val_dataset = KuzushijiMNISTDataset('dataset/npz/kmnist-test-imgs.npz', 
                                    'dataset/npz/kmnist-test-labels.npz', 
                                    transform=transform)

# Instantiate the dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Autoencoder

In [None]:
class AutoEncoderLightning(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = AutoEncoder()
        self.loss = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch  # We don't need labels for autoencoder training
        x = x.view(x.size(0), -1)  # Flatten the images
        x_hat = self(x)
        loss = self.loss(x_hat, x)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        val_loss = self.loss(x_hat, x)
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)
    
    # Display the original and reconstructed images at the end of the training
    def on_fit_end(self) -> None:
        x, _ = next(iter(self.trainer.val_dataloaders))
        x = x.view(x.size(0), -1).to(self.device)
        x_hat = self(x)
        # Select a few images to visualize
        x = x.view(-1, 1, 28, 28)[:6].cpu()
        x_hat = x_hat.view(-1, 1, 28, 28)[:6].cpu()
        
        fig, axs = plt.subplots(2, 6, figsize=(15, 5))
        for i in range(6):
            axs[0, i].imshow(x[i].squeeze().numpy(), cmap='gray')
            axs[0, i].axis('off')
            axs[1, i].imshow(x_hat[i].squeeze().detach().numpy(), cmap='gray')
            axs[1, i].axis('off')
        plt.show()

In [None]:
lightning_model = AutoEncoderLightning()

trainer = L.Trainer(max_epochs=10)
trainer.fit(lightning_model, train_loader, val_loader)

# Mamba

In [None]:
class ImgClassifierLightning(L.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)

    # Forward pass
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
# Instantiate model
patch_size = 4
img_size = 28
embed_dim = 256
dropout = 0.1
n_layers = 6
n_channels = 1

model = MambaImgClassifier(patch_size, img_size, n_channels, embed_dim, n_layers, dropout)

# Wrap in Lightning module
lightning_model = ImgClassifierLightning(model)

# Trainer
trainer = L.Trainer(max_epochs=5)
trainer.fit(lightning_model, train_loader, val_loader)

# Mamba w/ MAML

This is not working properly yet. I need to figure out how to create a better task sampler and adjust the training loop accordingly

In [None]:
class MAMLLightning(L.LightningModule):
    def __init__(self, model, learning_rate=1e-3, steps=1, lr_inner=0.01):
        super().__init__()
        self.model = model
        self.maml = l2l.algorithms.MAML(self.model, lr=lr_inner, first_order=False)
        self.learning_rate = learning_rate
        self.steps = steps  # Number of adaptation steps for each task
        self.lr_inner = lr_inner  # Learning rate for task-specific adaptation

    # Forward pass
    def forward(self, x):
        return self.maml(x)

    # TODO: https://github.com/learnables/learn2learn/blob/master/examples/vision/meta_mnist.py
    def training_step(self, batch, batch_idx):
        x, y = batch
        # Create a fast model using the current meta model
        learner = self.maml.clone()  # Create a clone of the model for this task
        for _ in range(self.steps):
            y_hat = learner(x)
            task_loss = F.cross_entropy(y_hat, y)
            learner.adapt(task_loss)

        # Compute loss on the adapted model to update the meta-model
        y_hat = learner(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        # Update the meta-model parameters, not the adapted model parameters
        return torch.optim.Adam(self.maml.parameters(), lr=self.learning_rate)

In [None]:
train_dataset = l2l.data.MetaDataset(train_dataset)

train_tasks = l2l.data.Taskset(
    train_dataset,
    task_transforms=[
        l2l.data.transforms.NWays(train_dataset, 5),
        l2l.data.transforms.KShots(train_dataset, 2),
        l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ],
    num_tasks=1000,
)

train_loader = DataLoader(train_tasks, batch_size=32, shuffle=True)

In [None]:
patch_size = 4
img_size = 28
embed_dim = 256
dropout = 0.1
n_layers = 6
n_channels = 1

model = MambaImgClassifier(patch_size, img_size, n_channels, embed_dim, n_layers, dropout)
lightning_model = MAMLLightning(model)

trainer = L.Trainer(max_epochs=5)
trainer.fit(lightning_model, train_loader, val_loader)