In [12]:
from sys import path

path.append("../")

import jax
import jax.numpy as jnp
from flax import linen as nn
from matplotlib import pyplot as plt
from mamba import *

seed = 0
rng = jax.random.PRNGKey(seed)
BATCH_SIZE = 4

In [15]:
x = jnp.arange(5)
x.devices()

TypeError: 'set' object is not subscriptable

# Args

In [2]:
args = ModelArgs.init(d_model=4, n_layers=8, vocab_size=2048)
args

ModelArgs(d_model=4, d_inner=8, n_layers=8, vocab_size=2048, d_state=16, expand=2, dt_rank=Array(1., dtype=float32, weak_type=True), d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False)

# RMSNorm

In [3]:
norm_layer = RMSNorm(4)
x = jnp.zeros((BATCH_SIZE, 4))
norm_params = norm_layer.init(rng, x)
jax.tree.map(lambda x: x.shape, norm_params)

{'params': {'weight': (1,)}}

In [4]:
norm_params

{'params': {'weight': Array([1.], dtype=float32)}}

In [5]:
output = norm_layer.apply(norm_params, x)
output

Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

# MambaBlock

In [16]:
mamba_block = MambaBlock(args)
# input shape is (BATCH_SIZE, l, d)
length = 16
d = 4
x = jnp.zeros((BATCH_SIZE, length, d))
mamba_block_params = mamba_block.init(rng, x)

NameError: name 'args' is not defined

CallCompactUnboundModuleError: Can't call compact methods on unbound modules (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.CallCompactUnboundModuleError)

In [26]:
thing_to_index = jnp.arange(5)

@jax.jit
def fn(carry, i):
    carry = carry + thing_to_index[i]
    return carry, x 

result1, result2 = jax.lax.scan(fn, jnp.zeros((4,4)), jnp.arange(5))
result1

Array([[10., 10., 10., 10.],
       [10., 10., 10., 10.],
       [10., 10., 10., 10.],
       [10., 10., 10., 10.]], dtype=float32)

In [25]:
jnp.arange(5)

Array([0, 1, 2, 3, 4], dtype=int32)

# Torch vs Flax

In [32]:
import torch.nn as tnn
import flax.linen as fnn

d_inner = 2048
conv_bias = True
d_conv=4
torch_conv = tnn.Conv1d(
    in_channels=d_inner,
    out_channels=d_inner,
    bias=conv_bias,
    kernel_size=d_conv,
    groups=d_inner,
    padding=d_conv - 1
)
flax_conv = fnn.Conv(
    features=d_inner,
    use_bias=conv_bias,
    kernel_size=d_conv,
    feature_group_count=d_inner,
    padding=d_conv - 1
)

In [33]:
import torch
import jax.numpy as jnp

t_input = torch.zeros((1, 2048, 4))
j_input = jnp.zeros((1, 4, 2048))

In [34]:
rng = jax.random.PRNGKey(0)
flax_params = flax_conv.init(rng, j_input)
print(torch_conv(t_input).shape)
print(flax_conv.apply(flax_params, j_input).shape)

torch.Size([1, 2048, 7])
(1, 7, 2048)
