# ENV

In [1]:
%%capture
!pip install git+https://github.com/google/flax.git
!pip install datasets
# !pip install -U "jax[cuda12]"

In [2]:
import jax
from jax import random
from jax import numpy as jnp
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve

from flax import nnx
import optax

import torch
import math
from typing import Union
from dataclasses import dataclass

from einops import repeat, rearrange, einsum, reduce

In [3]:
print(jax.devices())

[CudaDevice(id=0)]


In [4]:
key = random.key(0)

# Data

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('FacebookAI/roberta-base')
tokenizer.padding_side='left'

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



In [6]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset('LTCB/enwik8', trust_remote_code=True, split='train[:5%]')
dataset = dataset.train_test_split(train_size=0.9)
dataset['validation']=dataset.pop('test')
dataset

enwik8.py:   0%|          | 0.00/2.94k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/36.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1128024 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 50760
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 5641
    })
})

In [7]:
def preprocessing_fn(batch):
    tokenized = tokenizer(batch['text'], max_length = 1024, truncation=True)
    tokenized.pop('attention_mask', None)
    return tokenized

tokenized_dataset = dataset.map(preprocessing_fn, batched=True, batch_size=1024, remove_columns = dataset['train'].column_names)
# tokenized_dataset.with_format("jax")
print(tokenized_dataset)
del dataset

Map:   0%|          | 0/50760 [00:00<?, ? examples/s]

Map:   0%|          | 0/5641 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 50760
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 5641
    })
})


In [8]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, return_tensors='jax')

In [9]:
def collator(data):
    data = data_collator(data)
    input_ids = data['input_ids'][:, :-1]
    labels = data['input_ids'][:, 1:]

    return input_ids, labels

input_ids, labels = collator([tokenized_dataset["train"][i] for i in range(5)])
print(input_ids.shape)
print(type(labels))

(5, 204)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [10]:
from torch.utils.data import DataLoader

def create_data_loader(batch_size):
    tokenized_dataset.set_format('torch')
    train_loader = DataLoader(tokenized_dataset['train'], batch_size=batch_size,collate_fn=collator, drop_last=True, shuffle=True,)
    val_loader = DataLoader(tokenized_dataset['validation'], batch_size=batch_size, collate_fn= collator, drop_last=True)

    return train_loader,val_loader

train_loader, val_loader = create_data_loader(32)
print(len(train_loader))
batch = next(iter(train_loader))
print(len(batch))
print(batch[0].shape)

1586
2
(32, 161)


# Scan

In [11]:
def npo2(length):
    return 2 ** math.ceil(math.log2(length))

def pad_npo2(X):
    len_npo2 = npo2(X.shape[1])
    pad_tuple = [(0, 0), (0, len_npo2 - X.shape[1]), (0, 0), (0, 0)]
    return jnp.pad(X, pad_tuple, mode="constant", constant_values=0)

class PScan:
    @staticmethod
    def pscan(A, X):
        B, D, L, _ = A.shape
        num_steps = int(math.log2(L))

        Aa, Xa = jnp.copy(A), jnp.copy(X)

        for _ in range(num_steps - 2):
            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa = Xa.at[:, :, :, 1].set(Xa[:, :, :, 1] + Aa[:, :, :, 1] * Xa[:, :, :, 0])
            Aa = Aa.at[:, :, :, 1].set(Aa[:, :, :, 1] * Aa[:, :, :, 0])

            Aa, Xa = Aa[:, :, :, 1], Xa[:, :, :, 1]

        if Xa.shape[2] == 4:
            Xa = Xa.at[:, :, 1].set(Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0])
            Aa = Aa.at[:, :, 1].set(Aa[:, :, 1] * Aa[:, :, 0])
            Xa = Xa.at[:, :, 3].set(Xa[:, :, 3] + Aa[:, :, 3] * (Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1]))
        elif Xa.shape[2] == 2:
            Xa = Xa.at[:, :, 1].set(Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0])
            return
        else:
            return

        Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa = Xa.at[:, :, 2].set(Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1])
        Aa = Aa.at[:, :, 2].set(Aa[:, :, 2] * Aa[:, :, 1])

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 2**k-1:L:2**k]
            Xa = X[:, :, 2**k-1:L:2**k]

            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa = Xa.at[:, :, 1:, 0].set(Xa[:, :, 1:, 0] + Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1])
            Aa = Aa.at[:, :, 1:, 0].set(Aa[:, :, 1:, 0] * Aa[:, :, :-1, 1])

    @staticmethod
    def pscan_rev(A, X):
        B, D, L, _ = A.shape
        num_steps = int(math.log2(L))

        Aa, Xa = jnp.copy(A), jnp.copy(X)

        for _ in range(num_steps - 2):
            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa = Xa.at[:, :, :, 0].set(Xa[:, :, :, 0] + Aa[:, :, :, 0] * Xa[:, :, :, 1])
            Aa = Aa.at[:, :, :, 0].set(Aa[:, :, :, 0] * Aa[:, :, :, 1])

            Aa, Xa = Aa[:, :, :, 0], Xa[:, :, :, 0]

        if Xa.shape[2] == 4:
            Xa = Xa.at[:, :, 2].set(Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 3])
            Aa = Aa.at[:, :, 2].set(Aa[:, :, 2] * Aa[:, :, 3])
            Xa = Xa.at[:, :, 0].set(Xa[:, :, 0] + Aa[:, :, 0] * (Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 2]))
        elif Xa.shape[2] == 2:
            Xa = Xa.at[:, :, 0].set(Xa[:, :, 0] + Aa[:, :, 0] * Xa[:, :, 1])
        else:
            return

        Aa = A[:, :, 0:L:2**(num_steps-2)]
        Xa = X[:, :, 0:L:2**(num_steps-2)]
        Xa = Xa.at[:, :, 1].set(Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 2])
        Aa = Aa.at[:, :, 1].set(Aa[:, :, 1] * Aa[:, :, 2])

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 0:L:2**k]
            Xa = X[:, :, 0:L:2**k]

            Aa, Xa = rearrange(Aa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2), rearrange(Xa, 'b d (t2 t1) v -> b d t2 t1 v', t1=2)

            Xa = Xa.at[:, :, :-1, 1].set(Xa[:, :, :-1, 1] + Aa[:, :, :-1, 1] * Xa[:, :, 1:, 0])
            Aa = Aa.at[:, :, :-1, 1].set(Aa[:, :, :-1, 1] * Aa[:, :, 1:, 0])

    @staticmethod
    @jax.custom_vjp
    def pscan_apply(A_in, X_in):
        L = X_in.shape[1]

        if L == npo2(L):
            A, X = jnp.copy(A_in), jnp.copy(X_in)
        else:
            A, X = pad_npo2(A_in), pad_npo2(X_in)

        A, X = rearrange(A, 'b d t ... -> b t d ...'), rearrange(X, 'b d t ... -> b t d ...')

        PScan.pscan(A, X)

        return rearrange(X, 'b t d ... -> b d t ...')[:, :L]

    @staticmethod
    def pscan_apply_vjp_fwd(A_in, X_in):
        # Forward pass
        L = X_in.shape[1]

        if L == npo2(L):
            A, X = jnp.copy(A_in), jnp.copy(X_in)
        else:
            A, X = pad_npo2(A_in), pad_npo2(X_in)

        A, X = rearrange(A, 'b d t ... -> b t d ...'), rearrange(X, 'b d t ... -> b t d ...')

        PScan.pscan(A, X)

        return rearrange(X, 'b t d ... -> b d t ...')[:, :L], (A_in, X)

    @staticmethod
    def pscan_apply_vjp_bwd(residual, g):
        A_in, X = residual

        L = g.shape[1]

        if L == npo2(L):
            grad_output = jnp.copy(g)
        else:
            grad_output = pad_npo2(g)
            A_in = pad_npo2(A_in)

        grad_output = rearrange(grad_output, 'b d t ... -> b t d ...')
        A_in = rearrange(A_in, 'b d t ... -> b t d ...')
        A = jnp.pad(A_in[:, :, 1:], [(0, 0), (0, 0), (0, 1), (0, 0)])

        PScan.pscan_rev(A, grad_output)

        Q = jnp.zeros_like(X)
        Q = Q.at[:, :, 1:].add(X[:, :, :-1] * grad_output[:, :, 1:])

        return (None, rearrange(Q, 'b t d ... -> b d t ...')[:, :L])

# Register forward and backward
PScan.pscan_apply.defvjp(PScan.pscan_apply_vjp_fwd, PScan.pscan_apply_vjp_bwd)

# Architecture

In [12]:
@dataclass
class ModelConfig:
    emb_dim: int
    n_layer: int
    vocab_size: int
    state_dim: int = 16
    expand: int =2
    delta_rank: Union[int,str] = 'auto'
    kernel_size: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    linear_bias: bool = False

    def __post_init__(self):
        self.inner_dim = int(self.expand*self.emb_dim)

        if self.delta_rank == 'auto':
            self.delta_rank = math.ceil(self.emb_dim/16)

        if (self.vocab_size % self.pad_vocab_size_multiple) !=0:
            self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size%self.pad_vocab_size_multiple)

config = ModelConfig(emb_dim = 128, n_layer = 2, vocab_size=80)

In [13]:
class RMSNorm(nnx.Module):
    def __init__(self, emb_dim, eps=1e-5, rngs=nnx.Rngs(0)):
        super().__init__()
        self.eps = eps
        self.weight = nnx.Param(nnx.initializers.ones(rngs.params(), [emb_dim,]))

    def __call__(self, x):
        return x*jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps) * self.weight

x = random.uniform(key,[1,128])
tmp = RMSNorm(128)
tmp(x).shape

(1, 128)

In [14]:
class MambaBlock(nnx.Module):
    def __init__(self, args, rngs=nnx.Rngs(0)):
        self.args = args

        linear_dim = 2*args.inner_dim + 2*args.state_dim + args.delta_rank
        self.in_proj = nnx.Linear(args.emb_dim, linear_dim, use_bias=args.linear_bias, rngs=rngs)

        conv_dim = args.inner_dim + 2*args.state_dim
        self.conv = nnx.Conv(conv_dim, conv_dim, kernel_size = args.kernel_size,
                             padding='CAUSAL', feature_group_count=conv_dim, use_bias=args.conv_bias, rngs=rngs)

        self.delta_proj = nnx.Linear(args.delta_rank, args.inner_dim, rngs=rngs)

        A = repeat(jnp.arange(1, args.state_dim+1), 'n -> d n', d=args.inner_dim)
        self.A_log = nnx.Param(jnp.log(A))
        self.D = nnx.Param(nnx.initializers.ones(rngs.params(), [args.inner_dim]))

        self.out_proj = nnx.Linear(args.inner_dim, args.emb_dim, use_bias=args.linear_bias, rngs=rngs)

    def __call__(self, x, decode=0):
        batch_size, seq_len, emb_dim = x.shape

        z_xBC_delta = self.in_proj(x)
        z, xBC, delta = jnp.split(z_xBC_delta, [self.args.inner_dim,
                                                self.args.inner_dim + self.args.inner_dim + 2*self.args.state_dim], axis=-1)

        xBC = self.conv(xBC)
        xBC = xBC/(1 + jnp.exp(-xBC))
        x, B, C = jnp.split(xBC, [self.args.inner_dim,
                                  self.args.inner_dim + self.args.state_dim], axis=-1)

        y = self.ssm(x, delta, B, C, decode)
        y *= nnx.silu(z)
        out = self.out_proj(y)

        return out

    def ssm(self, x, delta, B, C, decode=0):
        emb_dim, state_dim = self.A_log.shape
        A = -jnp.exp(self.A_log.value)
        D = self.D

        delta = nnx.softplus(self.delta_proj(delta))
        if decode:
            y = self.RNN_selective_scan(x, delta, A, B, C, D)
        else:
            y = self.CNN_selective_scan(x, delta, A, B, C, D)
        return y

    def RNN_selective_scan(self, u, delta, A, B, C, D):
        batch_size, seq_len, emb_dim = u.shape
        state_dim = A.shape[1]

        dA = jnp.exp(einsum(delta, A, 'b l d, d n -> b l d n'))
        dB_u = einsum(delta, B, u, 'b l d, b l n, b l d-> b l d n')

        h = jnp.zeros((u.shape[0], self.args.inner_dim, self.args.state_dim))
        hs = []
        for t in range(0, seq_len):
            h = dA[:, t] * h + dB_u[:, t]
            hs.append(h)
        # jax.lax.scan
        hs = jnp.stack(hs, axis=1)
        y = (hs @ jnp.expand_dims(C, axis=-1)).squeeze(axis=3)
        y = y + D * u
        return y

    def CNN_selective_scan(self, u, delta, A, B,C,D):
        dA = einsum(delta, A, 'b l d,d n -> b l d n')
        dB_u = einsum(delta, u, B, 'b l d, b l d, b l n -> b l d n')

        # dA_cumsum = jnp.flip(jnp.cumsum(jnp.flip(jnp.pad(dA[:, 1:], ((0, 0), (0, 1), (0, 0), (0, 0))), axis=1), axis=1), axis=1)

        # x = dB_u * dA_cumsum
        # x = x.cumsum(1) / (dA_cumsum + 1e-12)

        # y = einsum(x, C,'b l d n,b l n -> b l d')

        hs = PScan.pscan_apply(dA, dB_u)
        y = (hs @ jnp.expand_dims(C, axis=-1)).squeeze(axis=3)

        return y


x = random.uniform(key,[1,128,config.emb_dim])
tmp=MambaBlock(config)
x1=tmp(x)
x2=tmp(x, decode=True)
print(jnp.isclose(x1,x2).sum()/(2*128*config.emb_dim))

0.0


In [15]:
class ResidualBlock(nnx.Module):
    def __init__(self, args, rngs=nnx.Rngs(0)):
        super().__init__()
        self.mixer = MambaBlock(args, rngs)
        self.norm = RMSNorm(args.emb_dim, rngs=rngs)

    def __call__(self,x,decode=False):
        return self.mixer(self.norm(x), decode) + x

x = random.uniform(key,[1,12,config.emb_dim])
tmp=ResidualBlock(config)
tmp(x).shape

(1, 12, 128)

In [16]:
class Mamba(nnx.Module):
    def __init__(self,args, rngs=nnx.Rngs(0)):
        super().__init__()
        self.embedding = nnx.Embed(args.vocab_size, args.emb_dim, rngs=rngs)
        self.layers = ([ResidualBlock(args, rngs) for _ in range(args.n_layer)])
        self.norm = RMSNorm(args.emb_dim, rngs=rngs)

        self.head = nnx.Linear(args.emb_dim, args.vocab_size, rngs=rngs)

    def __call__(self, x, decode=False):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, decode)

        x = self.norm(x)
        logits = self.head(x)

        return logits

x = random.randint(key, [2,128], 0, 5)
tmp=Mamba(config)
jnp.isnan(tmp(x)).sum()

Array(0, dtype=int32)

# Training

In [17]:
config = ModelConfig(emb_dim = 32,
                     n_layer = 2,
                     vocab_size=tokenizer.vocab_size,
                     state_dim= 24,
                     expand=2,
                     delta_rank = 'auto',
                     kernel_size = 5,
                     pad_vocab_size_multiple = 8,
                     conv_bias = True,
                     linear_bias = False
)
model = Mamba(config)
# .to(device)



num_epochs = 10
batch_size = 256
learning_rate = 5e-4
momentum = 0.9
train_loader, val_loader = create_data_loader(16)

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(),
                          loss=nnx.metrics.Average('loss'))

In [18]:
from tqdm.auto import tqdm


def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, batch):
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  optimizer.update(grads)
  accuracy  = jnp.mean((jnp.argmax(logits, axis=-1) == batch[1]))
  return loss,accuracy

for epoch in tqdm(range(num_epochs)):
    p_bar = tqdm(train_loader, desc= f'Epoch {epoch+1}/{num_epochs}: ')
    for batch in p_bar:
        loss, accuracy = train_step(model, optimizer, metrics, batch)
        p_bar.set_postfix(loss = loss, accuracy=accuracy)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/3172 [00:00<?, ?it/s]

KeyboardInterrupt: 