In [4]:
from mamba_ssm import Mamba2

In [5]:
def mamba_check(hidden_dim, num_heads, mamba_expand):
    headdim = hidden_dim // num_heads
    if (hidden_dim * mamba_expand / headdim) % 8 != 0:
        gcd = math.gcd(mamba_expand, 8)
        step = 8 // gcd
        n_low = (num_heads // step) * step
        n_high = n_low + step
        candidates = [n for n in [n_low, n_high] if n > 0]
        suggestions = []
        for n in candidates:
            h = int(round(hidden_dim / n) * n)
            suggestions.append((h, n))
        suggestions.sort(key=lambda x: (abs(x[0] - hidden_dim), abs(x[1] - num_heads)))
        best_h, best_n = suggestions[0]
        raise ValueError(
            f"Mamba packed sequence constraint failed: (hidden_dim * expand / headdim) % 8 != 0.\n"
            f"Current: hidden_dim={hidden_dim}, num_heads={num_heads}, expand={mamba_expand}.\n"
            f"Suggested fix: hidden_dim={best_h}, num_heads={best_n}."
        )

In [6]:
hidden_dim = 128
num_heads = 8
mamba_expand = 2

In [8]:
headdim = hidden_dim // num_heads
mamba_check(hidden_dim, num_heads, mamba_expand)

In [9]:
m = Mamba2(
    d_model=hidden_dim,
    headdim=headdim,
    d_state=16,
    d_conv=4,
    expand=mamba_expand,
)