In [1]:
# Code adapted from: https://github.com/johnma2006/mamba-minimal

In [2]:
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat
from transformers import AutoTokenizer
from utils.mamba import generate, from_pretrained, ssm, RMSNorm

  from .autonotebook import tqdm as notebook_tqdm


# Model Definition
We build out MAMBA from the innermost `MAMBA Block`, to the surrounding `Resiudal Block`, to the full `MAMBA Model`.

### Model Arguments

| Argument       | Meaning/Definition                                  | Notation in Mamba Paper                  |
|----------------|-----------------------------------------------------|------------------------------------------|
| `b`            | Batch size                                          | `B` in Algorithm 2                       |
| `l`            | Sequence length                                     | `L` in Algorithm 2                       |
| `d`, `d_model` | Hidden dimension                                    |                                          |
| `n`, `d_state` | Latent state dimension                              | `N` in Algorithm 2                       |
| `expand`       | Expansion factor                                    | `E` in Section 3.4                       |
| `d_in`, `d_inner` | `d * expand` (expanded hidden dimension)         | `D` in Algorithm 2                       |
| `A`, `B`, `C`, `D` | State space parameters                          | See state space formulas. `B`, `C` are input-dependent (selective); `A`, `D` are not |
| `Δ`, `delta`   | Input-dependent step size                           |                                          |
| `dt_rank`      | Rank of `Δ`                                         | See Section 3.6: "Parameterization of ∆" |


In [3]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

### MAMBA Block

In [4]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = ssm(x, self.A_log, self.D, self.x_proj, self.args.dt_rank, self.dt_proj)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output    

### Residual Block

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)

        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        output = self.mixer(self.norm(x)) + x

        return output

### MAMBA Model

In [6]:
class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper


    def forward(self, input_ids):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.lm_head(x)

        return logits

# Model Usage

### Loading the Model and Tokenizer

In [7]:
# Load a pretrained model
args, state_dict = from_pretrained('state-spaces/mamba-370m')
model = Mamba(ModelArgs(**args))
model.load_state_dict(state_dict)

# Show model arguments and architecture
display(ModelArgs(**args))
display(model)

ModelArgs(d_model=1024, n_layer=48, vocab_size=50280, d_state=16, expand=2, dt_rank=64, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False)

Mamba(
  (embedding): Embedding(50280, 1024)
  (layers): ModuleList(
    (0-47): 48 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        (x_proj): Linear(in_features=2048, out_features=96, bias=False)
        (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=1024, out_features=50280, bias=False)
)

In [8]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

### Generating Output

In [9]:
# Tokenize the input
input_tokens = tokenizer("Mamba is the", return_tensors='pt').input_ids

# Generate tokens
for token in generate(model, input_tokens, 50):
    print(tokenizer.decode(token), end='', flush=True)

Mamba is the second album to be released by the band. This album features the addition of former members of Black Keys – Jason Isbell and Scott Shamblin, and new drummer, James McBain. They recorded most of the album before being joined by Is