In [1]:
import torch, os
from mamba_ssm import Mamba, Mamba2
from time import time
from functools import wraps
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"]="3"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.__version__)

2.1.2


In [3]:
# utils

def average_time(model, runs=10):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            time_runs = []
            for _ in range(runs):
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()
                result = func(*args, **kwargs)
                end.record()
                torch.cuda.synchronize()
                time_runs.append(start.elapsed_time(end))
            print(f"| {model} |: {np.mean(time_runs):.4f}({np.std(time_runs):.4f})")
            return result
        return wrapper
    return decorator

batch, length, dim = 20, 5, 16
x = torch.randn(batch, length, dim).to("cuda")

In [4]:
# Mamba1 Block

mamba = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=8,    # Block expansion factor
).to("cuda")

@average_time(model='Mamba-I')
def mamba_test(x):
    return mamba(x)

mamba(x) # warmup cuda
None

In [5]:
# Mamba2 Blocfrom mamba_ssm import Mamba2
mamba2 = Mamba2(
    d_model=dim, # Model dimension d_model
    d_state=128,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=8,    # Block expansion factor
    headdim=16, # Head dimension
).to("cuda")

@average_time(model='Mamba-II')
def mamba2_test(x):
    return mamba2(x)

mamba2(x) # warmup cuda
None

In [6]:
# Transformer

transformer = torch.nn.TransformerEncoderLayer(
    d_model = dim,
    nhead = 16, 
    dim_feedforward=300,
).to("cuda")

@average_time(model='Transformer')
def transformer_test(x):
    return transformer(x)

transformer(x) # warmup cuda
None

In [7]:

# compute number of parameters
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

print('Model Size')
print("|   Mamba-I   |: ", get_num_params(mamba))
print("|  Mamba-II   |: ", get_num_params(mamba2))
print("| Transformer |: ", get_num_params(transformer))

# data 

batch, length, dim = 20, 10**3, 16
x = torch.randn(batch, length, dim).to("cuda")

print('Inference Time')
# run tests
mamba_test(x)
mamba2_test(x)
transformer_test(x)
None

Model Size
|   Mamba-I   |:  13440
|  Mamba-II   |:  12440
| Transformer |:  11068
Inference Time
| Mamba-I |: 0.6506(1.1563)
| Mamba-II |: 40.8071(119.8397)
| Transformer |: 0.6711(0.0018)


In [8]:
a = x.clone()
a[:,-1,:] = torch.randn_like(x[:,-1,:])

mamba_test(a)
mamba2_test(a)
transformer_test(a)

None


| Mamba-I |: 0.5230(0.1820)
| Mamba-II |: 1.1991(0.0884)
| Transformer |: 0.6985(0.0922)
