In [1]:
import os
import logging
import sys
from contextlib import contextmanager
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [2]:
# @contextmanager
# def suppress_output():
#     with open(os.devnull, 'w') as fnull:
#         old_stdout = sys.stdout
#         old_stderr = sys.stderr
#         sys.stdout = fnull
#         sys.stderr = fnull
#         try:
#             yield
#         finally:
#             sys.stdout = old_stdout
#             sys.stderr = old_stderr


In [3]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),  
    ]
)
logger = logging.getLogger(__name__)

In [4]:
import torch
import numpy as np
from torch.utils.data import Subset, DataLoader
from tqdm.auto import tqdm
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM
from mamba_trainer.data import DataModule
from mamba_trainer.trainer import MambaTrainer, GradientCallback
%matplotlib inline 

  from .autonotebook import tqdm as notebook_tqdm
2024-08-15 19:35:23.100235: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-15 19:35:23.119032: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-15 19:35:23.124730: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-15 19:35:23.139595: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

lora_config = LoraConfig(
    r=16,
    target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

model = get_peft_model(model, lora_config)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MambaForCausalLM(
      (backbone): MambaModel(
        (embeddings): lora.Embedding(
          (base_layer): Embedding(50280, 768)
          (lora_dropout): ModuleDict(
            (default): Identity()
          )
          (lora_A): ModuleDict()
          (lora_B): ModuleDict()
          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 16x50280 (cuda:0)])
          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 768x16 (cuda:0)])
          (lora_magnitude_vector): ModuleDict()
        )
        (layers): ModuleList(
          (0-23): 24 x MambaBlock(
            (norm): MambaRMSNorm(768, eps=1e-05)
            (mixer): MambaMixer(
              (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
              (act): SiLU()
              (in_proj): lora.Linear(
                (base_la

In [7]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model)

trainable params: 3,796,608 || all params: 132,931,968 || trainable%: 2.86


In [8]:
training_args = TrainingArguments(
    learning_rate=5e-5,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    output_dir="model",
    logging_dir="logs",
    evaluation_strategy="no", 
    logging_steps=1,
    save_steps=1,
    report_to="none",
    disable_tqdm=True
)



In [9]:
train_data_module = DataModule(data_path="./data/basic_20-70/train.tsv", tokenizer=tokenizer)
train_dataset = train_data_module.dataset
val_data_module = DataModule(data_path="./data/basic_20-70/val.tsv", tokenizer=tokenizer)
val_dataset = val_data_module.dataset

./data/basic_20-70/train.tsv


I0000 00:00:1723739732.375249 2047537 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723739732.386345 2047537 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723739732.397141 2047537 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723739732.412803 2047537 cuda_executor.cc:1015] successful NUMA node read from SysFS ha

./data/basic_20-70/val.tsv


2024-08-15 19:35:35.160134: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [10]:
def evaluate(trainer, eval_dataset, step):
    metrics = trainer.evaluate(eval_dataset)
    loss = metrics.get('eval_loss', None)
    logger.info(f"Validation Loss at step {step}: {loss}")

In [11]:
grad_callback = GradientCallback()

for i in tqdm(range(1000), disable=False):
    np.random.seed(None)
    ids = np.random.choice(len(train_dataset), size=4, replace=False)
    subset = Subset(train_dataset, ids.tolist())
    
    trainer = MambaTrainer(
            model=model,
            args=training_args,
            train_dataset=subset,
            tokenizer=tokenizer,
            optimizers=(optimizer, None),
            data_collator=train_data_module.data_collator,
            callbacks=[grad_callback]
    )

    trainer.train()

    if grad_callback.step % 20 == 0:
        evaluate(trainer=trainer,
                     eval_dataset=val_dataset,
                     step=grad_callback.step
        )

  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 2.8619, 'grad_norm': 1.4639662504196167, 'learning_rate': 0.0, 'pre_update_loss': 2.8619091510772705, 'epoch': 1.0}


  0%|          | 1/1000 [00:02<36:45,  2.21s/it]

{'train_runtime': 1.4695, 'train_samples_per_second': 2.722, 'train_steps_per_second': 0.68, 'train_loss': 2.8619091510772705, 'pre_update_loss': 2.8619091510772705, 'epoch': 1.0}
{'loss': 2.7825, 'grad_norm': 1.142163634300232, 'learning_rate': 0.0, 'pre_update_loss': 2.7824604511260986, 'epoch': 1.0}


  0%|          | 2/1000 [00:03<29:45,  1.79s/it]

{'train_runtime': 1.1459, 'train_samples_per_second': 3.491, 'train_steps_per_second': 0.873, 'train_loss': 2.7824604511260986, 'pre_update_loss': 2.7824604511260986, 'epoch': 1.0}
{'loss': 2.8572, 'grad_norm': 1.4149364233016968, 'learning_rate': 0.0, 'pre_update_loss': 2.8571767807006836, 'epoch': 1.0}


  0%|          | 3/1000 [00:05<33:24,  2.01s/it]

{'train_runtime': 1.835, 'train_samples_per_second': 2.18, 'train_steps_per_second': 0.545, 'train_loss': 2.8571767807006836, 'pre_update_loss': 2.8571767807006836, 'epoch': 1.0}
{'loss': 2.8561, 'grad_norm': 1.3130896091461182, 'learning_rate': 0.0, 'pre_update_loss': 2.856112241744995, 'epoch': 1.0}


  0%|          | 4/1000 [00:07<30:58,  1.87s/it]

{'train_runtime': 1.3393, 'train_samples_per_second': 2.987, 'train_steps_per_second': 0.747, 'train_loss': 2.856112241744995, 'pre_update_loss': 2.856112241744995, 'epoch': 1.0}
{'loss': 2.7542, 'grad_norm': 1.096690058708191, 'learning_rate': 0.0, 'pre_update_loss': 2.7541539669036865, 'epoch': 1.0}


  0%|          | 5/1000 [00:09<31:28,  1.90s/it]

{'train_runtime': 1.5762, 'train_samples_per_second': 2.538, 'train_steps_per_second': 0.634, 'train_loss': 2.7541539669036865, 'pre_update_loss': 2.7541539669036865, 'epoch': 1.0}
{'loss': 2.5959, 'grad_norm': 0.9770918488502502, 'learning_rate': 0.0, 'pre_update_loss': 2.5959413051605225, 'epoch': 1.0}


  1%|          | 6/1000 [00:11<31:57,  1.93s/it]

{'train_runtime': 1.6633, 'train_samples_per_second': 2.405, 'train_steps_per_second': 0.601, 'train_loss': 2.5959413051605225, 'pre_update_loss': 2.5959413051605225, 'epoch': 1.0}
{'loss': 3.1706, 'grad_norm': 1.556930422782898, 'learning_rate': 0.0, 'pre_update_loss': 3.1706106662750244, 'epoch': 1.0}


  1%|          | 7/1000 [00:13<31:42,  1.92s/it]

{'train_runtime': 1.4068, 'train_samples_per_second': 2.843, 'train_steps_per_second': 0.711, 'train_loss': 3.1706106662750244, 'pre_update_loss': 3.1706106662750244, 'epoch': 1.0}
{'loss': 2.7686, 'grad_norm': 1.046669840812683, 'learning_rate': 0.0, 'pre_update_loss': 2.7685508728027344, 'epoch': 1.0}


  1%|          | 8/1000 [00:15<33:26,  2.02s/it]

{'train_runtime': 1.9521, 'train_samples_per_second': 2.049, 'train_steps_per_second': 0.512, 'train_loss': 2.7685508728027344, 'pre_update_loss': 2.7685508728027344, 'epoch': 1.0}
{'loss': 2.8142, 'grad_norm': 1.0975260734558105, 'learning_rate': 0.0, 'pre_update_loss': 2.814185619354248, 'epoch': 1.0}


  1%|          | 9/1000 [00:18<35:35,  2.16s/it]

{'train_runtime': 2.1528, 'train_samples_per_second': 1.858, 'train_steps_per_second': 0.465, 'train_loss': 2.814185619354248, 'pre_update_loss': 2.814185619354248, 'epoch': 1.0}
{'loss': 2.884, 'grad_norm': 1.0256454944610596, 'learning_rate': 0.0, 'pre_update_loss': 2.883972644805908, 'epoch': 1.0}


  1%|          | 10/1000 [00:20<34:50,  2.11s/it]

{'train_runtime': 1.6337, 'train_samples_per_second': 2.448, 'train_steps_per_second': 0.612, 'train_loss': 2.883972644805908, 'pre_update_loss': 2.883972644805908, 'epoch': 1.0}
{'loss': 2.8162, 'grad_norm': 1.073409080505371, 'learning_rate': 0.0, 'pre_update_loss': 2.8162150382995605, 'epoch': 1.0}


  1%|          | 11/1000 [00:21<30:16,  1.84s/it]

{'train_runtime': 0.9303, 'train_samples_per_second': 4.3, 'train_steps_per_second': 1.075, 'train_loss': 2.8162150382995605, 'pre_update_loss': 2.8162150382995605, 'epoch': 1.0}
{'loss': 2.6104, 'grad_norm': 0.9044678807258606, 'learning_rate': 0.0, 'pre_update_loss': 2.610399007797241, 'epoch': 1.0}


  1%|          | 12/1000 [00:23<32:41,  1.99s/it]

{'train_runtime': 1.8502, 'train_samples_per_second': 2.162, 'train_steps_per_second': 0.54, 'train_loss': 2.610399007797241, 'pre_update_loss': 2.610399007797241, 'epoch': 1.0}
{'loss': 2.761, 'grad_norm': 1.2777113914489746, 'learning_rate': 0.0, 'pre_update_loss': 2.761019468307495, 'epoch': 1.0}


  1%|▏         | 13/1000 [00:26<38:15,  2.33s/it]

{'train_runtime': 2.6347, 'train_samples_per_second': 1.518, 'train_steps_per_second': 0.38, 'train_loss': 2.761019468307495, 'pre_update_loss': 2.761019468307495, 'epoch': 1.0}
{'loss': 2.7473, 'grad_norm': 1.113443374633789, 'learning_rate': 0.0, 'pre_update_loss': 2.7472660541534424, 'epoch': 1.0}


  1%|▏         | 14/1000 [00:30<43:10,  2.63s/it]

{'train_runtime': 3.0103, 'train_samples_per_second': 1.329, 'train_steps_per_second': 0.332, 'train_loss': 2.7472660541534424, 'pre_update_loss': 2.7472660541534424, 'epoch': 1.0}
{'loss': 2.8829, 'grad_norm': 1.7518501281738281, 'learning_rate': 0.0, 'pre_update_loss': 2.882904052734375, 'epoch': 1.0}


  2%|▏         | 15/1000 [00:33<47:06,  2.87s/it]

{'train_runtime': 3.1244, 'train_samples_per_second': 1.28, 'train_steps_per_second': 0.32, 'train_loss': 2.882904052734375, 'pre_update_loss': 2.882904052734375, 'epoch': 1.0}
{'loss': 2.6482, 'grad_norm': 1.1963293552398682, 'learning_rate': 0.0, 'pre_update_loss': 2.6482226848602295, 'epoch': 1.0}


  2%|▏         | 16/1000 [00:37<53:35,  3.27s/it]

{'train_runtime': 3.4567, 'train_samples_per_second': 1.157, 'train_steps_per_second': 0.289, 'train_loss': 2.6482226848602295, 'pre_update_loss': 2.6482226848602295, 'epoch': 1.0}
{'loss': 3.0071, 'grad_norm': 1.822587251663208, 'learning_rate': 0.0, 'pre_update_loss': 3.007148265838623, 'epoch': 1.0}


  2%|▏         | 17/1000 [00:41<55:23,  3.38s/it]

{'train_runtime': 3.2644, 'train_samples_per_second': 1.225, 'train_steps_per_second': 0.306, 'train_loss': 3.007148265838623, 'pre_update_loss': 3.007148265838623, 'epoch': 1.0}
{'loss': 2.7929, 'grad_norm': 1.2938666343688965, 'learning_rate': 0.0, 'pre_update_loss': 2.792912483215332, 'epoch': 1.0}


  2%|▏         | 18/1000 [00:44<56:06,  3.43s/it]

{'train_runtime': 3.2777, 'train_samples_per_second': 1.22, 'train_steps_per_second': 0.305, 'train_loss': 2.792912483215332, 'pre_update_loss': 2.792912483215332, 'epoch': 1.0}
{'loss': 2.8393, 'grad_norm': 1.2552732229232788, 'learning_rate': 0.0, 'pre_update_loss': 2.839315414428711, 'epoch': 1.0}


  2%|▏         | 19/1000 [00:48<57:51,  3.54s/it]

{'train_runtime': 3.512, 'train_samples_per_second': 1.139, 'train_steps_per_second': 0.285, 'train_loss': 2.839315414428711, 'pre_update_loss': 2.839315414428711, 'epoch': 1.0}
{'loss': 2.506, 'grad_norm': 0.8212756514549255, 'learning_rate': 0.0, 'pre_update_loss': 2.506035804748535, 'epoch': 1.0}
{'train_runtime': 3.405, 'train_samples_per_second': 1.175, 'train_steps_per_second': 0.294, 'train_loss': 2.506035804748535, 'pre_update_loss': 2.506035804748535, 'epoch': 1.0}
