In [1]:
import os
import torch
import numpy as np
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, Trainer

from generation import DirectTasksDataset, test_tokenization
from train_helpers import DirectReasoningTrainer, get_accuracy_bar

os.environ["TOKENIZERS_PARALLELISM"] = "true"

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
test_tokenization(tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
# use only the first few layers out of 24
model._modules["backbone"].layers = model._modules["backbone"].layers[:4]

In [3]:
def evaluate_example(model, ex):
    input_ids = ex["input_ids"].reshape(1, -1).to("cuda")
    labels = ex["labels"].to("cuda")
    logits = model(input_ids=input_ids).logits
    output_ids = logits.argmax(axis=2).flatten()
    # tokenizer.decode(output_ids.flatten())
    return all(labels == output_ids)

In [4]:
task_steps_limit = 3

trainer = DirectReasoningTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        weight_decay=1e-2,
        # logging_steps=10,
        # save_strategy="epoch"
        # otherwise transformers will remove "labels" item for some reason
        remove_unused_columns=False,
    ),
)
trainer.answer_token = tokenizer.encode("\n")[0]
# disable log printing
trainer.log = lambda logs: None


total_examples = 0
while True:
    _num_examples_per_num_steps = [(ns, 16) for ns in range(1, task_steps_limit + 1)]
    trainer.train_dataset = DirectTasksDataset(tokenizer, _num_examples_per_num_steps)
    total_examples += len(trainer.train_dataset)
    trainer.train()

    # for each steps length, check whether model answers correctly
    _num_examples_per_num_steps = [(ns, 1) for ns in range(1, task_steps_limit + 1)]
    eval_dataset = DirectTasksDataset(tokenizer, _num_examples_per_num_steps)
    scores = [evaluate_example(model, ex) for ex in eval_dataset]
    print(
        f"{total_examples:9}  seq.len.: {task_steps_limit:3}  "
        + get_accuracy_bar(scores)
    )

    if np.mean(scores) > 0.9:
        # all answers were correct, so increase difficulty level
        task_steps_limit += 1

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


       48  seq.len.:   3  «   »
       96  seq.len.:   3  «██ »
      144  seq.len.:   3  «█  »
      192  seq.len.:   3  «██ »
      240  seq.len.:   3  «███»
      304  seq.len.:   4  «███ »
      368  seq.len.:   4  «███ »
      432  seq.len.:   4  «███ »
      496  seq.len.:   4  «███ »
      560  seq.len.:   4  «███ »
      624  seq.len.:   4  «███ »
      688  seq.len.:   4  «███ »
      752  seq.len.:   4  «███ »
      816  seq.len.:   4  «███ »
      880  seq.len.:   4  «████»
      960  seq.len.:   5  «████ »
     1040  seq.len.:   5  «████ »
     1120  seq.len.:   5  «████ »
     1200  seq.len.:   5  «████ »
     1280  seq.len.:   5  «█████»
     1376  seq.len.:   6  «█████ »
     1472  seq.len.:   6  «████  »
     1568  seq.len.:   6  «████  »
     1664  seq.len.:   6  «████  »
     1760  seq.len.:   6  «████  »
     1856  seq.len.:   6  «█████ »
     1952  seq.len.:   6  «██████»
     2064  seq.len.:   7  «█████  »
     2176  seq.len.:   7  «█████  »
     2288  seq.len.:   