In [1]:
import torch
from mamba_ssm import Mamba as MambaLayer

In [9]:
class GLU(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, input_dim * 2)
    def forward(self, x):
        out = self.linear(x)
        return out[:, :, :x.shape[2]] * torch.sigmoid(out[:, :, x.shape[2]:])
    
class MambaBlock(torch.nn.Module):
    def __init__(self, hidden_dim, state_dim, conv_dim, expansion, dropout, prenorm):
        super().__init__()
        self.norm = torch.nn.LayerNorm(hidden_dim)
        self.mamba = MambaLayer(d_model=hidden_dim, d_state=state_dim, d_conv=conv_dim, expand=expansion)
        self.glu = GLU(hidden_dim)
        self.activation = torch.nn.GELU()
        self.dropout = torch.nn.Dropout(dropout)
        self.prenorm = prenorm
    def forward(self, x):
        skip = x
        if self.prenorm:
            x = self.norm(x)
        x = self.mamba(x)
        x = self.dropout(self.activation(x))
        x = self.glu(x)
        x = self.dropout(x)
        x = x + skip
        if not self.prenorm:
            x = self.norm(x)
        return x
    
class Mamba(torch.nn.Module):
    def __init__(self, num_blocks, input_dim, output_dim, hidden_dim, state_dim, conv_dim, expansion, dropout, prenorm):
        super().__init__()
        self.linear_encoder = torch.nn.Linear(input_dim, hidden_dim)
        self.blocks = torch.nn.Sequential(*[MambaBlock(hidden_dim, state_dim, conv_dim, expansion, dropout, prenorm) for _ in range(num_blocks)])
        self.linear_decoder = torch.nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = self.linear_encoder(x)
        x = self.blocks(x)
        x = torch.mean(x, dim=1)
        x = torch.softmax(self.linear_decoder(x), dim=1)
        return x

In [14]:
def train(
    seed,
	dataloader,
    num_epochs,
    learning_rate,
    wd,
    num_blocks,
	input_dim,
	output_dim,
    hidden_dim,
    state_dim,
    conv_dim,
    expansion,
    dropout,
    prenorm
    ):
    torch.manual_seed(seed)
    device = "cuda"
    model = Mamba(num_blocks, input_dim, output_dim, hidden_dim, state_dim, conv_dim, expansion, dropout, prenorm).to(device)
    print("Nr. of parameters: {0}".format(sum(p.numel() for p in model.parameters())))
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=wd)
    running_loss = 0.0
    for epoch in range(num_epochs):
        for X, y in dataloader:
            optimizer.zero_grad()
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)
            loss = torch.nn.functional.cross_entropy(y_hat, y)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        model.eval()
        train_accuracy = 0.0
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)
            accuracy = (y_hat.argmax(dim=1) == y).float().sum() / len(y)
            train_accuracy += accuracy
        print(train_accuracy / len(dataloader))
        model.train()

In [4]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [5]:
import torchvision
import torchvision.transforms as transforms

In [6]:
transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
        transforms.Lambda(lambda x: x.view(1, 1024).t())
    ])

# Train with no data augmentation
transform_train = transform_test = transform

trainset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=True, download=True, transform=transform_train)
trainset, _ = split_train_val(trainset, val_split=0.1)

valset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=True, download=True, transform=transform_test)
_, valset = split_train_val(valset, val_split=0.1)

testset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=False, download=True, transform=transform_test)

d_input = 1 
d_output = 10

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [7]:
BATCH_SIZE = 32

# Dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
model = Mamba(4, 1, 10, 64, 64, 4, 2).to("cuda")

In [11]:
del model

In [15]:
train(
    seed=1234,
	dataloader=trainloader,
	num_epochs=30,
    learning_rate=2e-4,
    wd=0.1,
    num_blocks=6,
	input_dim=1,
	output_dim=10,
    hidden_dim=64,
    state_dim=256,
    conv_dim=4,
    expansion=2,
    dropout=0.0,
    prenorm=False
    )

Nr. of parameters: 800266
tensor(0.2942, device='cuda:0')
tensor(0.3396, device='cuda:0')
tensor(0.4153, device='cuda:0')
tensor(0.4485, device='cuda:0')
tensor(0.4818, device='cuda:0')
tensor(0.5144, device='cuda:0')
tensor(0.5279, device='cuda:0')
tensor(0.5146, device='cuda:0')
tensor(0.5332, device='cuda:0')
tensor(0.5583, device='cuda:0')
tensor(0.5776, device='cuda:0')
tensor(0.5827, device='cuda:0')
tensor(0.5898, device='cuda:0')
tensor(0.5572, device='cuda:0')
tensor(0.5951, device='cuda:0')
tensor(0.6089, device='cuda:0')
tensor(0.6126, device='cuda:0')
tensor(0.6128, device='cuda:0')
tensor(0.6267, device='cuda:0')
tensor(0.6299, device='cuda:0')
tensor(0.6269, device='cuda:0')
tensor(0.6428, device='cuda:0')
tensor(0.6449, device='cuda:0')
tensor(0.6505, device='cuda:0')
tensor(0.6584, device='cuda:0')
tensor(0.6437, device='cuda:0')
tensor(0.6536, device='cuda:0')
tensor(0.6336, device='cuda:0')
tensor(0.6667, device='cuda:0')
tensor(0.6415, device='cuda:0')
