# Env

In [1]:
%%capture
!pip install einops
!pip install datasets

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

import math
from typing import Union
from dataclasses import dataclass

from einops import repeat, rearrange, einsum

# Data

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('FacebookAI/roberta-base')
tokenizer.padding_side='left'

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



In [4]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset('LTCB/enwik8', trust_remote_code=True, split='train[:5%]')
dataset = dataset.train_test_split(train_size=0.9)
dataset['validation']=dataset.pop('test')
dataset

enwik8.py:   0%|          | 0.00/2.94k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/36.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1128024 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 50760
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 5641
    })
})

In [5]:
def preprocessing_fn(batch):
    tokenized = tokenizer(batch['text'], max_length = 1024, truncation=True)
    tokenized.pop('attention_mask', None)
    return tokenized

tokenized_dataset = dataset.map(preprocessing_fn, batched=True, batch_size=1024, remove_columns = dataset['train'].column_names)
print(tokenized_dataset)
del dataset

Map:   0%|          | 0/50760 [00:00<?, ? examples/s]

Map:   0%|          | 0/5641 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 50760
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 5641
    })
})


In [6]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, return_tensors='pt')

In [7]:
def collator(data):
    data = data_collator(data)
    input_ids = data['input_ids'][:, :-1]
    labels = data['input_ids'][:, 1:]

    return input_ids, labels

input_ids, labels = collator([tokenized_dataset["train"][i] for i in range(5)])
print(input_ids.shape)
print(labels.shape)

torch.Size([5, 72])
torch.Size([5, 72])


In [8]:
from torch.utils.data import DataLoader

def create_data_loader(batch_size):
    tokenized_dataset.set_format('torch')
    train_loader = DataLoader(tokenized_dataset['train'], batch_size=batch_size,collate_fn=collator, drop_last=True, shuffle=True,)
    val_loader = DataLoader(tokenized_dataset['validation'], batch_size=batch_size, collate_fn= collator, drop_last=True)

    return train_loader,val_loader

train_loader, val_loader = create_data_loader(32)
print(len(train_loader))
batch = next(iter(train_loader))
print(len(batch))
print(batch[0].shape)

1586
2
torch.Size([32, 182])


# Scan

In [18]:
import math
import torch
import torch.nn.functional as F
from einops import rearrange

def npo2(len):
    return 2 ** math.ceil(math.log2(len))

def pad_npo2(X):
    len_npo2 = npo2(X.size(1))
    pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
    return F.pad(X, pad_tuple, "constant", 0)

class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # Initialize tensors for intermediate calculations
        Aa, Xa = A.clone(), X.clone()

        for _ in range(num_steps - 2):
            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0]
            Aa[:, :, :, 1] *= Aa[:, :, :, 0]

            Aa, Xa = Aa[:, :, :, 1], Xa[:, :, :, 1]

        if Xa.size(2) == 4:
            Xa[:, :, 1] += Aa[:, :, 1] * Xa[:, :, 0]
            Aa[:, :, 1] *= Aa[:, :, 0]
            Xa[:, :, 3] += Aa[:, :, 3] * (Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1])
        elif Xa.size(2) == 2:
            Xa[:, :, 1] += Aa[:, :, 1] * Xa[:, :, 0]
            return
        else:
            return

        Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa[:, :, 2] += Aa[:, :, 2] * Xa[:, :, 1]
        Aa[:, :, 2] *= Aa[:, :, 1]

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 2**k-1:L:2**k]
            Xa = X[:, :, 2**k-1:L:2**k]

            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1]
            Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1]

    @staticmethod
    def pscan_rev(A, X):
        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        Aa, Xa = A.clone(), X.clone()

        for _ in range(num_steps - 2):
            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa[:, :, :, 0] += Aa[:, :, :, 0] * Xa[:, :, :, 1]
            Aa[:, :, :, 0] *= Aa[:, :, :, 1]

            Aa, Xa = Aa[:, :, :, 0], Xa[:, :, :, 0]

        if Xa.size(2) == 4:
            Xa[:, :, 2] += Aa[:, :, 2] * Xa[:, :, 3]
            Aa[:, :, 2] *= Aa[:, :, 3]
            Xa[:, :, 0] += Aa[:, :, 0] * (Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 2])
        elif Xa.size(2) == 2:
            Xa[:, :, 0] += Aa[:, :, 0] * Xa[:, :, 1]
            return
        else:
            return

        Aa = A[:, :, 0:L:2**(num_steps-2)]
        Xa = X[:, :, 0:L:2**(num_steps-2)]
        Xa[:, :, 1] += Aa[:, :, 1] * Xa[:, :, 2]
        Aa[:, :, 1] *= Aa[:, :, 2]

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 0:L:2**k]
            Xa = X[:, :, 0:L:2**k]

            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa[:, :, :-1, 1] += Aa[:, :, :-1, 1] * Xa[:, :, 1:, 0]
            Aa[:, :, :-1, 1] *= Aa[:, :, 1:, 0]

    @staticmethod
    def forward(ctx, A_in, X_in):
        L = X_in.size(1)

        if L == npo2(L):
            A, X = A_in.clone(), X_in.clone()
        else:
            A, X = pad_npo2(A_in), pad_npo2(X_in)

        # Prepare tensors
        A, X = rearrange(A, 'b d t ... -> b t d ...'), rearrange(X, 'b d t ... -> b t d ...')

        PScan.pscan(A, X)

        ctx.save_for_backward(A_in, X)

        return rearrange(X, 'b d t ... -> b t d ...')[:, :L]

    @staticmethod
    def backward(ctx, grad_output_in):
        A_in, X = ctx.saved_tensors

        L = grad_output_in.size(1)

        if L == npo2(L):
            grad_output = grad_output_in.clone()
        else:
            grad_output = pad_npo2(grad_output_in)
            A_in = pad_npo2(A_in)

        grad_output = rearrange(grad_output, 'b d t ... -> b t d ...')
        A_in = rearrange(A_in, 'b d t ... -> b t d ...')
        A = F.pad(A_in[:, :, 1:], (0, 0, 0, 1))

        PScan.pscan_rev(A, grad_output)

        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])

        return rearrange(Q, 'b d t ... -> b t d ...')[:, :L], rearrange(grad_output, 'b d t ... -> b t d ...')[:, :L]

pscan = PScan.apply


# Architecture

In [19]:
@dataclass
class ModelConfig:
    emb_dim: int
    n_layer: int
    vocab_size: int
    state_dim: int = 16
    expand: int =2
    dt_rank: Union[int,str] = 'auto'
    kernel_size: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    linear_bias: bool = False

    def __post_init__(self):
        self.inner_dim = int(self.expand*self.emb_dim)

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.emb_dim/16)

        if (self.vocab_size % self.pad_vocab_size_multiple) !=0:
            self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size%self.pad_vocab_size_multiple)

config = ModelConfig(emb_dim = 128, n_layer = 2, vocab_size=80)

In [20]:
class RMSNorm(nn.Module):
    def __init__(self,emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        return x * torch.rsqrt(torch.mean(x**2, dim= -1, keepdim=True) + self.eps) *self.weight

x = torch.rand([7,128])
tmp = RMSNorm(128)
tmp(x).shape

torch.Size([7, 128])

In [21]:
# %%timeit
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.emb_dim, args.inner_dim+ args.inner_dim + 2*args.state_dim + args.dt_rank, bias=args.linear_bias)

        conv_dim = args.inner_dim + 2 * args.state_dim
        self.conv = nn.Conv1d(conv_dim, conv_dim, args.kernel_size,
                              padding=args.kernel_size - 1, groups=conv_dim,
                              bias=args.conv_bias)

        self.dt_proj = nn.Linear(args.dt_rank, args.inner_dim)

        A = repeat(torch.arange(1, args.state_dim + 1), 'n -> d n', d=args.inner_dim)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.inner_dim))

        self.out_proj = nn.Linear(args.inner_dim, args.emb_dim, bias=args.linear_bias)

    def forward(self, x: torch.Tensor, decode: bool = False) -> torch.Tensor:
        batch_size, seq_len, emb_dim = x.shape

        z_xBC_dt = self.in_proj(x)
        z, xBC, dt = torch.split(z_xBC_dt,[self.args.inner_dim, self.args.inner_dim + 2 * self.args.state_dim, self.args.dt_rank], dim=-1)

        xBC = rearrange(xBC, 'b l d -> b d l')
        xBC = self.conv(xBC)[:, :, :seq_len]
        xBC = rearrange(xBC, 'b d l -> b l d')
        xBC = F.silu(xBC)

        x, B, C = torch.split(xBC,[self.args.inner_dim, self.args.state_dim, self.args.state_dim], dim=-1)

        y = self.ssm(x,dt, B, C, decode)
        y *= F.silu(z)

        out = self.out_proj(y)

        return out

    def ssm(self, x, delta, B, C, decode):
        emb_dim, state_dim = self.A_log.shape
        A = -torch.exp(self.A_log)
        D = self.D

        delta = F.softplus(self.dt_proj(delta))

        if decode:
            y = self.RNN_selective_scan(x, delta, A, B, C, D)
        else:
            y = self.CNN_selective_scan(x, delta, A, B, C, D)
        return y

    def RNN_selective_scan(self, u, delta, A, B, C, D):
        batch_size, seq_len, emb_dim = u.shape
        state_dim = A.shape[1]

        dA = torch.exp(einsum(delta, A, 'b l d, d n -> b l d n'))
        dB_u = einsum(delta, B, u, 'b l d, b l n, b l d-> b l d n')

        h = torch.zeros(u.size(0), self.args.inner_dim, self.args.state_dim, device=dA.device)
        hs = []
        for t in range(0, seq_len):
            h = dA[:, t] * h + dB_u[:, t]
            hs.append(h)

        hs = torch.stack(hs, dim=1)
        y = (hs @ C.unsqueeze(-1)).squeeze(3)
        y = y + D * u
        return y

    def CNN_selective_scan(self, u, delta, A, B, C, D):
        dA = einsum(delta, A, 'b l d,d n -> b l d n')
        dB_u = einsum(delta, u, B, 'b l d, b l d, b l n -> b l d n')

        # dA_cumsum = F.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1)
        # x = dB_u * dA_cumsum
        # x = x.cumsum(1) / (dA_cumsum + 1e-12)
        # y = torch.einsum('bldn,bln->bld', x, C)

        # y = einsum(x, C,'b l d n,b l n -> b l d')
        hs = pscan(dA, dB_u)
        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
        return y + u * D

x = torch.rand([2,129,config.emb_dim])
tmp=MambaBlock(config)
x1=tmp(x)
# x2=tmp(x, decode=True)
# print(torch.isclose(x1,x2).sum()/(2*128*config.emb_dim))

In [22]:
class ResidualBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.emb_dim)

    def forward(self,x,decode=False):
        return self.mixer(self.norm(x), decode)+x

x = torch.rand([2,10,config.emb_dim])
tmp=ResidualBlock(config)
tmp(x).shape

torch.Size([2, 10, 128])

In [23]:
class Mamba(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.embedding = nn.Embedding(args.vocab_size, args.emb_dim)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm = RMSNorm(args.emb_dim)

        self.head = nn.Linear(args.emb_dim, args.vocab_size)

    def forward(self, x, decode=False):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, decode)

        x = self.norm(x)
        logits = self.head(x)

        return logits

x = torch.randint(low=0, high=5,size=[2,128])
tmp=Mamba(config)
tmp(x).isnan().sum()

tensor(0)

# Training

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [25]:
num_epochs = 10
batch_size = 256
learning_rate = 5e-4
config = ModelConfig(emb_dim = 32,
                     n_layer = 2,
                     vocab_size=tokenizer.vocab_size,
                     state_dim= 24,
                     expand=2,
                     dt_rank = 'auto',
                     kernel_size = 5,
                     pad_vocab_size_multiple = 8,
                     conv_bias = True,
                     linear_bias = False
)
model = Mamba(config).to(torch.float).to(device)
train_loader, val_loader = create_data_loader(16)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [26]:
import gc
from tqdm.auto import tqdm


gc.collect()
torch.cuda.empty_cache()

for epoch in tqdm(range(num_epochs)):
    model.train()
    progress_bar = tqdm(train_loader,desc = f'Epoch {epoch+1}/{num_epochs}: ')
    for batch in progress_bar:
        input_ids , labels = batch
        input_ids=input_ids.to(device)
        labels=labels.to(device)

        output = model(input_ids)
        loss = criterion(output.view(-1, output.size(-1)), labels.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred = torch.argmax(output, dim=-1)
        acc = (pred==labels).sum()/(labels.shape[0]*labels.shape[1])
        progress_bar.set_postfix(acc=acc.item(), loss=loss.item())

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/3172 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model.eval()
acc_list = []
loss_list = []
progress_bar = tqdm(val_loader,desc = f'Epoch {epoch+1}/{num_epochs}: ')
for batch in progress_bar:
    input_ids, attention_mask , labels = batch
    input_ids=input_ids.to(device)
    labels=labels.to(device)
    with torch.no_grad():
        output = model(input_ids, decode=True)
    loss = criterion(output.view(-1, output.size(-1)), labels.reshape(-1))

    pred = torch.argmax(output, dim=-1)
    acc = (pred==labels).sum()/(labels.shape[0]*labels.shape[1])
    acc_list.append(acc.item())
    loss_list.append(loss.item())
    progress_bar.set_postfix(acc=acc.item(), loss=loss.item())

total_acc = sum(acc_list)/len(acc_list)
total_loss = sum(loss_list)/len(loss_list)
progress_bar.set_postfix(acc=total_acc, loss=total_loss)