In [None]:
import os
import math
import matplotlib.pyplot as plt

import jax
import jax.nn as nn
import jax.numpy as jnp
import jax.lax as lax
import jax.tree as jt
import jax.random as random

from jax import jit, value_and_grad, vmap, Array
from einops import repeat, einsum
from operator import getitem
from functools import partial
from typing import Union, NamedTuple
from dataclasses import dataclass
from time import time

In [None]:
!nvidia-smi

In [None]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    dt_min: float = 0.001
    dt_max: float = 0.1
    dt_scale = 1.0
    dt_init_floor = 1e-4
    d_conv: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = self.expand * self.d_model

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
        
        self.orig_vocab_size = self.vocab_size

        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple


class LayerParams(NamedTuple):
    norm: Array
    in_proj: Array
    in_proj_bias: Union[None, Array]
    conv: Array
    conv_bias: Union[None, Array]
    x_proj: Array
    dt_proj: Array
    dt_proj_bias: Array
    A_log: Array
    D: Array
    out_proj: Array
    out_proj_bias: Union[None, Array]


class MambaParams(NamedTuple):
    embedding: Array
    layers: LayerParams
    norm_f: Array


def initialize_params(key, args):
    truncated_normal_stddev = .87962566103423978

    d_model_scale = 1 / (math.sqrt(args.d_model) * truncated_normal_stddev)
    d_inner_scale = 1 / (math.sqrt(args.d_inner) * truncated_normal_stddev)
    dt_rank_scale = 1 / (math.sqrt(args.dt_rank) * truncated_normal_stddev)
    dt_init_std = args.dt_rank ** -0.5 * args.dt_scale

    embed_key, dt_key, layers_key = random.split(key, 3)
    layers_keys = random.split(layers_key, 7)

    embedding = random.truncated_normal(embed_key, -2, 2, (args.vocab_size, args.d_model)) * d_model_scale

    dt = jnp.exp(
        random.uniform(dt_key, (args.n_layer, args.d_inner)) \
            * (math.log(args.dt_max) - math.log(args.dt_min)) + math.log(args.dt_min)
    ).clip(min=args.dt_init_floor)

    layers = LayerParams(
        norm=jnp.ones((args.n_layer, args.d_model)),

        in_proj=random.truncated_normal(
            layers_keys[0], -2, 2, (args.n_layer, args.d_model, args.d_inner * 2)
        ) * d_model_scale,
        in_proj_bias=jnp.zeros((args.n_layer, args.d_inner)) if args.bias else None,

        conv=random.truncated_normal(
            layers_keys[1], -2, 2, (args.n_layer, args.d_inner, args.d_conv)
        ) * d_inner_scale,
        conv_bias=jnp.zeros((args.n_layer, args.d_inner)) if args.conv_bias else None,

        x_proj=random.truncated_normal(
            layers_keys[2], -2, 2, (args.n_layer, args.d_inner, args.dt_rank + 2 * args.d_state)
        ) * d_inner_scale,

        dt_proj=random.uniform(
            layers_keys[3],
            (args.n_layer, args.dt_rank, args.d_inner),
            minval=-dt_init_std,
            maxval=dt_init_std
        ),
        dt_proj_bias=dt + jnp.log(-jnp.expm1(-dt)),

        A_log=repeat(
            jnp.log(jnp.arange(1, args.d_state + 1)),
            'ds -> nl di ds',
            nl=args.n_layer,
            di=args.d_inner
        ),
        D=jnp.ones((args.n_layer, args.d_inner)),

        out_proj=random.truncated_normal(
            layers_keys[4], -2, 2, (args.n_layer, args.d_inner, args.d_model)
        ) * d_inner_scale,
        out_proj_bias=jnp.zeros((args.n_layer, args.d_model)) if args.bias else None,
    )

    norm_f = jnp.ones(args.d_model)

    return MambaParams(embedding=embedding, layers=layers, norm_f=norm_f)


def zero_or(x):
    return 0 if x is None else x


def rms_norm(w, x, eps):
    z = x.astype(jnp.float32)
    norm = z * lax.rsqrt(jnp.mean(z * z, -1, keepdims=True) + eps)
    return w * norm.astype(x.dtype)


# training
def mamba(args, use_associative_scan, params, tokens):

    def block(x, params):
        # (l, d_model) -> (l, d * 2) -> (l, d), (l, d)
        x, res = jnp.split(x @ params.in_proj + zero_or(params.in_proj_bias), 2, -1)
        # (l, d) -> (l + c - 1, d) -> (d, l + c - 1)
        x = jnp.concatenate([jnp.zeros((args.d_conv - 1, args.d_inner)), x], 0).T
        # (d, l + c - 1) -> (d, l) -> (l, d)
        x = vmap(jnp.convolve, (0, 0, None))(x, params.conv, 'valid').T + zero_or(params.conv_bias)
        x = nn.silu(x)
        # (l, d) -> (l, r + s + s) -> (l, r), (l, s), (l, s)
        x_dt, B, C = jnp.split(x @ params.x_proj, [args.dt_rank, args.dt_rank + args.d_state], -1)
        # (l, r) -> (l, d)
        dt = nn.softplus(x_dt @ params.dt_proj + zero_or(params.dt_proj_bias))
        # discretized A and B
        dA = jnp.exp(einsum(dt, -jnp.exp(params.A_log), 'l d, d s -> l d s'))
        dBx = einsum(x * dt, B, 'l d, l s -> l d s')
        # see section 1.4.1 "First-Order Recurrences" in the paper "Prefix Sums and Their Applications"
        # the main loop is equivalent to
        # 
        # ssm_states = []
        # s = jnp.zeros((args.d_inner, args.d_state))
        # for c in zip(dA, dBx):
        #     s = c[0] * s + c[1]
        #     ssm_states.append(s)
        # ssm_states = jnp.stack(ssm_states)
        #
        # we use the associative operator `op` below to parallelize this
        if use_associative_scan:
            op = lambda s, c: (c[0] * s[0], c[0] * s[1] + c[1])
            _, ssm_states = lax.associative_scan(op, (dA, dBx))
        # or we can implement the same loop using lax.scan 
        else:
            def op(s, c):
                s = c[0] * s + c[1]
                return s, s

            _, ssm_states = lax.scan(op, jnp.zeros((args.d_inner, args.d_state)), (dA, dBx))
        # read out, gating, then output projection
        y = einsum(ssm_states, C, 'l d s, l s -> l d') + x * params.D
        y = y * nn.silu(res)
        # (l, d) -> (l, d_model)
        return y @ params.out_proj + zero_or(params.out_proj_bias)

    def f(x, params):
        return x + block(rms_norm(params.norm, x, 1e-8), params), None

    h, _ = lax.scan(f, params.embedding[tokens], params.layers)
    
    logits = rms_norm(params.norm_f, h, 1e-8) @ params.embedding.T

    return logits


# inference
def mamba_step(args, valid_logits, params, cache, token):

    def block(x, params, conv_cache, ssm_state):
        x, res = jnp.split(x @ params.in_proj + zero_or(params.in_proj_bias), 2, -1)
        # convolve input with kernel
        conv_input = jnp.concatenate([conv_cache, x[:, None]], -1)  # (d_inner, d_conv)
        kernel = jnp.flip(params.conv, -1)  # (d_inner, d_conv)
        x = nn.silu(jnp.vecdot(conv_input, kernel) + zero_or(params.conv_bias))
        # per token discretization, read-in, and read-out vectors
        x_dt, B, C = jnp.split(x @ params.x_proj, [args.dt_rank, args.dt_rank + args.d_state], -1)
        # dt should always be positive
        dt = nn.softplus(x_dt @ params.dt_proj + zero_or(params.dt_proj_bias))
        # (s,) -> (1, s), (d,) -> (d, 1)
        B, dt = B[None], dt[:, None]
        # (d, s)
        ssm_state = jnp.exp(-jnp.exp(params.A_log) * dt) * ssm_state + B * x[:, None] * dt
        # (d, s) @ (s,).T + (d,) * (d,) -> (d,)
        y = ssm_state @ C.T + x * params.D
        # gating, output projection, then return with cache
        y = y * nn.silu(res)
        y = y @ params.out_proj + zero_or(params.out_proj_bias)
        return y, (conv_input[:, 1:], ssm_state)

    def f(x, params_and_cache):
        params, cache = params_and_cache
        h, cache = block(rms_norm(params.norm, x, 1e-8), params, *cache)
        return x + h, cache

    h, cache = lax.scan(f, params.embedding[token], (params.layers, cache))

    logits = rms_norm(params.norm_f, h, 1e-8) @ params.embedding.T
    
    return logits[:args.orig_vocab_size if valid_logits else args.vocab_size], cache


def adam(lr, b1, b2, eps, step, params, grads, state):
    m, v = state
    m = jt.map(lambda m, g: b1 * m + (1 - b1) * g, m, grads)
    v = jt.map(lambda v, g: b2 * v + (1 - b2) * g ** 2, v, grads)
    m_ = jt.map(lambda m: m / (1 - b1 ** step), m)
    v_ = jt.map(lambda v: v / (1 - b2 ** step), v)
    params = jt.map(lambda p, m, v: p - lr * m / (v + eps) ** .5, params, m_, v_)
    return params, (m, v)

In [None]:
# !rm *.md
!wget "https://raw.githubusercontent.com/textvs/Austen-Works/master/1813%20%7C%20Pride%20and%20Prejudice%20(PG%201342)/Jane%20Austen%2C%20Pride%20and%20Prejudice%20(1813).md" \
      "https://raw.githubusercontent.com/textvs/Austen-Works/master/1811%20%7C%20Sense%20and%20Sensibility%20(PG%2021839)/Jane%20Austen%2C%20Sense%20and%20Sensibility%20(1811).md" \
      "https://raw.githubusercontent.com/textvs/Austen-Works/master/1818%20%7C%20Persuasion%20(PG%20105)/Jane%20Austen%2C%20Persuasion%20(1818).md"

In [None]:
text = '\n[end]\n'.join([open(f, 'r').read() for f in os.listdir() if f.endswith('.md')])

vocab = ['<|pad|>'] + sorted(list(set(text)))
print(f'vocabulary:\n{vocab}')

itoc = {i: c for i, c in enumerate(vocab)}
ctoi = {c: i for i, c in enumerate(vocab)}

encode = lambda string: [ctoi[c] for c in string]
decode = lambda tokens: ''.join([itoc[i] for i in tokens])

tokens = jnp.array(encode(text))
print(f'# tokens:\n{len(tokens)}')

In [None]:
def cross_entropy(logits, targets):
    logits = logits.reshape(-1, logits.shape[-1])
    targets = targets.reshape(-1)
    return (nn.logsumexp(logits, -1) - vmap(getitem)(logits, targets)).mean()


def get_sampler(tokens, batch_size, seq_len):

    def sampler(key):
        start = random.randint(key, (batch_size, 1), 0, len(tokens) - seq_len - 1)
        batch = tokens[jnp.arange(seq_len + 1) + start]
        return batch[:, :-1], batch[:, 1:]
    
    return sampler


def get_train_step(model, optimizer, sampler):

    def loss_fn(params, inputs, targets):
        return cross_entropy(model(params, inputs), targets)
    
    def train_step(key, step, params, state):
        loss, grads = value_and_grad(loss_fn)(params, *sampler(key))
        params, state = optimizer(step, params, grads, state)
        return loss, params, state
    
    return jit(train_step)

In [None]:
args = ModelArgs(d_model=384, n_layer=6, vocab_size=len(vocab))
use_associative_scan = True
model = vmap(partial(mamba, args, use_associative_scan), (None, 0))

lr = 1e-3
b1 = 0.9
b2 = 0.999
eps = 1e-8
optimizer = partial(adam, lr, b1, b2, eps)

batch_size = 8
seq_len = 512
sampler = get_sampler(tokens, batch_size, seq_len)

train_step = get_train_step(model, optimizer, sampler)

In [None]:
key, subkey = random.split(random.key(42))
params = initialize_params(subkey, args)
state = jt.map(jnp.zeros_like, (params, params))
sum(p.size for p in jt.leaves(params))

In [None]:
steps = 10000

In [None]:
losses = []

for step in range(1, steps + 1):
    key, subkey = random.split(key)
    loss, params, state = train_step(subkey, step, params, state)
    losses.append(loss)
    print(f'step {step:4d} loss {float(loss):11.7f}')

In [None]:
plt.plot(losses)

In [None]:
def run(key, args, params, prompt, steps, temperature=1):
    f = jit(partial(mamba_step, args, True, params))

    tokens = jnp.array(encode(prompt))

    cache = (
        jnp.zeros((args.n_layer, args.d_inner, args.d_conv - 1)),  # conv_cache
        jnp.zeros((args.n_layer, args.d_inner, args.d_state))  # ssm_state
    )

    # step through prompt tokens once to get next token and current state 
    for token in tokens:
        logits, cache = f(cache, token)

    print(prompt, end='')
  
    token = random.categorical(key, logits / temperature)
    print(itoc[int(token)], end='')

    # sample tokens autoregressively 
    for _ in range(steps):
        key, subkey = random.split(key)
        logits, cache = f(cache, token)
        token = random.categorical(subkey, logits / temperature)
        print(itoc[int(token)], end='')

In [None]:
key = random.key(42)
prompt = "Mr. Bennet was among the"
steps = seq_len - len(prompt)
run(key, args, params, prompt, steps)