In [1]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
import torch

  from .autonotebook import tqdm as notebook_tqdm


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [2]:
# tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


In [3]:
input_ids = tokenizer("What is mamba?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.batch_decode(out))

['What is mamba?\n\nA:\n\nMamba is a genus of the family Colubridae.\n\nA:\n\nMamba is a genus of the family Colubridae.\n\nA:\n\nMamba is a genus of the family Colubridae.\n\n<|endoftext|>']


In [4]:
dataset = load_dataset("Abirate/english_quotes", split="train")
dataset

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Dataset({
    features: ['quote', 'author', 'tags'],
    num_rows: 2508
})

In [5]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)

In [6]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)



In [7]:
trainer.train()

  1%|          | 10/1881 [02:25<7:02:34, 13.55s/it]

{'loss': 14.9046, 'grad_norm': nan, 'learning_rate': 0.0019893673577884106, 'epoch': 0.02}


  1%|          | 20/1881 [05:44<7:27:15, 14.42s/it] 

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.001978734715576821, 'epoch': 0.03}


  2%|▏         | 30/1881 [07:42<7:44:29, 15.06s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.001968102073365231, 'epoch': 0.05}


  2%|▏         | 40/1881 [10:22<3:55:32,  7.68s/it] 

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019574694311536417, 'epoch': 0.06}


  3%|▎         | 50/1881 [14:30<12:39:23, 24.88s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019468367889420523, 'epoch': 0.08}


  3%|▎         | 60/1881 [16:23<3:04:52,  6.09s/it] 

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019362041467304626, 'epoch': 0.1}


  4%|▎         | 70/1881 [18:21<9:22:00, 18.62s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019255715045188731, 'epoch': 0.11}


  4%|▍         | 80/1881 [35:21<22:34:51, 45.14s/it]  

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019149388623072834, 'epoch': 0.13}


  5%|▍         | 90/1881 [37:53<7:51:22, 15.79s/it] 

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0019043062200956938, 'epoch': 0.14}


  5%|▌         | 100/1881 [38:54<1:40:39,  3.39s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0018936735778841043, 'epoch': 0.16}


  6%|▌         | 110/1881 [40:41<4:54:29,  9.98s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0018830409356725146, 'epoch': 0.18}


  6%|▋         | 120/1881 [42:52<12:51:16, 26.28s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.001872408293460925, 'epoch': 0.19}


  7%|▋         | 130/1881 [47:24<18:03:18, 37.12s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0018617756512493355, 'epoch': 0.21}


  7%|▋         | 140/1881 [53:12<15:21:05, 31.74s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.001851143009037746, 'epoch': 0.22}


  8%|▊         | 150/1881 [1:03:10<10:52:04, 22.60s/it] 

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0018405103668261563, 'epoch': 0.24}


  8%|▊         | 154/1881 [1:03:28<4:03:40,  8.47s/it] 