In [1]:
import torch
from mamba_ssm import Mamba2
from torch.optim import AdamW

# need to be careful with setting d_model, headdim, and expand
# packing had a ton of issues with casual conv1d when trying to use
# arbitrary settings for these parameters.
# what I know is that d_model = 64, headdim = 4, expand = 2 works
# seems to suggest that https://github.com/state-spaces/mamba/issues/351
# we need (d_model * expand) / headdim to be multiple of 8, that is
# `(d_model * expand) / headdim % 8`
batch, length, dim = 2, 16, 64
expand = 2
headdim = 4
assert (dim * expand) % (headdim * 8) == 0, "d_model * expand must be multiple of headdim * 8"
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    headdim=4,
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
)
opt = AdamW(model.parameters(), lr=1e-3)
model = model.to("cuda")
y = model(x)
assert y.shape == x.shape

In [2]:
l = y.mean()
l.backward()

In [3]:
opt.step()

In [4]:
# Let's test packed sequences
# Create two tensors of different lengths
s1 = 3
s2 = 3
x1 = torch.randn(1, s1, dim).to("cuda")
x2 = torch.randn(1, s2, dim).to("cuda")

seq_idx = [0] * s1 + [1] * s2
seq_idx = torch.tensor(seq_idx, device="cuda", dtype=torch.int32).unsqueeze(0)

# Get individual outputs
with torch.no_grad():
    y1_individual = model(x1)
    y2_individual = model(x2)

# Create packed input
x_packed = torch.cat([x1, x2], dim=1)
cu_seqlens = torch.tensor([0, s1, s1 + s2], dtype=torch.int32, device="cuda")

# Get packed output
with torch.no_grad():
    #y_packed = model(x_packed, seq_idx=seq_idx, cu_seqlens=cu_seqlens).squeeze(0)
    y_packed = model(x_packed, seq_idx=seq_idx).squeeze(0)

# Split packed output
y1_packed = y_packed[:s1].unsqueeze(0)
y2_packed = y_packed[s1:].unsqueeze(0)

# Check if the outputs are the same
print("y1 comparison:", torch.allclose(y1_individual, y1_packed, atol=1e-6))
print("y2 comparison:", torch.allclose(y2_individual, y2_packed, atol=1e-6))

y1 comparison: True
y2 comparison: True
