In [None]:
%pip install datasets==2.16.0 mamba-ssm==1.1.1 accelerate==0.25.0

# See: https://www.reddit.com/r/LocalLLaMA/comments/18da1al/an_interactive_demo_for_mambachat/
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained("VietAI/gpt-neo-1.3B-vietnamese-news")
# train_dataset = load_dataset("nampdn-ai/tinystories-vietnamese", split="train[:500]").map(lambda d: tokenizer(d["vi"]))
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
train_dataset = load_dataset("roneneldan/TinyStories", split="train[:50000]").map(lambda d: tokenizer(d["text"]))
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

In [3]:
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

config = MambaConfig()
config.d_model = 256
config.n_layer = 16
config.vocab_size = tokenizer.vocab_size
model = MambaLMHeadModel(config, device="cuda")
print("Model's parameters:", sum(p.numel() for p in model.parameters()))

Model's parameters: 19876096


In [4]:
import os
import torch
from transformers import Trainer, TrainingArguments

# See: https://github.com/havenhq/mamba-chat/blob/bbc9ef6/trainer/mamba_trainer.py
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits
        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
        return lm_loss

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

trainer = MambaTrainer(
    model=model,
    args=TrainingArguments(
        "mamba-tiny",
        gradient_accumulation_steps=100, num_train_epochs=1,
        logging_steps=20, report_to="none",
    ),
    train_dataset=train_dataset, tokenizer=tokenizer,
)
trainer.train()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
20,8.5804
40,7.9563
60,7.7684


TrainOutput(global_step=62, training_loss=8.08849436236966, metrics={'train_runtime': 1275.3967, 'train_samples_per_second': 39.203, 'train_steps_per_second': 0.049, 'total_flos': 0.0, 'train_loss': 8.08849436236966, 'epoch': 0.99})

In [5]:
# See: https://github.com/state-spaces/mamba/blob/1df0df1/benchmarks/benchmark_generation_mamba_simple.py
def complete(model, tokenizer, input_text: str) -> str:
    input_tokens = tokenizer(input_text, return_tensors="pt")
    input_ids = input_tokens.input_ids.to("cuda")
    output_ids = model.generate(input_ids, max_length=100, return_dict_in_generate=True, cg=True)
    output_text = tokenizer.batch_decode(output_ids.sequences.tolist())[0]
    return output_text

In [6]:
complete(model, tokenizer, "Once")

'Once upon a time, there was a little girl named.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'