In [1]:
def mamba2_params(dim_M, expand=2, d_conv=4, d_state=64, tight=False):
    # dominant-only (default) vs. slightly tighter approximation
    base = 3 * expand * (dim_M ** 2)
    if not tight:
        return base
    return base + expand * dim_M * d_conv + 2 * expand * dim_M * d_state

def target_params(n_p, hdim_p, dim_p, dim_M, expand=2, tight=False):
    return 4 * n_p * hdim_p * dim_p + mamba2_params(dim_M, expand=expand, tight=tight)

def match_layer(T, dims=(256,384,512,640,768,1024,1280,1536,2048),
                hdims=(64,96,128,160,192,256,320,384,512),
                n_max=256, tol_pct=0.01):
    best = []
    for d in dims:
        for h in hdims:
            n = round(T / (4*h*d))
            if 1 <= n <= n_max:
                P = 4*n*h*d
                err = abs(P - T) / T
                if err <= tol_pct:
                    best.append((err, n, h, d, P))
    return sorted(best)

In [4]:
from typing import Iterable, List, Dict, Optional, Tuple

def mamba2_params(dim_M: int, *, expand: int = 2, d_conv: int = 4, d_state: int = 64,
                  tight: bool = False) -> int:
    """
    Parameter count for ONE Mamba2 layer.
    - dominant: 3 * expand * dim_M^2
    - tight=True: adds small linear terms (depthwise conv & SSM projections)
    """
    base = 3 * expand * (dim_M ** 2)
    if not tight:
        return base
    return base + expand * dim_M * d_conv + 2 * expand * dim_M * d_state  # biases ignored

def find_matches_with_mamba(
    # Your layer:
    n: int, h_dim: int, dim: int,
    # Search grids for Mamba2:
    mamba_dims: Iterable[int] = (256, 384, 512, 640, 768, 1024, 1280, 1536, 2048, 2304, 2560, 3072, 4096),
    mamba_expands: Iterable[int] = (1, 2, 4),
    mamba_d_states: Iterable[int] = (64, 128),
    mamba_d_convs: Iterable[int] = (4,),
    tight_mamba: bool = False,
    # Search grids/constraints for the *other* model (n', h'_dim, dim'):
    dims_prime: Iterable[int] = range(256, 4097, 64),
    hdims_prime: Iterable[int] = (64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768),
    nprime_min: int = 1, nprime_max: int = 4096,
    require_dim_multiple_of: Optional[int] = 64,
    require_hdim_multiple_of: Optional[int] = 32,
    # Matching mode:
    exact: bool = True, tol_pct: float = 0.0,
    # Optional coupling: tie dim' to Mamba dim, or to your dim
    couple_dim_prime_to: Optional[str] = None,  # None | "mamba" | "yours"
) -> List[Dict]:
    """
    Finds tuples (Mamba2 config, n', h'_dim, dim') such that:
        4*n*h_dim*dim  ≈  4*n'*h'_dim*dim' + P_mamba2
    Exact equality if exact=True; else within tol_pct.

    Returns a list of dicts sorted by % error, then by (dim', h'_dim, n').
    """
    P_yours = 4 * n * h_dim * dim

    # pre-filter multiples
    def ok_multiple(x: int, mult: Optional[int]) -> bool:
        return True if mult is None else (x % mult == 0)

    dims_prime = [d for d in dims_prime if ok_multiple(d, require_dim_multiple_of)]
    hdims_prime = [h for h in hdims_prime if ok_multiple(h, require_hdim_multiple_of)]

    results: List[Dict] = []

    for dim_M in mamba_dims:
        for expand in mamba_expands:
            for d_state in mamba_d_states:
                for d_conv in mamba_d_convs:
                    P_m2 = mamba2_params(dim_M, expand=expand, d_conv=d_conv,
                                         d_state=d_state, tight=tight_mamba)
                    target_linear = P_yours - P_m2
                    if target_linear <= 0:
                        continue  # Mamba alone exceeds budget

                    # Determine the dim' candidates, optionally coupled
                    if couple_dim_prime_to == "mamba":
                        dims_prime_eff = [dim_M] if ok_multiple(dim_M, require_dim_multiple_of) else []
                    elif couple_dim_prime_to == "yours":
                        dims_prime_eff = [dim] if ok_multiple(dim, require_dim_multiple_of) else []
                    else:
                        dims_prime_eff = dims_prime

                    for dprime in dims_prime_eff:
                        for hprime in hdims_prime:
                            denom = 4 * hprime * dprime
                            if exact:
                                if target_linear % denom != 0:
                                    continue
                                nprime = target_linear // denom
                                if not (nprime_min <= nprime <= nprime_max):
                                    continue
                                P_other = 4 * nprime * hprime * dprime + P_m2
                                err = abs(P_other - P_yours) / P_yours
                                results.append(dict(
                                    error_pct=err * 100.0,
                                    n_prime=int(nprime), h_dim_prime=int(hprime), dim_prime=int(dprime),
                                    mamba_dim=int(dim_M), mamba_expand=int(expand),
                                    mamba_d_state=int(d_state), mamba_d_conv=int(d_conv),
                                    P_yours=int(P_yours), P_other=int(P_other), P_mamba2=int(P_m2),
                                    exact=True
                                ))
                            else:
                                nprime = int(round(target_linear / denom))
                                if not (nprime_min <= nprime <= nprime_max) or nprime <= 0:
                                    continue
                                P_other = 4 * nprime * hprime * dprime + P_m2
                                err = abs(P_other - P_yours) / P_yours
                                if err <= tol_pct:
                                    results.append(dict(
                                        error_pct=err * 100.0,
                                        n_prime=int(nprime), h_dim_prime=int(hprime), dim_prime=int(dprime),
                                        mamba_dim=int(dim_M), mamba_expand=int(expand),
                                        mamba_d_state=int(d_state), mamba_d_conv=int(d_conv),
                                        P_yours=int(P_yours), P_other=int(P_other), P_mamba2=int(P_m2),
                                        exact=False
                                    ))

    results.sort(key=lambda r: (r["error_pct"], r["dim_prime"], r["h_dim_prime"], r["n_prime"]))
    return results


In [5]:
# Your layer budget
n, h_dim, dim = 12, 64, 768

# 1) Exact matches over common Mamba settings, tying dim' to Mamba dim
# cands_exact = find_matches_with_mamba(
#     n, h_dim, dim,
#     exact=True, couple_dim_prime_to="mamba",
#     mamba_dims=(512, 640, 768, 1024),
#     mamba_expands=(2,), mamba_d_states=(64, 128),
# )

# 2) Allow 0.5% slack, search wide, and use tighter Mamba estimate
cands_approx = find_matches_with_mamba(
    n, h_dim, dim,
    exact=False, tol_pct=0.005, tight_mamba=True,
    mamba_dims=range(256, 4097, 64),
    mamba_expands=(1, 2, 4),
    mamba_d_states=(64, 128),
    dims_prime=range(256, 4097, 64),
    hdims_prime=(64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768),
)

# # 3) Pin dim' to your dim (useful if you want same width on both sides)
# same_width = find_matches_with_mamba(
#     n, h_dim, dim,
#     exact=True,
#     couple_dim_prime_to="yours",
# )


In [6]:
cands_approx

[{'error_pct': 0.03255208333333333,
  'n_prime': 1,
  'h_dim_prime': 64,
  'dim_prime': 256,
  'mamba_dim': 832,
  'mamba_expand': 1,
  'mamba_d_state': 128,
  'mamba_d_conv': 4,
  'P_yours': 2359296,
  'P_other': 2358528,
  'P_mamba2': 2292992,
  'exact': False},
 {'error_pct': 0.03255208333333333,
  'n_prime': 1,
  'h_dim_prime': 96,
  'dim_prime': 448,
  'mamba_dim': 832,
  'mamba_expand': 1,
  'mamba_d_state': 64,
  'mamba_d_conv': 4,
  'P_yours': 2359296,
  'P_other': 2358528,
  'P_mamba2': 2186496,
  'exact': False},
 {'error_pct': 0.043402777777777776,
  'n_prime': 32,
  'h_dim_prime': 64,
  'dim_prime': 256,
  'mamba_dim': 256,
  'mamba_expand': 1,
  'mamba_d_state': 128,
  'mamba_d_conv': 4,
  'P_yours': 2359296,
  'P_other': 2360320,
  'P_mamba2': 263168,
  'exact': False},
 {'error_pct': 0.043402777777777776,
  'n_prime': 16,
  'h_dim_prime': 128,
  'dim_prime': 256,
  'mamba_dim': 256,
  'mamba_expand': 1,
  'mamba_d_state': 128,
  'mamba_d_conv': 4,
  'P_yours': 2359296,
 

# final mamba oparam calculator?

In [7]:
from typing import Iterable, List, Dict, Optional, Tuple
import math

# --- Param models -------------------------------------------------------------

def mamba2_params(dim_M: int, *, expand: int = 2, d_conv: int = 4, d_state: int = 64,
                  tight: bool = False) -> int:
    """
    Parameter count for ONE Mamba2 layer.
    dominant term: 3 * expand * dim_M^2
    tight=True adds small linear terms; still dominated by dim_M^2.
    """
    base = 3 * expand * (dim_M ** 2)
    if not tight:
        return base
    return base + expand * dim_M * d_conv + 2 * expand * dim_M * d_state  # biases ignored


# --- Search with multi-objective scoring -------------------------------------

def find_balanced_configs(
    # Your layer (the budget driver)
    n: int, h_dim: int, dim: int,

    # Search grids for Mamba2 (keep these modest to ensure "smaller" blocks)
    mamba_dims: Iterable[int] = (256, 320, 384, 448, 512, 576, 640, 704, 768),
    mamba_expands: Iterable[int] = (1, 2),
    mamba_d_states: Iterable[int] = (64, 128),
    mamba_d_convs: Iterable[int] = (4,),
    tight_mamba: bool = False,

    # Search grids for the other block (n', h'_dim, dim')
    dims_prime: Iterable[int] = tuple(range(256, 4097, 64)),
    hdims_prime: Iterable[int] = (64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448, 512),
    nprime_min: int = 1, nprime_max: int = 8192,

    # Optional coupling: tie dim' to your dim or to the mamba dim
    couple_dim_prime_to: Optional[str] = None,  # None | "mamba" | "yours"

    # Hard constraints / caps for “smaller Mamba”
    max_mamba_dim: Optional[int] = None,        # e.g., 640
    max_mamba_params_frac: Optional[float] = None,  # e.g., 0.35 means P_mamba <= 35% of total

    # Matching mode
    exact: bool = False,  # allow small slack by default
    tol_pct: float = 0.005,  # 0.5% budget tolerance

    # Preference weights (sum doesn’t need to be 1)
    w_param: float = 1.0,     # weight on parameter matching error
    w_closeness: float = 1.0, # weight on staying close to (n, h_dim, dim)
    w_mamba: float = 0.5,     # weight on keeping Mamba block small

    # Closeness scaling (normalize differences)
    # Use relative-to-original for dim, hdim; log distance for n (often spans wider)
    eps: float = 1e-9,

    # Output control
    top_k: int = 50,
) -> List[Dict]:
    """
    Find tuples (Mamba2 config, n', h'_dim, dim') that approximately satisfy:
        4*n*h_dim*dim ≈ 4*n'*h'_dim*dim' + P_mamba
    while preferring small Mamba and (n',h'_dim,dim') close to (n,h_dim,dim).

    Returns top_k candidates sorted by the composite score.
    """

    P_yours = 4 * n * h_dim * dim

    # Make effective dim' grid if coupled
    def effective_dims_prime(dim_M: int) -> Iterable[int]:
        if couple_dim_prime_to == "mamba":
            return (dim_M,)
        elif couple_dim_prime_to == "yours":
            return (dim,)
        else:
            return dims_prime

    candidates: List[Dict] = []

    for dim_M in mamba_dims:
        if max_mamba_dim is not None and dim_M > max_mamba_dim:
            continue

        for expand in mamba_expands:
            for d_state in mamba_d_states:
                for d_conv in mamba_d_convs:
                    P_m2 = mamba2_params(dim_M, expand=expand, d_conv=d_conv,
                                         d_state=d_state, tight=tight_mamba)

                    # If capped as a fraction of total budget, approximate feasibility
                    if max_mamba_params_frac is not None:
                        if P_m2 > max_mamba_params_frac * P_yours:
                            continue

                    target_linear = P_yours - P_m2
                    if target_linear <= 0:
                        continue  # Mamba already too big

                    for dprime in effective_dims_prime(dim_M):
                        for hprime in hdims_prime:
                            denom = 4 * hprime * dprime

                            # exact or approximate n'
                            if exact:
                                if target_linear % denom != 0:
                                    continue
                                nprime = target_linear // denom
                            else:
                                nprime_f = target_linear / denom
                                nprime = int(round(nprime_f))

                            if not (nprime_min <= nprime <= nprime_max) or nprime <= 0:
                                continue

                            P_other = 4 * nprime * hprime * dprime + P_m2
                            err = abs(P_other - P_yours) / P_yours
                            if exact or err <= tol_pct:
                                # --- closeness metric ---
                                # relative deltas for dims; log distance for n
                                rel_dim = abs(dprime - dim) / (dim + eps)
                                rel_h = abs(hprime - h_dim) / (h_dim + eps)
                                logn = abs(math.log((nprime + eps) / (n + eps)))
                                closeness = (rel_dim + rel_h + logn) / 3.0

                                # --- mamba smallness metric ---
                                mamba_frac = P_m2 / P_yours  # smaller is better

                                score = w_param * err + w_closeness * closeness + w_mamba * mamba_frac

                                candidates.append(dict(
                                    score=score,
                                    error_pct=err * 100.0,
                                    closeness=closeness,
                                    mamba_frac=mamba_frac,
                                    # chosen configs
                                    n_prime=int(nprime), h_dim_prime=int(hprime), dim_prime=int(dprime),
                                    mamba_dim=int(dim_M), mamba_expand=int(expand),
                                    mamba_d_state=int(d_state), mamba_d_conv=int(d_conv),
                                    # budgets
                                    P_yours=int(P_yours), P_other=int(P_other), P_mamba2=int(P_m2),
                                ))

    candidates.sort(key=lambda r: (r["score"], r["error_pct"], r["mamba_frac"], r["closeness"]))
    return candidates[:top_k]


In [8]:
# Your current layer
n, h_dim, dim = 12, 64, 768

# Prefer smaller Mamba; keep widths near yours; allow 0.5% param slack
cands = find_balanced_configs(
    n, h_dim, dim,
    # Encourage smaller Mamba:
    mamba_dims=(256, 320, 384, 448, 512, 576, 640),
    mamba_expands=(1, 2),
    max_mamba_dim=640,
    max_mamba_params_frac=0.35,   # cap Mamba at <= 35% of total params

    # Prefer (n', h', d') close to yours:
    couple_dim_prime_to="yours",  # keep dim' tied to your dim (optional)
    exact=False, tol_pct=0.005,   # allow ±0.5% budget slack
    w_param=1.0, w_closeness=1.5, w_mamba=0.7,  # bump closeness importance
)

# Inspect top few
for r in cands[:5]:
    print(
        f"P_err={r['error_pct']:.3f}%  close={r['closeness']:.4f}  "
        f"Mamba={r['mamba_frac']*100:.1f}%  |  "
        f"n'={r['n_prime']}, h'={r['h_dim_prime']}, d'={r['dim_prime']}  |  "
        f"Mamba(dim={r['mamba_dim']}, expand={r['mamba_expand']}, d_state={r['mamba_d_state']})"
    )


P_err=0.000%  close=0.0290  Mamba=8.3%  |  n'=11, h'=64, d'=768  |  Mamba(dim=256, expand=1, d_state=64)
P_err=0.000%  close=0.0290  Mamba=8.3%  |  n'=11, h'=64, d'=768  |  Mamba(dim=256, expand=1, d_state=128)
P_err=0.000%  close=0.0608  Mamba=16.7%  |  n'=10, h'=64, d'=768  |  Mamba(dim=256, expand=2, d_state=64)
P_err=0.000%  close=0.0608  Mamba=16.7%  |  n'=10, h'=64, d'=768  |  Mamba(dim=256, expand=2, d_state=128)
P_err=0.000%  close=0.1352  Mamba=33.3%  |  n'=8, h'=64, d'=768  |  Mamba(dim=512, expand=1, d_state=64)


In [2]:
T = target_params(n_p=12, hdim_p=128, dim_p=768, dim_M=768, expand=2, tight=False)
candidates = match_layer(T, tol_pct=0.005)  # within 0.5%

In [3]:
candidates

[(0.0, 7, 192, 1536, 8257536),
 (0.0, 7, 384, 768, 8257536),
 (0.0, 14, 96, 1536, 8257536),
 (0.0, 14, 192, 768, 8257536),
 (0.0, 14, 384, 384, 8257536),
 (0.0, 21, 64, 1536, 8257536),
 (0.0, 21, 96, 1024, 8257536),
 (0.0, 21, 128, 768, 8257536),
 (0.0, 21, 192, 512, 8257536),
 (0.0, 21, 256, 384, 8257536),
 (0.0, 21, 384, 256, 8257536),
 (0.0, 28, 96, 768, 8257536),
 (0.0, 28, 192, 384, 8257536),
 (0.0, 42, 64, 768, 8257536),
 (0.0, 42, 96, 512, 8257536),
 (0.0, 42, 128, 384, 8257536),
 (0.0, 42, 192, 256, 8257536),
 (0.0, 56, 96, 384, 8257536),
 (0.0, 63, 64, 512, 8257536),
 (0.0, 63, 128, 256, 8257536),
 (0.0, 84, 64, 384, 8257536),
 (0.0, 84, 96, 256, 8257536),
 (0.0, 126, 64, 256, 8257536)]

In [1]:
import torch
from mamba_ssm import Mamba2

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = Mamba2(
    d_model=16,   # hidden size
    d_state=16,   # state size
    d_conv=4,     # convolution size
    expand=2      # expansion factor
)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Dummy input: batch of 2 sequences, each length 8, embedding dim = 16
x = torch.randn(2, 8, 16).to(device)

# Forward pass
y = model(x)

print("Torch:", torch.__version__, "CUDA:", torch.version.cuda)
print("CUDA available?", torch.cuda.is_available())
print("Torch path:", torch.__file__)
print("Input shape:", x.shape)
print("Output shape:", y.shape)

AssertionError: 