In [12]:
import pandas as pd
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import os
from einops import rearrange, repeat
import einops
from glob import glob
from math import log
import math
from tqdm import tqdm
import pickle
#from staf import StafLayer
    #from staf import INR
from mamba_ssm import Mamba
from mamba_ssm.modules.block import Block
import matplotlib.pyplot as plt
from transformer import TransformerEncoderINR
import time
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import cProfile

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class BiMamba(torch.nn.Module):
    def __init__(self, dim = 512):
        super(BiMamba, self).__init__()
        
        self.f_mamba = Mamba(d_model = dim)
        self.r_mamba = Mamba(d_model = dim)
        
    def forward(self, x, **kwargs):
        x_f = self.f_mamba(x, **kwargs)
        x_r = torch.flip(self.r_mamba(torch.flip(x, dims=[1]), **kwargs), dims=[1])
        out = (x_f + x_r)/2
        
        return out
    
class MambaINRModel(torch.nn.Module):

    def __init__(self, input_size, token_dim = 512, output_size = 3, model_type = 'stacked'):
        super(MambaINRModel, self).__init__()
        if model_type == 'stacked':
            self.mamba = MambaStack(num = 6, token_dim = token_dim)
        else:
            self.mamba = BiMamba(token_dim = token_dim)
            
        self.input = torch.nn.Linear(input_size, token_dim)
        self.output = torch.nn.Linear(token_dim, output_size)
        torch.nn.Linear(token_dim, output_size)
        self.sig = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self.input(x)
        x = self.mamba(x)
        x = self.output(x)
        x = self.sig(x)

        return x

class MambaCLS(torch.nn.Module):

    def __init__(self, input_size, token_dim = 512, output_size = 3, model_type = 'stacked', num_lp = 1):
        super(MambaCLS, self).__init__()
        self.token_dim = token_dim
        if model_type == 'stacked':
            self.mamba = MambaStack(num = 6, token_dim = self.token_dim)
        else:
            self.mamba = BiMamba(token_dim = self.token_dim)
            
        self.input = torch.nn.Linear(input_size, self.token_dim)
        self.output = torch.nn.Linear(self.token_dim, output_size)
        self.pred_out= torch.nn.Sequential(torch.nn.Linear(self.token_dim, 2*self.token_dim), torch.nn.Linear(2*self.token_dim, 2*self.token_dim),
                                               torch.nn.Linear(2*self.token_dim, output_size))

        self.sig = torch.nn.Sigmoid()

        self.num_lp = num_lp
        self.lp = torch.nn.Parameter(torch.empty((self.num_lp, self.token_dim), dtype = torch.float32))
        
        self.lp_idxs = None
    
    def set_lp_idxs(self, lp_idxs):
        self.lp_idxs = lp_idxs
        
    def add_lp(self, x):
        
        if x.ndim == 2:
            seq_len = x.shape[0]
        elif x.ndim == 3:
            seq_len = x.shape[1]
        total_len = seq_len + self.num_lp
        
        chunk_size = round(seq_len/(self.num_lp+1))
        insert_idxs = torch.clamp(torch.tensor([(chunk_size+1)*(x+1)-1 for x in range (self.num_lp)]), min = 0, max = total_len-1)
        self.set_lp_idxs(insert_idxs)
        

        mask = torch.zeros(total_len, dtype=torch.bool)
        mask[insert_idxs] = True

        out = torch.empty((x.shape[0], total_len, self.token_dim), dtype=torch.float32).to(x.device)
        out[:, mask] = self.lp
        out[:, ~mask] = x
        return out
    
    def extract_lp_tokens(self, x):
        return x[:, self.lp_idxs]
        
    def forward(self, x):
        
        x = self.input(x)
        #print(x.shape)
        x = self.add_lp(x)
        #print(x.shape)
        x = self.mamba(x)
        #print(x.shape)
        #x = self.output(x)
        x = self.extract_lp_tokens(x)
        x = self.pred_out(x.squeeze(1))

        return x

class MambaStack(torch.nn.Module):
    def __init__(self, num = 3, token_dim = 512):
        super(MambaStack, self).__init__()
        self.blocks = nn.ModuleList([
            Block(
                dim=token_dim,
                mixer_cls= lambda dim: BiMamba(dim),
                mlp_cls= lambda dim: torch.nn.Sequential(
                    nn.Linear(dim, 4 * dim),
                    nn.GELU(),
                    nn.Linear(4 * dim, dim),
                ),
                norm_cls=nn.LayerNorm,  # or RMSNorm
                fused_add_norm=False
            )
            for _ in range(num)
        ])
    
    def forward(self, x):
        residual = None
        for block in self.blocks:
            x, residual = block(x, residual=residual)
        return x


In [5]:
test_input = torch.randn((1, 5))
print(test_input)
test_input = test_input.unsqueeze(2)
print(test_input)

tensor([[-0.2638, -0.8965,  0.1289,  0.1090, -1.1527]])
tensor([[[-0.2638],
         [-0.8965],
         [ 0.1289],
         [ 0.1090],
         [-1.1527]]])


In [6]:
model = MambaCLS(input_size = 1, output_size = 10).to(device)
model.eval()

test_input = torch.randn((1, 28*28)).unsqueeze(2).to(device)
print(test_input.shape)
out = model(test_input)

print(out.shape)

torch.Size([1, 784, 1])
torch.Size([1, 10])


In [7]:
batch_size = 64
num_epochs = 5
lr = 1e-3

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.FashionMNIST(root='.', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size)

In [9]:
x = torch.tensor([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]])
x = torch.transpose(x, 1, 2)
print(x.reshape(1, 10,))

tensor([[ 1,  6,  2,  7,  3,  8,  4,  9,  5, 10]])


In [10]:
def scan(x):
    return x.reshape(x.shape[0], -1).unsqueeze(2)

In [15]:
def train():
    # ======= Model Setup =======
    model_dim = 128
    num_classes = 10

    model = MambaCLS(input_size = 1, output_size = num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # ======= Training Loop =======
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for imgs, labels in tqdm(train_loader):
            imgs = imgs.to(device)
            labels = labels.to(device)
            #x = patchify(imgs, patch_size)  # (B, N, D)
            x = scan(imgs)
            preds = model(x)
            loss = criterion(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"[Epoch {epoch+1}] Train Loss: {total_loss / len(train_loader):.4f}")

        # Eval
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in tqdm(test_loader):
                imgs = imgs.to(device)
                labels = labels.to(device)
                x = scan(imgs)
                preds = model(x)
                pred_labels = preds.argmax(dim=1)
                correct += (pred_labels == labels).sum().item()
                total += labels.size(0)

        acc = correct / total
        print(f"[Epoch {epoch+1}] Test Accuracy: {acc:.2%}")


In [16]:
cProfile.run('train()', sort='cumtime')

  5%|▌         | 48/938 [01:04<19:59,  1.35s/it]

         816651 function calls (802132 primitive calls) in 64.978 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   64.978   64.978 {built-in method builtins.exec}
        1    0.000    0.000   64.978   64.978 <string>:1(<module>)
        1    0.021    0.021   64.978   64.978 3215347763.py:1(train)
       49    0.001    0.000   43.678    0.891 _tensor.py:592(backward)
       49    0.001    0.000   43.677    0.891 __init__.py:243(backward)
       49    0.001    0.000   43.672    0.891 graph.py:815(_engine_run_backward)
       49   43.671    0.891   43.671    0.891 {method 'run_backward' of 'torch._C._EngineBase' objects}
8820/3234    0.017    0.000   20.136    0.006 module.py:1747(_wrapped_call_impl)
8820/3234    0.028    0.000   20.130    0.006 module.py:1755(_call_impl)
       49    0.002    0.000   19.894    0.406 3962538750.py:86(forward)
       49   18.246    0.372   18.246    0.372 3962538




KeyboardInterrupt: 