<a href="https://colab.research.google.com/github/jmbcmuphd/mamba_tf_maxDeCoder/blob/jmbcmuphd_google_colab/mamba-tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%%capture
!pip install tensorflow[and-cuda]==2.15.0.post1
!pip install transformers==4.36.2
!pip install einops==0.7.0
!pip install datasets==2.16.1
!pip install torch

In [4]:
import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers, Model

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Union

from transformers import AutoTokenizer

import datasets
import math
import numpy as np

from datasets import load_dataset
from tqdm import tqdm

  _torch_pytree._register_pytree_node(


In [5]:
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95 # Change this value as per requirement
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

In [6]:
@dataclass
class ModelArgs:
    model_input_dims: int = 64
    model_states: int = 64
    projection_expand_factor: int = 2
    conv_kernel_size: int = 4
    delta_t_min: float = 0.001
    delta_t_max: float = 0.1
    delta_t_scale: float = 0.1
    delta_t_init_floor: float = 1e-4
    conv_use_bias: bool = True
    dense_use_bias: bool = False
    layer_id: int = -1
    seq_length: int = 128
    num_layers: int = 5
    dropout_rate: float = 0.2
    use_lm_head: float = False
    num_classes: int = None
    vocab_size: int = None
    final_activation = None
    loss:Union[str, keras.losses.Loss] = None
    optimizer: Union[str, keras.optimizers.Optimizer] = keras.optimizers.AdamW()
    metrics = ['accuracy']

    def __post_init__(self):
        self.model_internal_dim: int = int(self.projection_expand_factor * self.model_input_dims)

        self.delta_t_rank = math.ceil(self.model_input_dims/16)
        if self.layer_id == -1:
            self.layer_id = np.round(np.random.randint(0, 1000), 4)

        if self.vocab_size == None:
            raise ValueError("vocab size cannot be none")

        if self.use_lm_head:
            self.num_classes=self.vocab_size
        else:
            if self.num_classes == None:
                raise ValueError(f'num classes cannot be {self.num_classes}')

            if self.num_classes == 1:
                self.final_activation = 'sigmoid'
            else:
                self.final_activation = 'softmax'

        if self.loss == None:
            raise ValueError(f"loss cannot be {self.loss}")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [7]:
# Not used in the code, but can be used, working
class RMSNorm(layers.Layer):
    def __init__(self, d_model: int, eps: float=1e-5, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.eps = eps
        self.weight = tf.Variable(np.ones(d_model), dtype=tf.float32, trainable=True)

    def call(self, x):
        x = tf.math.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True)
        output = x * tf.math.rsqrt(x + self.eps) * self.weight
        return output

∆𝑡 = 𝜏∆(𝖯𝖺𝗋𝖺𝗆𝖾𝗍𝖾𝗋 + 𝑠∆(𝑥𝑡))
= 𝗌𝗈𝖿𝗍𝗉𝗅𝗎𝗌(𝖯𝖺𝗋𝖺𝗆𝖾𝗍𝖾𝗋 + 𝖫𝗂𝗇𝖾𝖺𝗋(𝑥𝑡))
= 𝗌𝗈𝖿𝗍𝗉𝗅𝗎𝗌(𝖫𝗂𝗇𝖾𝖺𝗋(𝑥𝑡))

In [8]:
def selective_scan(u, delta, A, B, C, D):
    dA = tf.einsum('bld,dn->bldn', delta, A) # first step of A_bar = exp(ΔA), i.e., ΔA
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1

    # Cumulative sum along all the input tokens, parallel prefix sum, calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)
    dA_cumsum = tf.exp(dA_cumsum)  # second step of A_bar = exp(ΔA), i.e., exp(ΔA)

    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

    x = dB_u * dA_cumsum
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12) # 1e-12 to avoid division by 0

    y = tf.einsum('bldn,bln->bld', x, C)

    return y + u * D

In [9]:
class MambaBlock(layers.Layer):
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        args = modelargs

        self.in_projection = layers.Dense(
            args.model_internal_dim * 2,
            input_shape=(args.model_input_dims,),
            use_bias=False)

        self.conv1d = layers.Conv1D(
            filters=args.model_internal_dim,
            use_bias=args.conv_use_bias,
            kernel_size=args.conv_kernel_size,
            groups=args.model_internal_dim,
            data_format='channels_first',
            padding='causal'
        )

        # this layer takes in current token 'x' and outputs the input-specific Δ, B, C (according to S6)
        self.x_projection = layers.Dense(
            args.delta_t_rank + args.model_states * 2,
            use_bias=False)

        # this layer projects Δ from delta_t_rank to the mamba internal dimension
        self.delta_t_projection = layers.Dense(args.model_internal_dim,
                                               input_shape=(args.delta_t_rank,), use_bias=True)

        self.A = tf.Variable(repeat(
                tf.range(1, args.model_states+1, dtype=tf.float32),
                'n -> d n', d=args.model_internal_dim), trainable=False, dtype=tf.float32)
        self.A_log = tf.Variable(tf.math.log(self.A), trainable=False, dtype=tf.float32)

        self.D = tf.Variable(np.ones(args.model_internal_dim), dtype=tf.float32)

        self.out_projection = layers.Dense(
            args.model_input_dims,
            input_shape=(args.model_internal_dim,), use_bias=args.dense_use_bias)

    def call(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].

        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """

        (batch_size, seq_len, dimension) = x.shape

        x_and_res = self.in_projection(x) # shape = (batch, seq_len, 2 * model_internal_dimension)
        (x, res) = tf.split(x_and_res,
                            [self.args.model_internal_dim, self.args.model_internal_dim], axis=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :seq_len]
        x = rearrange(x, 'b d_in l -> b l d_in')

        x = tf.nn.swish(x)
        y = self.ssm(x)
        y = y * tf.nn.swish(res) # right side of mamba block image
        return self.out_projection(y)

    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -tf.exp(tf.cast(self.A_log, tf.float32)) # shape -> (d_in, n)
        D = tf.cast(self.D, tf.float32)

        x_dbl = self.x_projection(x) # shape -> (batch, seq_len, delta_t_rank + 2*n)


        (delta, B, C) = tf.split(
            x_dbl,
            num_or_size_splits=[self.args.delta_t_rank, n, n],
            axis=-1) # delta.shape -> (batch, seq_len) & B, C shape -> (batch, seq_len, n)

        delta = tf.nn.softplus(
            self.delta_t_projection( delta)) # shape -> (batch, seq_len, model_input_dim)

        return selective_scan(x, delta, A, B, C, D)

In [10]:
class ResidualBlock(layers.Layer):
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        args = modelargs

        self.mixer = MambaBlock(args)
        # self.norm = RMSNorm(args.model_input_dims)
        self.norm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x):
        """
        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297

            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....

        """
        return self.mixer(self.norm(x)) + x

In [12]:

def init_model(args: ModelArgs):
    input_layer = layers.Input(shape=(args.seq_length,), name='input_ids')
    x = layers.Embedding(args.vocab_size, args.model_input_dims, input_length=args.seq_length)(input_layer)

    for i in range(args.num_layers):
        x = ResidualBlock(args, name=f"Residual_{i}")(x)
        x = layers.Dropout(args.dropout_rate)(x)

    x = layers.LayerNormalization(epsilon=1e-5)(x)

    if not args.use_lm_head: # use flatten only if we are using the model as an LM
        x = layers.Flatten()(x)
    x = layers.Dense(1024, activation=tf.nn.gelu)(x)
    output_layer = layers.Dense(args.num_classes, activation=args.final_activation)(x)

    model = Model(inputs=input_layer, outputs=output_layer, name='Mamba_ka_Mamba')
    model.compile(
        loss=args.loss,
        optimizer=args.optimizer,
        metrics=args.metrics
    )

    return model

In [13]:
dataset = load_dataset("ajaykarthick/imdb-movie-reviews")

Downloading readme:   0%|          | 0.00/1.54k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/53.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.3M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [14]:
args = ModelArgs(
    model_input_dims=128,
    model_states=32,
    num_layers=12,
    dropout_rate=0.2,
    vocab_size=vocab_size,
    num_classes=1,
    loss='binary_crossentropy',
)
model = init_model(args)
model.summary()

Model: "Mamba_ka_Mamba"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_ids (InputLayer)      [(None, 128)]             0         
                                                                 
 embedding (Embedding)       (None, 128, 128)          3906816   
                                                                 
 Residual_0 (ResidualBlock)  (None, 128, 128)          137216    
                                                                 
 dropout (Dropout)           (None, 128, 128)          0         
                                                                 
 Residual_1 (ResidualBlock)  (None, 128, 128)          137216    
                                                                 
 dropout_1 (Dropout)         (None, 128, 128)          0         
                                                                 
 Residual_2 (ResidualBlock)  (None, 128, 128)       

In [16]:
train_labels, test_labels = [], []
train_ids = np.zeros((len(dataset['train']), args.seq_length))
test_ids = np.zeros((len(dataset['test']), args.seq_length))

for i, item in enumerate(tqdm(dataset['train'])):
    text = item['review']
    train_ids[i, :] = tokenizer.encode_plus(text, max_length=args.seq_length, padding='max_length', return_tensors='np')['input_ids'][0][:args.seq_length]
    train_labels.append(item['label'])

for i, item in enumerate(tqdm(dataset['test'])):
    text = item['review']
    test_ids[i, :] = tokenizer.encode_plus(text, max_length=args.seq_length, padding='max_length', return_tensors='np')['input_ids'][0][:args.seq_length]
    test_labels.append(item['label'])
del dataset


100%|██████████| 40000/40000 [00:52<00:00, 761.03it/s]
100%|██████████| 10000/10000 [00:11<00:00, 873.83it/s]


In [17]:
BATCH_SIZE = 32
train_dataset = tf.data.Dataset.from_tensor_slices((train_ids, train_labels)).batch(BATCH_SIZE).shuffle(1000)
test_dataset = tf.data.Dataset.from_tensor_slices((test_ids, test_labels)).batch(BATCH_SIZE).shuffle(1000)

In [None]:
history = model.fit(train_dataset, validation_data=test_dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10

In [None]:
def infer(text: str, model: Model, tokenizer):
    tokens = tokenizer.encode("Hello what is up", max_length=args.seq_length, padding='max_length', return_tensors='np')
    output = model(tokens)[0, 0]
    return output
