In [1]:
import torch, time
from mamba_ssm.models.mixer_seq_simple import create_block, _init_weights
from mamba_ssm.utils.generation import InferenceParams, update_graph_cache
from functools import wraps

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
from torch.nn.modules.transformer import (Dropout, LayerNorm, Linear, Module, MultiheadAttention, Optional, Tensor,
                                          _get_activation_fn)
from torch.utils.checkpoint import checkpoint

from mamba_ssm.models.mixer_seq_simple import create_block, _init_weights
try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from functools import partial
import numpy as np

In [3]:
class MambaLayer(Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        d_intermediate: int,
        ssm_cfg=None,
        attn_layer_idx=None,
        attn_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    d_intermediate=d_intermediate,
                    ssm_cfg=ssm_cfg,
                    attn_layer_idx=attn_layer_idx,
                    attn_cfg=attn_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
                n_residuals_per_layer=1 if d_intermediate == 0 else 2,  # 2 if we have MLP
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, hidden_states, single_eval_position = None, *, inference_params=None, **mixer_kwargs):
        residual = None
        for layer in self.layers:
            # pdb.set_trace()
            # src_left, src_right = hidden_states[:single_eval_position], hidden_states[single_eval_position:]

            ## Need to be changed
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
            # hidden_states = torch.cat([src_left, src_right], dim=0)
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            hidden_states = layer_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
                is_rms_norm=isinstance(self.norm_f, RMSNorm)
            )
        return hidden_states
    

In [4]:
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")

In [5]:
max_length, max_batch_size = 65, batch
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=max_batch_size)

In [6]:
model = MambaLayer(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    n_layer = 2,
    d_intermediate=dim,
).to("cuda")


In [7]:
# compute number of parameters
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

In [8]:
get_num_params(model)

19168

In [9]:
def mamba_inference(
    input_ids,
    model,
    max_length,
):
    batch_size, seqlen_og, _ = input_ids.shape
    if not hasattr(model, "_decoding_cache"):
        model._decoding_cache = None
    model._decoding_cache = update_graph_cache(
        model,
        model._decoding_cache,
        batch_size,
        seqlen_og,
        max_length,
    )
    inference_params = model._decoding_cache.inference_params
    inference_params.reset(max_length, batch_size)

    output = model(
        input_ids,
        inference_params=inference_params,
        num_last_tokens=1,
    )

    return output

In [10]:
# utils
def average_time(model, runs=10):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            time_runs = []
            for _ in range(runs):
                start = time.time()
                result = func(*args, **kwargs)
                end = time.time()
                time_runs.append(end - start)
                print(time_runs[-1])
            print(f"| {model} |: {np.mean(time_runs):.4f}({np.std(time_runs):.4f})")
            return result
        return wrapper
    return decorator

In [11]:
y = model(x) # warmup the cuda

In [12]:
@average_time(model='Mamba-I')
def mamba_test(x, cache = True):
    if cache:
        return mamba_inference(x, model, x.shape[1])
    else:
        return model(x)

In [13]:
batch, length, dim = 20, 100000, 16
x = torch.randn(batch, length, dim).to("cuda")
print('with cache')
mamba_test(x, cache = True)

print('without cache')
mamba_test(x, cache = False)
None

with cache


RuntimeError: Given normalized_shape=[16], expected input with shape [*, 16], but got input of size[20, 1]

In [None]:
batch, length, dim = 20, 10000, 16
x = torch.randn(batch, length, dim).to("cuda")

In [None]:
start = time.time()
model(x, inference_params = model._decoding_cache.inference_params)
end = time.time()
print(end - start)

0.0054168701171875


In [None]:
start = time.time()
model(x, inference_params = None)
end = time.time()
print(end - start)

0.007906675338745117


In [None]:
mamba_inference(x, model, 0).shape

torch.Size([20, 10000, 16])

In [None]:
mamba_inference(x, model, 1).shape

torch.Size([20, 10000, 16])

In [None]:
model._decoding_cache.inference_params.shape

AttributeError: 'InferenceParams' object has no attribute 'shape'