In [2]:
import os
import logging
os.environ["CUDA_DEVICE_ORDER"]='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()  
    ]
)
logger = logging.getLogger(__name__)

In [4]:
import torch
import argparse
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM
from mamba_trainer.data import DataModule
from mamba_trainer.data import LongRangeDataset
from mamba_trainer.trainer import MambaTrainer, GradientCallback
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm
2024-08-05 23:06:05.638607: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-05 23:06:05.657687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-05 23:06:05.663521: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-05 23:06:05.678634: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
202

In [5]:
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

lora_config = LoraConfig(
    r=16,
    target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

model = get_peft_model(model, lora_config)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MambaForCausalLM(
      (backbone): MambaModel(
        (embeddings): lora.Embedding(
          (base_layer): Embedding(50280, 768)
          (lora_dropout): ModuleDict(
            (default): Identity()
          )
          (lora_A): ModuleDict()
          (lora_B): ModuleDict()
          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 16x50280 (cuda:0)])
          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 768x16 (cuda:0)])
          (lora_magnitude_vector): ModuleDict()
        )
        (layers): ModuleList(
          (0-23): 24 x MambaBlock(
            (norm): MambaRMSNorm(768, eps=1e-05)
            (mixer): MambaMixer(
              (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
              (act): SiLU()
              (in_proj): lora.Linear(
                (base_la

In [7]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model)



trainable params: 3,796,608 || all params: 132,931,968 || trainable%: 2.86


In [8]:
training_args = TrainingArguments(
    learning_rate=5e-5,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    output_dir="model",
    logging_dir="logs", 
    logging_steps=1,
    save_steps=1,
    report_to="none", 
)

In [11]:
from torch.utils.data import Subset
import numpy as np
grad_callback = GradientCallback()
data_module = DataModule(data_path="./data/basic_20-70/train.tsv", tokenizer=tokenizer)
dataset = data_module.dataset
for i in tqdm(range(20)):
    np.random.seed(None)
    ids = np.random.choice(len(dataset), size=4, replace=False)
    print(ids)
    subset = Subset(dataset, ids.tolist())
    #dataloader = DataLoader(subset, batch_size=4, collate_fn=data_module.data_collator, shuffle=True)
    for ds in subset:
        print(ds["input_ids"].shape)



    trainer = MambaTrainer(
        model=model,
        args=training_args,
        train_dataset=subset,
        tokenizer=tokenizer,
        optimizers=(optimizer, None),
        data_collator=data_module.data_collator,
        callbacks=[grad_callback]
    )

    trainer.train()

./data/basic_20-70/train.tsv


  0%|          | 0/20 [00:00<?, ?it/s]

[257 664 525 962]
torch.Size([47])
torch.Size([100])
torch.Size([91])
torch.Size([61])


2024-08-05 23:08:07,536 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.1039


2024-08-05 23:08:07,615 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.1039
  5%|▌         | 1/20 [00:01<00:19,  1.01s/it]

[779 674 585 506]
torch.Size([47])
torch.Size([43])
torch.Size([115])
torch.Size([71])


2024-08-05 23:08:08,541 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.9364


2024-08-05 23:08:08,653 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.9364
 10%|█         | 2/20 [00:02<00:18,  1.03s/it]

[791 423 490   4]
torch.Size([48])
torch.Size([72])
torch.Size([58])
torch.Size([63])


2024-08-05 23:08:09,553 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.7517


2024-08-05 23:08:09,675 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.7517
 15%|█▌        | 3/20 [00:03<00:17,  1.02s/it]

[957 146 367 589]
torch.Size([44])
torch.Size([64])
torch.Size([63])
torch.Size([64])


2024-08-05 23:08:10,587 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.414


2024-08-05 23:08:10,672 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.414
 20%|██        | 4/20 [00:04<00:16,  1.06s/it]

[311 684 895 437]
torch.Size([52])
torch.Size([66])
torch.Size([99])
torch.Size([139])


2024-08-05 23:08:11,695 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.6037


2024-08-05 23:08:11,781 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.6037
 25%|██▌       | 5/20 [00:05<00:15,  1.05s/it]

[ 42  36 325 107]
torch.Size([57])
torch.Size([92])
torch.Size([67])
torch.Size([64])


2024-08-05 23:08:12,749 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.1107


2024-08-05 23:08:12,864 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.1107
 30%|███       | 6/20 [00:06<00:14,  1.07s/it]

[220 967 981 659]
torch.Size([57])
torch.Size([88])
torch.Size([49])
torch.Size([69])


2024-08-05 23:08:13,820 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.2516


2024-08-05 23:08:13,897 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.2516
 35%|███▌      | 7/20 [00:07<00:13,  1.05s/it]

[531 491 160 969]
torch.Size([90])
torch.Size([67])
torch.Size([61])
torch.Size([89])


2024-08-05 23:08:14,856 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.5725


2024-08-05 23:08:14,932 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.5725
 40%|████      | 8/20 [00:08<00:12,  1.06s/it]

[304 953 282  87]
torch.Size([86])
torch.Size([55])
torch.Size([66])
torch.Size([43])


2024-08-05 23:08:16,353 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.2443


2024-08-05 23:08:16,436 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.2443
 45%|████▌     | 9/20 [00:10<00:13,  1.23s/it]

[295 276 531 856]
torch.Size([111])
torch.Size([84])
torch.Size([90])
torch.Size([43])


2024-08-05 23:08:18,305 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.1875


2024-08-05 23:08:18,389 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.1875
 50%|█████     | 10/20 [00:11<00:14,  1.44s/it]

[ 48 576 577 722]
torch.Size([110])
torch.Size([72])
torch.Size([66])
torch.Size([112])


2024-08-05 23:08:19,471 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.7538


2024-08-05 23:08:19,548 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.7538
 55%|█████▌    | 11/20 [00:12<00:11,  1.31s/it]

[777 462 541 932]
torch.Size([55])
torch.Size([65])
torch.Size([96])
torch.Size([78])


2024-08-05 23:08:20,458 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.9884


2024-08-05 23:08:20,570 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.9884
 60%|██████    | 12/20 [00:14<00:09,  1.24s/it]

[328 594 787 130]
torch.Size([53])
torch.Size([81])
torch.Size([69])
torch.Size([53])


2024-08-05 23:08:21,553 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.9733


2024-08-05 23:08:21,641 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.9733
 65%|██████▌   | 13/20 [00:15<00:08,  1.18s/it]

[895 459 640 272]
torch.Size([99])
torch.Size([63])
torch.Size([71])
torch.Size([47])


2024-08-05 23:08:22,580 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.3711


2024-08-05 23:08:22,657 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.3711
 70%|███████   | 14/20 [00:16<00:06,  1.13s/it]

[ 45 292 897 129]
torch.Size([104])
torch.Size([52])
torch.Size([43])
torch.Size([59])


2024-08-05 23:08:23,652 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.8727


2024-08-05 23:08:23,759 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.8727
 75%|███████▌  | 15/20 [00:17<00:05,  1.12s/it]

[707 157 332 347]
torch.Size([71])
torch.Size([117])
torch.Size([51])
torch.Size([54])


2024-08-05 23:08:24,674 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.8346


2024-08-05 23:08:24,763 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.8346
 80%|████████  | 16/20 [00:18<00:04,  1.10s/it]

[842  12 864 711]
torch.Size([90])
torch.Size([98])
torch.Size([129])
torch.Size([44])


2024-08-05 23:08:25,757 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.3108


2024-08-05 23:08:25,837 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.3108
 85%|████████▌ | 17/20 [00:19<00:03,  1.09s/it]

[997 129 475 708]
torch.Size([109])
torch.Size([59])
torch.Size([110])
torch.Size([47])


2024-08-05 23:08:26,816 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.176


2024-08-05 23:08:26,913 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.176
 90%|█████████ | 18/20 [00:20<00:02,  1.08s/it]

[911 665 374 987]
torch.Size([61])
torch.Size([81])
torch.Size([115])
torch.Size([47])


2024-08-05 23:08:27,849 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,4.6203


2024-08-05 23:08:27,926 - mamba_trainer.trainer - INFO - Training Loss at step 1: 4.6203
 95%|█████████▌| 19/20 [00:21<00:01,  1.06s/it]

[604 885 661 212]
torch.Size([77])
torch.Size([84])
torch.Size([84])
torch.Size([49])


2024-08-05 23:08:28,874 - mamba_trainer.trainer - INFO - Gradient Norm at step 0: 0.0


Step,Training Loss
1,3.5983


2024-08-05 23:08:28,960 - mamba_trainer.trainer - INFO - Training Loss at step 1: 3.5983
100%|██████████| 20/20 [00:22<00:00,  1.12s/it]
