In [1]:
from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve

import torch

from dataclasses import dataclass

from typing import Union

import matplotlib.pyplot as plt
import seaborn

In [2]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

  pid, fd = os.forkpty()


--2024-05-07 08:53:08--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: 'input.txt'


2024-05-07 08:53:08 (46.7 MB/s) - 'input.txt' saved [1115394/1115394]



In [3]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
data = jnp.array(encode(text), dtype=jnp.int64)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

  data = jnp.array(encode(text), dtype=jnp.int64)


(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

In [7]:
train_test_split = 0.9
n = int(train_test_split*len(data))
train_data = data[:n]
test_data = data[n:]

In [8]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [9]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [10]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 128
expans = 2
n_head = 4
channel_size = n_embd // n_head
n_layers = 4
dropout = 0.2
conv_k_size = 3
n_latent_dim = 16

rng_key = jax.random.PRNGKey(1564)

def get_batch(split):
  #generate a small batch of data of inputs, x, and targets, y,

    data = train_data if split == 'train' else test_data
    ix = jax.random.uniform(rng_key, shape=(batch_size,), minval=0,maxval=len(data) - block_size, dtype=jnp.float32).astype(jnp.int32)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

# print('----')

# for b in range(batch_size): # batch dimension
#     for t in range(block_size): # time dimension
#         context = xb[b, :t+1]
#         target = yb[b,t]
#         print(f"when input is {context} the target: {target}")

inputs:
(4, 32)
[[58 46 59 57  1 41 39 52  1 51 39 49 43  1 46 47 51  1 40 47 58 43  1 58
  46 43  1 50 39 61  1 40]
 [53 50 43 58 57  1 42 47 51  6  0 14 59 58  1 57 61 43 43 58 43 56  1 58
  46 39 52  1 58 46 43  1]
 [43 52  1 40 53 56 52  1 58 53  5 57 11  1 44 53 56  0 21 52  1 58 46 53
  57 43  1 59 52 44 50 43]
 [46 43  1 30 53 51 39 52 57 10  1 39 52 42  1 61 43  1 46 43 56 43  1 42
  43 50 47 60 43 56  6  0]]
targets:
(4, 32)
[[46 59 57  1 41 39 52  1 51 39 49 43  1 46 47 51  1 40 47 58 43  1 58 46
  43  1 50 39 61  1 40 63]
 [50 43 58 57  1 42 47 51  6  0 14 59 58  1 57 61 43 43 58 43 56  1 58 46
  39 52  1 58 46 43  1 50]
 [52  1 40 53 56 52  1 58 53  5 57 11  1 44 53 56  0 21 52  1 58 46 53 57
  43  1 59 52 44 50 43 42]
 [43  1 30 53 51 39 52 57 10  1 39 52 42  1 61 43  1 46 43 56 43  1 42 43
  50 47 60 43 56  6  0 31]]


In [11]:
@dataclass
class ModelArgs: # The same as torch version since this does not have any torch specific code
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

In [12]:
class RMSNorm(nn.Module):
    d_model: int
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        weight = self.param('weight', nn.initializers.ones, (self.d_model,)) # TODO, maybe use setup will be more clear
        normed = x * jax.lax.rsqrt(np.mean(np.square(x), axis=-1, keepdims=True) + self.eps)
        output = normed * weight
        return output

In [13]:
xb.shape

(4, 32)

In [14]:
class Expan_proj(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=n_embd)(x)
        return x
    
in_proj = nn.Dense(features=n_embd)
params = in_proj.init(rng_key, xb)
in_proj.apply(params, xb).shape

(4, 128)

In [15]:
jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (128,), 'kernel': (32, 128)}}

In [16]:
print(in_proj.tabulate(jax.random.key(0), jnp.ones(xb.shape),
                   compute_flops=True, compute_vjp_flops=True))


[3m                                 Dense Summary                                  [0m
┏━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath[0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs     [0m[1m [0m┃[1m [0m[1moutputs    [0m[1m [0m┃[1m [0m[1mflops[0m[1m [0m┃[1m [0m[1mvjp_flops[0m[1m [0m┃[1m [0m[1mparams      [0m[1m [0m┃
┡━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│      │ Dense  │ [2mfloat32[0m[4,… │ [2mfloat32[0m[4,… │ 33280 │ 103424    │ bias:        │
│      │        │             │             │       │           │ [2mfloat32[0m[128] │
│      │        │             │             │       │           │ kernel:      │
│      │        │             │             │       │           │ [2mfloat32[0m[32,… │
│      │        │             │             │       │           │              │
│      │        │             │             │       │         

In [17]:
32*32+32

1056

# Mamba Block
Dense --> Conv1D --> Silu --> SSM --> Silu --> 

In [18]:
class MambaBlock(nn.Module):
    
    def setup(self):          
        self.in_proj1 = nn.Dense(features=n_embd * expans)
        self.in_proj2 = nn.Dense(features=n_embd * expans)
        
        # Adjusted for Flax. Flax does not have nn.Conv1d, so you might need to reshape or use a different approach
        self.conv1d = nn.Conv(features=n_embd * expans,
                              kernel_size=conv_k_size,
                              padding=1,
                              )

        self.A = -1*self.param('A', nn.initializers.ones, (1, n_latent_dim, n_embd * expans, 1))
        self.B = self.param('B', nn.initializers.ones, (1, n_latent_dim, 1, block_size))
        self.C = self.param('C', jax.random.normal, (1, n_latent_dim, 1, block_size))
#         self.D = self.param('D', jax.random.normal, (1, self.args.d_state, self.args.d_model, 1))
        self.delta = self.param('delta', jax.random.normal, (1, 1, n_embd * expans, block_size))
        
        self.out_proj = nn.Dense(n_embd)
        
        self.rms_norm = nn.RMSNorm()
        
    def __call__(self, x):
        
        embeds = x
        x = self.in_proj1(embeds)
        x = self.conv1d(x)
        x = jax.nn.silu(x)
        x = x.reshape((x.shape[0],1,x.shape[2],x.shape[1]))
        x = self.ssm(x)[1]
        x = x.reshape((x.shape[0],x.shape[3],x.shape[2]))
        x = x*jax.nn.silu(self.in_proj2(embeds))

        x = self.out_proj(x)
        
        x = self.rms_norm(x)
        
        return x
    def discretize(self):
        da = self.delta * self.A
        a_ = jnp.exp(da)
        b_ = self.C * self.delta
        return a_, b_

    def ssm(self, x):
        y = []
        h = 0
        a_, b_ = self.discretize()
        for k in range(x.shape[-1]):
            h = a_[..., k] * h + b_[..., k] * x[..., k]
            y.append((self.C[..., k] * h).sum(1, keepdims=True))
        return h, jnp.stack(y, -1)

In [19]:
# model = MambaBlock()
# params = model.init(jax.random.key(0), xb)
# print(model.tabulate(jax.random.key(0), xb,
#                    compute_flops=True, compute_vjp_flops=True))
# xs = model.apply(params, xb)
# xb.shape, xs.shape

IndexError: tuple index out of range

In [None]:
xs[0][2][59]

In [21]:
xbs = jax.nn.one_hot(xb, vocab_size)
xbs.shape

(4, 32, 65)

In [22]:
class Mamba(nn.Module):
    
    def setup(self):
        self.embeddings = nn.Embed(vocab_size, n_embd)
        
        self.mamba_layers = [MambaBlock() for _ in range(n_layers)]
        
    def __call__(self, x):
        x = self.embeddings(x)
        
        for layer in self.mamba_layers:
            x = layer(x)
            
        out = self.embeddings.attend(x)
        return out

In [23]:
fin_model = Mamba()
fin_params = fin_model.init(jax.random.key(42), xb)
print(fin_model.tabulate(jax.random.key(42), xb,
                   compute_flops=True, compute_vjp_flops=True))
xf = fin_model.apply(fin_params, xb)
xb.shape, xf.shape


[3m                                 Mamba Summary                                  [0m
┏━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━┓
┃[1m [0m[1mpath    [0m[1m [0m┃[1m [0m[1mmodule  [0m[1m [0m┃[1m [0m[1minputs  [0m[1m [0m┃[1m [0m[1moutputs  [0m[1m [0m┃[1m [0m[1mflops   [0m[1m [0m┃[1m [0m[1mvjp_flops[0m[1m [0m┃[1m [0m[1mparams  [0m[1m [0m┃
┡━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━┩
│          │ Mamba    │ [2mint32[0m[4… │ [2mfloat32[0m[… │ 3132797… │ 981552128 │          │
├──────────┼──────────┼──────────┼───────────┼──────────┼───────────┼──────────┤
│ embeddi… │ Embed    │ [2mint32[0m[4… │ [2mfloat32[0m[… │ 17153    │ 41857     │ embeddi… │
│          │          │          │           │          │           │ [2mfloat32[0m… │
│          │          │          │           │          │           │          │
│          │          │          │           │        

((4, 32), (4, 32, 65))

In [114]:
jax.random.categorical(jax.random.PRNGKey(5332), 1000.0*xf[1][-1][:])

Array(60, dtype=int32)

In [30]:
import tensorflow as tf
tf.zeros([1, 50], dtype=tf.int64)[:, -block_size:]


<tf.Tensor: shape=(1, 32), dtype=int64, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>

In [None]:
layer = nn.Embed(num_embeddings=5, features=3)
indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
variables = layer.init(jax.random.key(0), indices_input)
variables
# get the first three and last three embeddings
nu_out = layer.apply(variables, indices_input)
nu_out.shape

In [None]:
params['params']['conv1d']['kernel'].shape

Array([[0.]], dtype=float32)