In [None]:
import torch
from mamba_ssm import Mamba2
from torch.optim import AdamW

# need to be careful with setting d_model, headdim, and expand
# packing had a ton of issues with casual conv1d when trying to use
# arbitrary settings for these parameters.
# what I know is that d_model = 64, headdim = 4, expand = 2 works
# seems to suggest that https://github.com/state-spaces/mamba/issues/351
# we need (d_model * expand) / headdim to be multiple of 8, that is
# `(d_model * expand) / headdim % 8`
# doing a bit of testing, it seems that is not the case? it is just that
# some combinations result in lower level errors for some f#%$#@ reason
# ok, never mind, it seems to be the case that for normal batched inputs,
# just about any combination works, but for packed inputs where we 
# specify the seq_lens, some combinations lead to the causal-conv1d
# thinking that the shape of the input (in terms of channels being last)
# is not correct. It seems it does seem to  be the case that the rule
# of (d_model * expand) / headdim % 8 == 0 holds for packed inputs.
batch, length, dim = 2, 16, 96
expand = 2
headdim = 4
if not (dim * expand / headdim) % 8 == 0:
    print("d_model * expand must be multiple of headdim * 8")
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    headdim=headdim,
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=expand,    # Block expansion factor
)
opt = AdamW(model.parameters(), lr=1e-3)
model = model.to("cuda")
y = model(x)
assert y.shape == x.shape

In [17]:
l = y.mean()
l.backward()

In [18]:
opt.step()

In [19]:
# Let's test packed sequences
# Create two tensors of different lengths
s1 = 3
s2 = 3
x1 = torch.randn(1, s1, dim).to("cuda")
x2 = torch.randn(1, s2, dim).to("cuda")

seq_idx = [0] * s1 + [1] * s2
seq_idx = torch.tensor(seq_idx, device="cuda", dtype=torch.int32).unsqueeze(0)

# Get individual outputs
with torch.no_grad():
    y1_individual = model(x1)
    y2_individual = model(x2)

# Create packed input
x_packed = torch.cat([x1, x2], dim=1)
x_packed.requires_grad_(True)
cu_seqlens = torch.tensor([0, s1, s1 + s2], dtype=torch.int32, device="cuda")

# Get packed output
with torch.no_grad():
    #y_packed = model(x_packed, seq_idx=seq_idx, cu_seqlens=cu_seqlens).squeeze(0)
    y_packed = model(x_packed, seq_idx=seq_idx).squeeze(0)

# Split packed output
y1_packed = y_packed[:s1].unsqueeze(0)
y2_packed = y_packed[s1:].unsqueeze(0)

# Check if the outputs are the same
print("y1 comparison:", torch.allclose(y1_individual, y1_packed, atol=1e-6))
print("y2 comparison:", torch.allclose(y2_individual, y2_packed, atol=1e-6))

y1 comparison: True
y2 comparison: True


In [20]:
y_packed.requires_grad_(True)

tensor([[-6.2116e-01,  2.4107e-01, -1.5129e+00,  3.8587e-01, -5.8004e-01,
         -1.6634e-01, -1.3824e+00,  4.4393e-01, -2.9099e-02,  7.7948e-01,
         -6.6768e-01, -1.0424e+00, -2.0781e-01,  2.0198e-01, -7.2246e-01,
          4.3984e-01,  6.4999e-01,  3.1778e-01,  2.6821e-01,  5.1650e-01,
          6.3873e-01, -1.6768e-01,  4.0387e-01, -2.9645e-01,  7.3715e-02,
          5.7566e-01,  5.1389e-01,  8.0220e-01,  1.2376e+00,  7.9729e-02,
         -8.4012e-01,  1.2893e-01, -2.7089e-01, -5.3847e-01,  1.0451e-01,
         -2.4162e-01, -1.5693e-02,  1.8588e-01,  6.6236e-01, -7.4285e-01,
          1.2007e+00, -6.3125e-01,  3.2688e-01,  3.3484e-01, -7.5681e-02,
         -1.9244e-02, -6.3390e-01, -9.8153e-01,  5.6085e-01,  3.7954e-01,
          3.0841e-01,  9.4786e-01,  3.1047e-01,  1.0744e+00,  8.5448e-01,
         -1.2287e-01, -4.8137e-02,  2.0482e-01,  7.8514e-01, -6.4747e-01,
         -4.9899e-06, -5.4974e-01, -1.2144e-01, -1.9491e-02, -2.0232e-01,
         -1.0877e+00, -3.2753e-01, -2.

In [21]:
y_packed.mean().backward()