In [1]:
from sys import path

path.append("../")

import jax
from jax import random
import jax.numpy as jnp
from flax import linen as nn
from matplotlib import pyplot as plt
from mamba import *

seed = 0
rng = jax.random.PRNGKey(seed)
BATCH_SIZE = 4

In [2]:
x = jnp.arange(5)
x.devices()

{CpuDevice(id=0)}

# Args

In [3]:
d_model = 2048

args = ModelArgs.init(d_model=d_model, n_layers=8, vocab_size=200)
args

ModelArgs(d_model=2048, d_inner=4096, n_layers=8, vocab_size=200, d_state=16, expand=2, dt_rank=128, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False)

# RMSNorm

In [4]:
norm_layer = RMSNorm(4)
x = jnp.zeros((BATCH_SIZE, 4))
norm_params = norm_layer.init(rng, x)
jax.tree.map(lambda x: x.shape, norm_params)

{'params': {'weight': (1,)}}

In [5]:
norm_params

{'params': {'weight': Array([1.], dtype=float32)}}

In [6]:
output = norm_layer.apply(norm_params, x)
output

Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

# Blocks

In [7]:
mamba_block = MambaBlock(args)
# input shape is (BATCH_SIZE, l, d)
length = 16
x = random.normal(rng, (BATCH_SIZE, length, d_model))
mamba_block_params = mamba_block.init(rng, x)
res = mamba_block.apply(mamba_block_params, x)
print(res.shape)

(4, 16, 2048)


In [8]:
residal = ResidualBlock(args)
residal_params = residal.init(rng, random.uniform(rng, (BATCH_SIZE, length, d_model)))
residal.apply(residal_params, x)

Array([[[ 0.54489195,  0.76979446,  0.285774  , ..., -0.32827622,
          0.02162874,  0.86027986],
        [-0.96390545,  2.14774   ,  1.3344753 , ...,  0.73951054,
          1.7069014 ,  0.63864386],
        [-0.19300614,  1.2431338 ,  1.2610406 , ...,  0.28740847,
          1.4620298 , -0.3728822 ],
        ...,
        [ 1.4138645 ,  1.175028  ,  0.19055331, ..., -0.95587724,
         -0.527748  , -1.3096331 ],
        [-0.65941924, -0.32283282,  0.03829893, ..., -2.9821916 ,
          1.7974606 ,  1.2640463 ],
        [-1.7579169 , -1.3390398 , -0.35797685, ...,  0.01441868,
         -0.12741882,  0.9854253 ]],

       [[-0.47561848,  0.47945794, -1.8594437 , ...,  0.45038664,
          1.3948874 ,  0.80349314],
        [-0.24722195, -0.04237077, -0.63305736, ...,  0.6097346 ,
          0.46724403,  1.9105839 ],
        [-0.97513324,  0.47870314, -0.23954386, ...,  0.89060074,
          0.19746482, -1.5330082 ],
        ...,
        [ 0.6828121 ,  0.02419403,  0.7800219 , ..., -

In [9]:
mamab = Mamba(args)
input_ids = random.randint(rng, (BATCH_SIZE, length), 0, args.vocab_size)
mamab_params = mamab.init(rng, input_ids)
logits = mamab.apply(mamab_params, input_ids)

print(logits.shape)

x shape after norm: (4, 16, 2048)
x shape after norm: (4, 16, 2048)
(4, 16, 200)


In [10]:
thing_to_index = jnp.arange(5)

@jax.jit
def fn(carry, i):
    carry = carry + thing_to_index[i]
    return carry, x 

result1, result2 = jax.lax.scan(fn, jnp.zeros((4,4)), jnp.arange(5))
result1

Array([[10., 10., 10., 10.],
       [10., 10., 10., 10.],
       [10., 10., 10., 10.],
       [10., 10., 10., 10.]], dtype=float32)

In [11]:
jnp.arange(5)

Array([0, 1, 2, 3, 4], dtype=int32)

# Torch vs Flax

In [12]:
import torch.nn as tnn
import flax.linen as fnn

d_inner = 2048
conv_bias = True
d_conv=4
torch_conv = tnn.Conv1d(
    in_channels=d_inner,
    out_channels=d_inner,
    bias=conv_bias,
    kernel_size=d_conv,
    groups=d_inner,
    padding=d_conv - 1
)
flax_conv = fnn.Conv(
    features=d_inner,
    use_bias=conv_bias,
    kernel_size=d_conv,
    feature_group_count=d_inner,
    padding=d_conv - 1
)

In [13]:
import torch
import jax.numpy as jnp

t_input = torch.zeros((1, 2048, 4))
j_input = jnp.zeros((1, 4, 2048))

In [14]:
rng = jax.random.PRNGKey(0)
flax_params = flax_conv.init(rng, j_input)
print(torch_conv(t_input).shape)
print(flax_conv.apply(flax_params, j_input).shape)

torch.Size([1, 2048, 7])
(1, 7, 2048)
