<a href="https://colab.research.google.com/github/kaballas/AutoGPT/blob/master/Refactored_SimplerMambaSSM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mamba-ssm causal-conv1d torch

Collecting mamba-ssm
  Downloading mamba_ssm-1.1.1.tar.gz (34 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting causal-conv1d
  Downloading causal_conv1d-1.1.1.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ninja (from mamba-ssm)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops (from mamba-ssm)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Collecting buildtools (from causal-conv1d)
  Downloading buildtools-1.0.6.tar.gz (446 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m446.5/446.5 kB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting argpars

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-01-12 22:44:59--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-01-12 22:45:00 (141 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
!mkdir differentattention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from mamba_ssm import Mamba
import time

# Hyperparameters and device configuration
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "train_data_fraction": 0.99,
    "epochs": 1,
    "lr": 1e-3,
    "batch_size": 64,
    "block_size": 256,
    "max_iters": 10000,
    "print_iters": 100,
    "eval_iters": 10,
    "eval_interval": 300,
    "n_embed": 384,
    "n_heads": 6,
    "n_layers": 6,
    "dropout": 0.2,
}


# Model definition
class FeedForward(nn.Module):
    """
    Feed-forward network used in each transformer block.
    """
    def __init__(self, embed_dim, expansion_factor=4):
        super(FeedForward, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, expansion_factor * embed_dim),
            nn.ReLU(),
            nn.Linear(expansion_factor * embed_dim, embed_dim),
            nn.Dropout(config["dropout"])
        )

    def forward(self, x):
        return self.ffn(x)


class Block(nn.Module):
    """
    Basic building block of the transformer consisting of Mamba and feed-forward network.
    """
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        # self.attention = MultiHeadAttention(num_heads, embed_dim)
        self.attention = Mamba(
            # This module uses roughly 3 * expand * d_model^2 parameters
            d_model=embed_dim, # Model dimension d_model
            d_state=16,  # SSM state expansion factor
            d_conv=4,    # Local convolution width
            expand=2,    # Block expansion factor
        )
        self.feed_forward = FeedForward(embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.attention(self.norm1(x))
        x = x + attention_output
        feed_forward_output = self.feed_forward(self.norm2(x))
        x = x + feed_forward_output
        return x


class BigramNeuralNetwork(nn.Module):
    """
    Defines the overall model.
    """
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super(BigramNeuralNetwork, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(config["block_size"], embed_dim)
        self.blocks = nn.Sequential(*[Block(embed_dim, num_heads) for _ in range(num_layers)])
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None):
        token_embeddings = self.token_embedding(idx)
        position_embeddings = self.position_embedding(torch.arange(idx.size(1), device=idx.device))
        x = token_embeddings + position_embeddings
        x = self.blocks(x)

        logits = self.lm_head(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -config["block_size"]:]
            logits, _ = self(idx_cond)
            next_token_probs = F.softmax(logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(next_token_probs, 1)
            idx = torch.cat([idx, next_token], dim=-1)

        return idx


class TextDataset(Dataset):
    """
    Loads, transforms, and serves the data
    """
    def __init__(self, file_path, block_size):
        with open(file_path, "r") as f:
            text = f.read()
        self.chars = sorted(list(set(text)))
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for i, ch in enumerate(self.chars)}
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        return chunk[:-1], chunk[1:]


# Load data
dataset = TextDataset(file_path="input.txt",
                      block_size=config["block_size"])
train_size = int(config["train_data_fraction"] * len(dataset))
config["vocab_size"]=len(dataset.chars)
print("Train size:", train_size)
print("Val size:", len(dataset)-train_size)
print("Vocab size:", config["vocab_size"])

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=True)


# Initialize Mamba model
mamba_model = BigramNeuralNetwork(
    vocab_size=config["vocab_size"],
    embed_dim=config["n_embed"],
    num_heads=config["n_heads"],
    num_layers=config["n_layers"]
).to(config["device"])

optimizer = torch.optim.AdamW(mamba_model.parameters(), lr=config["lr"])


def train_model(model, train_loader, val_loader, optimizer, config, early_stop=None):
    print("\nStarting train")
    for epoch in range(config["epochs"]):
        # Training Phase
        start_time = time.time()
        model.train()

        for i, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(config["device"]), y_batch.to(config["device"])

            optimizer.zero_grad()  # Clear gradients from the previous step
            _, train_loss = model(x_batch, y_batch)  # Forward pass and loss computation
            train_loss.backward()  # Backpropagation
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()  # Update model parameters

            if i % config["print_iters"] == 0:
                # Validation Phase
                model.eval()  # Set the model to evaluation mode
                total_val_loss = 0
                with torch.no_grad():  # Disable gradient computation
                    for x_batch, y_batch in val_loader:
                        x_batch, y_batch = x_batch.to(config["device"]), y_batch.to(config["device"])
                        _, val_loss = model(x_batch, y_batch)
                        total_val_loss += val_loss.item()

                average_val_loss = total_val_loss / len(val_loader)
                print(f"Epoch [{epoch+1}/{config['epochs']}], Step [{i+1}/{len(train_loader)}], Train Loss: {train_loss.item():.4f}, Validation Loss: {average_val_loss:.4f}")
                model.train()

                # Optional: Save Model Checkpoint
                # torch.save(model.state_dict(), f"./model_checkpoint_epoch_{epoch+1}.pt")

            if i == early_stop:
                break # Early stop

    end_time = time.time()
    train_duration = end_time - start_time
    print(f"Training completed in {train_duration:.2f} seconds")



Train size: 1103986
Val size: 11152
Vocab size: 65


In [None]:
train_model(
    model=mamba_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    config=config,
    early_stop=500
)


Starting train
Epoch [1/1], Step [1/17250], Train Loss: 4.7023, Validation Loss: 9.3851
Epoch [1/1], Step [101/17250], Train Loss: 1.5919, Validation Loss: 1.5709
Epoch [1/1], Step [201/17250], Train Loss: 1.3830, Validation Loss: 1.3720
Epoch [1/1], Step [301/17250], Train Loss: 1.2471, Validation Loss: 1.2713
Epoch [1/1], Step [401/17250], Train Loss: 1.2207, Validation Loss: 1.2019
Epoch [1/1], Step [501/17250], Train Loss: 1.1363, Validation Loss: 1.1431
Training completed in 393.98 seconds


In [None]:
class SelfAttentionHead(nn.Module):
    """
    Implements a self-attention mechanism.
    """
    def __init__(self, embed_dim, head_size):
        super(SelfAttentionHead, self).__init__()
        self.key = nn.Linear(embed_dim, head_size)
        self.query = nn.Linear(embed_dim, head_size)
        self.value = nn.Linear(embed_dim, head_size)
        self.scale = head_size ** -0.5
        self.dropout = nn.Dropout(config["dropout"])

    def forward(self, x):
        batch_size, seq_length, _ = x.shape
        keys = self.key(x)
        queries = self.query(x)
        values = self.value(x)

        weights = (keys @ queries.transpose(-2, -1)) * self.scale
        weights = torch.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        output = weights @ values
        return output


class MultiHeadAttention(nn.Module):
    """
    Multi-head attention mechanism.
    """
    def __init__(self, num_heads, embed_dim):
        super(MultiHeadAttention, self).__init__()
        head_size = embed_dim // num_heads
        self.heads = nn.ModuleList([SelfAttentionHead(embed_dim, head_size) for _ in range(num_heads)])
        self.linear = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(config["dropout"])

    def forward(self, x):
        concatenated_heads = torch.cat([head(x) for head in self.heads], dim=-1)
        output = self.linear(concatenated_heads)
        output = self.dropout(output)
        return output

class Block(nn.Module):
    """
    Redo the Block class to use attention instead of Mamba
    """
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        self.attention = MultiHeadAttention(num_heads, embed_dim)
        #self.attention = Mamba(
        #    # This module uses roughly 3 * expand * d_model^2 parameters
        #    d_model=embed_dim, # Model dimension d_model
        #    d_state=16,  # SSM state expansion factor
        #    d_conv=4,    # Local convolution width
        #    expand=2,    # Block expansion factor
        #)
        self.feed_forward = FeedForward(embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.attention(self.norm1(x))
        x = x + attention_output
        feed_forward_output = self.feed_forward(self.norm2(x))
        x = x + feed_forward_output
        return x


# Initialize Transformer model
transformer_model = BigramNeuralNetwork(
    vocab_size=len(dataset.chars),
    embed_dim=config["n_embed"],
    num_heads=config["n_heads"],
    num_layers=config["n_layers"]
).to(config["device"])

optimizer = torch.optim.AdamW(transformer_model.parameters(), lr=config["lr"])


In [None]:
train_model(
    model=transformer_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    config=config,
    early_stop=500
)


Starting train
Epoch [1/1], Step [1/17250], Train Loss: 4.6324, Validation Loss: 9.3575
Epoch [1/1], Step [101/17250], Train Loss: 0.2796, Validation Loss: 0.0974
Epoch [1/1], Step [201/17250], Train Loss: 0.0187, Validation Loss: 0.0107
Epoch [1/1], Step [301/17250], Train Loss: 0.0144, Validation Loss: 0.0102
Epoch [1/1], Step [401/17250], Train Loss: 0.0125, Validation Loss: 0.0102
Epoch [1/1], Step [501/17250], Train Loss: 0.0106, Validation Loss: 0.0102
Training completed in 390.31 seconds


In [None]:

import math
from functools import partial
import json
import os

from collections import namedtuple

import torch
import torch.nn as nn

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


class MixerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
        self.dropout = nn.Dropout(p=0.2)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, input_ids, inference_params=None):
        hidden_states = self.embedding(input_ids)
        hidden_states = self.dropout(hidden_states)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
        hidden_states = self.dropout(hidden_states)
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )
        return hidden_states


class MambaClassifier(nn.Module):
    def __init__(
        self,
        config: MambaConfig,
        num_classes: int,
        initializer_cfg=None,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.config = config
        self.num_classes = num_classes

        # Backbone model (MixerModel) setup
        self.backbone = MixerModel(
            d_model=config.d_model,
            n_layer=config.n_layer,
            vocab_size=config.vocab_size,
            ssm_cfg=config.ssm_cfg,
            rms_norm=config.rms_norm,
            initializer_cfg=initializer_cfg,
            fused_add_norm=config.fused_add_norm,
            residual_in_fp32=config.residual_in_fp32,
            device=device,
            dtype=dtype,
        )

        # Output layer for multi-label classification
        self.classifier = nn.Linear(config.d_model, num_classes, **{"device": device, "dtype": dtype})
        #self.loss_fn = nn.BCEWithLogitsLoss()
        self.loss_fn = nn.CrossEntropyLoss()
        self.dropout = nn.Dropout(p=0.2)

        # Initialize weights
        self.apply(
            partial(
                _init_weights,
                n_layer=config.n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def forward(self, input_ids, labels=None):
        # Passing the input through the backbone model
        hidden_states = self.backbone(input_ids)
        last_hidden_state = self.dropout(hidden_states[:, -1, :])

        # Applying the classifier to get logits for each label
        logits = self.classifier(last_hidden_state)  # Use the last hidden state

        # Calculate loss if labels are provided
        loss = self.loss_fn(logits, labels) if labels is not None else None

        return (loss, logits) if labels is not None else loss



In [None]:
# Usage of the class
my_mamba_config = MambaConfig(d_model=config['n_embed'], n_layer=config['n_layers'], vocab_size=config['vocab_size'], ssm_cfg={})
mamba_model_v2 = MambaClassifier(my_mamba_config, num_classes=config['vocab_size'])
input_ids = train_dataset[0][0]
logits = mamba_model_v2(input_ids)

In [None]:
train_model(
    model=mamba_model_v2,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    config=config,
    early_stop=500
)