In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
import os
from pathlib import Path


In [5]:
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerModel
from mamba_ssm import Mamba2

In [6]:
# Initializing a Time Series Transformer configuration with 12 time steps for prediction
configuration = TimeSeriesTransformerConfig(prediction_length=60*5)

# Randomly initializing a model (with random weights) from the configuration
tst_model = TimeSeriesTransformerModel(configuration)

# Accessing the model configuration
configuration = tst_model.config

In [7]:
tst_model

TimeSeriesTransformerModel(
  (scaler): TimeSeriesMeanScaler()
  (encoder): TimeSeriesTransformerEncoder(
    (value_embedding): TimeSeriesValueEmbedding(
      (value_projection): Linear(in_features=9, out_features=64, bias=False)
    )
    (embed_positions): TimeSeriesSinusoidalPositionalEmbedding(600, 64)
    (layers): ModuleList(
      (0): TimeSeriesTransformerEncoderLayer(
        (self_attn): TimeSeriesTransformerAttention(
          (k_proj): Linear(in_features=64, out_features=64, bias=True)
          (v_proj): Linear(in_features=64, out_features=64, bias=True)
          (q_proj): Linear(in_features=64, out_features=64, bias=True)
          (out_proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=64, out_features=32, bias=True)
        (fc2): Linear(in_features=32, out_features=64, bias=True)
        (f

In [8]:
mamba2_model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=128, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,  # Head dimension, typically 64 or 128
).to("cuda")

In [9]:
mamba2_model


Mamba2(
  (in_proj): Linear(in_features=128, out_features=648, bias=False)
  (conv1d): Conv1d(384, 384, kernel_size=(4,), stride=(1,), padding=(3,), groups=384)
  (act): SiLU()
  (norm): RMSNorm()
  (out_proj): Linear(in_features=256, out_features=128, bias=False)
)

In [10]:
batch, length, dim = 2, 256, 128
x = torch.randn(batch, length, dim).to("cuda")
y = mamba2_model(x)

In [11]:
y.shape

torch.Size([2, 256, 128])

In [12]:
tst_model.encoder.layers[0]

TimeSeriesTransformerEncoderLayer(
  (self_attn): TimeSeriesTransformerAttention(
    (k_proj): Linear(in_features=64, out_features=64, bias=True)
    (v_proj): Linear(in_features=64, out_features=64, bias=True)
    (q_proj): Linear(in_features=64, out_features=64, bias=True)
    (out_proj): Linear(in_features=64, out_features=64, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (activation_fn): GELUActivation()
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=64, bias=True)
  (final_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)

In [13]:
tst_attn = tst_model.encoder.layers[0].self_attn.to("cuda")

In [14]:
x = torch.randn(1, 256, 64).to("cuda")
y = tst_attn(x)

In [15]:
y[0].shape

torch.Size([1, 256, 64])

In [16]:
from transformers import Mamba2Model, Mamba2ForCausalLM, Mamba2PreTrainedModel
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
import torch.nn as nn

In [17]:
m2_config = Mamba2Config()
m2_llm = Mamba2ForCausalLM(m2_config)

In [18]:
m2_llm

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(32768, 4096)
    (layers): ModuleList(
      (0): Mamba2Block(
        (norm): Mamba2RMSNorm()
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)
          (in_proj): Linear(in_features=4096, out_features=18560, bias=False)
          (norm): MambaRMSNormGated()
          (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
        )
      )
      (1): Mamba2Block(
        (norm): Mamba2RMSNorm()
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)
          (in_proj): Linear(in_features=4096, out_features=18560, bias=False)
          (norm): MambaRMSNormGated()
          (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
        )
      )
      (2): Mamba2Block(
        (norm): Mam

In [35]:
m2_backbone = m2_llm.backbone

In [None]:
class MyMamba2Config(Mamba2Config):
    """
        Args:
        num_heads (`int`, *optional*, defaults to 128):
            Number of heads for the evolution matrices of mamba 2.
        head_dim (`int`, *optional*, defaults to 64):
            Dimension of each head.
        vocab_size (`int`, *optional*, defaults to 32768):
            Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Mamba2Model`].
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimensionality of the embeddings and hidden states.
        state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
        num_hidden_layers (`int`, *optional*, defaults to 64):
            Number of hidden layers in the model.
        layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
            The epsilon to use in the layer normalization layers.
        pad_token_id (`int`, *optional*, defaults to 1):
            Padding token id.
        bos_token_id (`int`, *optional*, defaults to 0):
            The id of the beginning of sentence token in the vocabulary.
        eos_token_id (`int`, *optional*, defaults to 2):
            The id of the end of sentence token in the vocabulary.
        expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
        conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
        n_groups (`int`, *optional*, defaults to 8):
            Number of groups for the evolution matrices of mamba 2.
        use_bias (`bool`, *optional*, defaults to `False`):
            Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
        use_conv_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to use bias in the convolution layer of the mixer block.
        hidden_act (`str`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        initializer_range (`float`, *optional*, defaults to 0.1):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        residual_in_fp32 (`bool`, *optional*, defaults to `True`):
            Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
        time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
            Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
        time_step_min (`float`, *optional*, defaults to 0.001):
            Minimum `time_step` used to bound `dt_proj.bias`.
        time_step_max (`float`, *optional*, defaults to 0.1):
            Maximum `time_step` used to bound `dt_proj.bias`.
        time_step_floor (`float`, *optional*, defaults to 0.0001):
            Minimum clamping value of the `dt_proj.bias` layer initialization.
        time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
            Accepted range of time step values.
        rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
            Whether or not to rescale `out_proj` weights when initializing.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the cache should be used.
        rms_norm (`bool`, *optional*, defaults to `True`):
            Whether to use RMS norm or not.
        chunk_size (`int`, *optional*, defaults to 256):
            Size of the chunks that will comprise the sequence.
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether to tie word embeddings or not.
    """
    model_type = "mamba2"
    def __init__(
        self,
        symbol_size=32768,
        *args,
        **kwargs
    ):
        self.symbol_size = symbol_size
        super().__init__(*args, **kwargs)

In [36]:
class MyMamba2Model(Mamba2PreTrainedModel):
    def __init__(self, config:MyMamba2Config):
        super().__init__()

        self.symbol_emb = nn.Embedding(config.symbol_size, config.hidden_size)
        
        self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])

        self.gradient_checkpointing = False
        self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)
        self.post_init()

Mamba2Model(
  (embeddings): Embedding(32768, 4096)
  (layers): ModuleList(
    (0): Mamba2Block(
      (norm): Mamba2RMSNorm()
      (mixer): Mamba2Mixer(
        (act): SiLU()
        (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)
        (in_proj): Linear(in_features=4096, out_features=18560, bias=False)
        (norm): MambaRMSNormGated()
        (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
      )
    )
    (1): Mamba2Block(
      (norm): Mamba2RMSNorm()
      (mixer): Mamba2Mixer(
        (act): SiLU()
        (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)
        (in_proj): Linear(in_features=4096, out_features=18560, bias=False)
        (norm): MambaRMSNormGated()
        (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
      )
    )
    (2): Mamba2Block(
      (norm): Mamba2RMSNorm()
      (mixer): Mamba2Mixer(
        (act): SiLU()
        (conv1d): 