In [1]:
pip install mamba-ssm

Note: you may need to restart the kernel to use updated packages.


In [2]:
import mamba_ssm
import torch
from mamba_ssm import Mamba
from mamba_ssm import Mamba2

In [3]:
mamba_ssm.__version__

'2.2.6.post3'

In [4]:
# Parameters for the model
dim = 128
n_layers = 2
Batch_Size = 16
lr = 1e-5
epochs=5

# setting up the device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# MambaSSM_model = Mamba2
# (
#     # This module uses roughly 3 * expand * d_model^2 parameters
#     d_model=dim, # Model dimension d_model
#     d_state=128, # SSM state expansion factor
#     d_conv=4,   # Local convolution width
#     conv_init=None,
#     expand=2,  # Block expansion factor
#     headdim=64,
#     d_ssm=None,  # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
#     ngroups=1,
#     A_init_range=(1, 16),
#     D_has_hdim=False,
#     rmsnorm=True,
#     norm_before_gate=False,
#     dt_min=0.001,
#     dt_max=0.1,
#     dt_init_floor=1e-4,
#     dt_limit=(0.0, float("inf")),
#     bias=False,
#     conv_bias=True,
#     # Fused kernel and sharding options
#     chunk_size=256,
#     use_mem_eff_path=True,
#     layer_idx=None,  # Absorb kwarg for general module
#     process_group=None,
#     sequence_parallel=True,
# ).to(device)

from mamba_ssm import Mamba2
import torch.nn as nn

Mamba_block = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    rmsnorm=True,
    ngroups = 2,
    use_mem_eff_path=True,
    chunk_size=256,
    layer_idx=None,  # Absorb kwarg for general module
).to(device=device)

class MambaStack(nn.Module):
    def __init__(self,model,n_layers=12,**kwargs):
        super().__init__()
        self.layers = nn.ModuleList(
            model for _ in range(n_layers)
        )
        
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
    
MambaSSMmodel = MambaStack(Mamba_block,n_layers = n_layers).to(device=device)

In [6]:
from mamba_ssm.models.config_mamba import MambaConfig

In [7]:
pip show torchtext

Name: torchtext
Version: 0.5.0
Summary: Text utilities and datasets for PyTorch
Home-page: https://github.com/pytorch/text
Author: PyTorch core devs and James Bradbury
Author-email: jekbradbury@gmail.com
License: BSD
Location: /home/mahesh/miniconda3/envs/mcd_env/lib/python3.12/site-packages
Requires: numpy, requests, sentencepiece, six, torch, tqdm
Required-by: 
Note: you may need to restart the kernel to use updated packages.


In [8]:
# Loading a simple Dataset AG News classification Dataset 
import torchtext
from torchtext.datasets import AG_NEWS
from torch.utils.data import DataLoader 

# train_iter = AG_NEWS(split='train')
train_dataset,test_dataset = AG_NEWS(root='/home/mahesh/sharath_MTP/Metro_Data_1',
                                     ngrams = 3,
                                     vocab = None,
                                     include_unk = False
                                    )
# test_iter = AG_NEWS(split = 'test')

120000lines [00:06, 18089.73lines/s]
120000lines [00:12, 9727.03lines/s] 
7600lines [00:00, 10933.65lines/s]


In [9]:
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence

# making a traindata iterator for making the vocab 
train_iter = list(train_dataset) 

def get_token(data_iter):
    for _,tokens in data_iter:
        yield tokens
        
vocab = build_vocab_from_iterator(get_token(train_iter))
# vocab.set_default_index(vocab["<unk>"])

def collate_fn(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(_label)  
        # adding a list comprehension to convert tokens to indices so the vocab object is not called directly
        # and remains callable 
        processed_text = torch.tensor([vocab[token] for token in _text], dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list, text_list, offsets

# Making the testing and training dataloaders 
train_Dataloader = DataLoader(train_dataset,
                        batch_size=Batch_Size,
                        shuffle=True,
                        collate_fn=collate_fn)

test_Dataloader = DataLoader(test_dataset,
                       batch_size=Batch_Size,
                       shuffle=True,
                       collate_fn=collate_fn)

# train_iter = AG_NEWS(root='/home/mahesh/sharath_MTP/Metro_Data_1',split='train')

120000lines [00:29, 4054.38lines/s]


In [10]:
# Replacing the vocab(_text) with a list comprehension [vocab[token] for token in _text], which maps each token to its index.
# This will avoid the 'Vocab' object is not callable error.
for labels,text,_ in train_Dataloader:
    print(labels)
    print(text)
    break

tensor([1, 3, 2, 2, 3, 2, 1, 2, 2, 3, 3, 3, 2, 1, 3, 1])
tensor([0, 0, 0,  ..., 0, 0, 0])


In [11]:
print(len(vocab))
vocab_size = len(vocab)

15220829


In [12]:
# pip show transformers
import transformers
from transformers import AutoModelForCausalLM,GPT2Config,GPT2LMHeadModel,GPT2Model
from transformers import AutoTokenizer

config = GPT2Config(
    vocab_size = vocab_size,
    n_embd=dim,
    n_layer=n_layers,
    n_head=8,
    layer_norm_epsilon=1e-5,
    eos_token_id=vocab_size,
)

GPT2_Model = GPT2LMHeadModel(config).to(device=device)
# Defining the tokenizer 
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [15]:
import torch.optim as optim 
from torch.optim import AdamW
import os 

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

optimizer = optim.AdamW(GPT2_Model.parameters(),lr = lr)
optimizer_1 = optim.AdamW(MambaSSMmodel.parameters(),lr = lr)

# # i --- IGNORE ---

from transformers import get_linear_schedule_with_warmup
total_steps = len(train_Dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)
scheduler_1 = get_linear_schedule_with_warmup(optimizer_1,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)
# # Training Loop
for epoch in range(epochs): 
    GPT2_Model.train()
    MambaSSMmodel.train()
    total_loss = 0
    for batch_idx,(labels,text,offsets) in enumerate(train_Dataloader):
        inputs = text.to(device=device)
        labels = text.to(device=device)
        
        # GPT2 Model Training
        optimizer.zero_grad()
        outputs = GPT2_Model(inputs,labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # MambaSSM Model Training
        optimizer_1.zero_grad()
        embeddings = nn.Embedding(vocab_size,dim).to(device=device)
        input_embeddings = embeddings(inputs)
        mamba_outputs = MambaSSMmodel(input_embeddings)
        # Using a linear layer to project the outputs to vocab size for computing loss
        linear_layer = nn.Linear(dim,vocab_size).to(device=device)
        logits = linear_layer(mamba_outputs)
        loss_1 = nn.CrossEntropyLoss()(logits.view(-1,vocab_size),labels.view(-1))
        loss_1.backward()
        optimizer_1.step()
        scheduler_1.step()
        
        total_loss += loss.item() + loss_1.item()
        
    avg_loss = total_loss / len(train_Dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss}")        
    
# # Saving the models
torch.save(GPT2_Model.state_dict(),"GPT2_MambaSSM_AGNews.pth")
torch.save(MambaSSMmodel.state_dict(),"MambaSSM_AGNews.pth")

AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
