In [None]:
import os
import torch
import numpy as np
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer

from generation import generate_task, mask_all_values, TasksDataset

os.environ["TOKENIZERS_PARALLELISM"] = "true"
answer_token = 31984

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

# make sure that the task is tokenized in a regular minimal way
num_steps = 13
task, reasoning = generate_task(num_steps)
assert len(tokenizer.encode(" ".join(task))) == num_steps * 2 + 1
assert len(tokenizer.encode(" ".join(reasoning))) == num_steps * 3 + 1
reasoning = mask_all_values(reasoning)
assert len(tokenizer.encode(" ".join(reasoning))) == num_steps * 3 + 1


In [None]:
# use only the first 8 layers out of 24
model._modules["backbone"].layers = model._modules["backbone"].layers[:4]

In [None]:
final_answer_loss_contribution = 0.5  # between 0 and 1


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")

        # batched generation
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        # cut out the task part (the part before "answer")
        reasoning_shift_logits = []
        reasoning_labels = []
        final_answer_shift_logits = []
        final_answer_labels = []
        for ex_shift_logits, ex_labels in zip(shift_logits, labels):
            # find the indexes of the "answer" token
            answer_index = torch.where(ex_labels == answer_token)[0]
            answer_index = int(answer_index)
            # cut out the task part
            reasoning_shift_logits.append(ex_shift_logits[answer_index:-1])
            reasoning_labels.append(ex_labels[answer_index:-1])
            # loss for the final answer will be calculated separately
            final_answer_shift_logits.append(ex_shift_logits[-1:])
            final_answer_labels.append(ex_labels[-1:])

        # calculate loss only for the tokens after "answer"
        loss_fct = torch.nn.CrossEntropyLoss()
        reasoning_lm_loss = loss_fct(
            torch.cat(reasoning_shift_logits),
            torch.cat(reasoning_labels),
        )
        loss_fct = torch.nn.CrossEntropyLoss()
        final_answer_lm_loss = loss_fct(
            torch.cat(final_answer_shift_logits),
            torch.cat(final_answer_labels),
        )
        return reasoning_lm_loss * (1 - final_answer_loss_contribution) + final_answer_lm_loss * final_answer_loss_contribution

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

In [None]:
def evaluate_example(model, ex):
    target_output = torch.cat([ex["task_ids"], ex["reasoning_ids"]]).to("cuda")
    out = model.generate(
        input_ids=ex["task_ids"].reshape(1, -1).to("cuda"),
        max_length=len(target_output),
        cg=True,
        temperature=1,
    )
    correct = all(out[0] == target_output)
    return correct


def eval_correctness(model, dataset):
    return np.mean([evaluate_example(model, ex) for ex in dataset])

In [None]:
task_steps_limit = 1
mask = True

trainer = MambaTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        weight_decay=1e-2,
        # logging_steps=10,
        # save_strategy="epoch"
    ),
)
# disable log printing
trainer.log = lambda logs: None

total_examples = 0
while True:
    _num_examples_per_num_steps = [(ns, 16) for ns in range(1, task_steps_limit + 1)]
    trainer.train_dataset = TasksDataset(tokenizer, mask, _num_examples_per_num_steps)
    total_examples += 16 * task_steps_limit
    trainer.train()

    # for each steps length, check whether model answers correctly
    _num_examples_per_num_steps = [(ns, 1) for ns in range(1, task_steps_limit + 1)]
    eval_dataset = TasksDataset(tokenizer, mask, _num_examples_per_num_steps)
    scores = [evaluate_example(model, ex) for ex in eval_dataset]
    # create accuracy bar
    accuracy_bar = "«"
    for score in scores:
        if score:
            accuracy_bar += "█"
        else:
            accuracy_bar += " "
    accuracy_bar += "»"

    print(f"{total_examples:9}  seq.len.: {task_steps_limit:3}  " + accuracy_bar)
    
    if all(scores):
        # all answers were correct, so increase difficulty level
        task_steps_limit += 1