In [None]:
import torch
import os
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer
from transformers import DataCollatorForLanguageModeling

from generation import generate_task


os.environ["TOKENIZERS_PARALLELISM"] = "true"
answer_token = 31984

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        # batched generation
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        # cut out the task part (the part before "answer")
        cut_shift_logits = []
        cut_labels = []
        for ex_shift_logits, ex_labels in zip(shift_logits, labels):
            # find the indexes of the "answer" token
            answer_index = torch.where(ex_labels == answer_token)[0]
            answer_index = int(answer_index)
            # cut out the task part
            cut_shift_logits.append(ex_shift_logits[answer_index:])
            cut_labels.append(ex_labels[answer_index:])

        # calculate loss only for the tokens after "answer"
        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(
            torch.cat(cut_shift_logits),
            torch.cat(cut_labels),
        )

        return lm_loss

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

    def log(self, logs):
        pass  # Override to do nothing and avoid printing logs

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, num_examples, sequence_length, mask):
        texts = []
        for _ in range(num_examples):
            task, normal_reasoning, masked_reasoning = generate_task(sequence_length)
            completion = masked_reasoning if mask else normal_reasoning
            texts.append(f"{task}\nanswer\n{completion}")
        
        tokenized = tokenizer(texts, padding=True)["input_ids"]
        tensors = [torch.LongTensor(tok) for tok in tokenized]

        self.input_ids = tensors

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i])

In [None]:
# eval correctness
def eval_correctness(all_input_ids):
    # todo parallelize
    is_corrects = []
    for full_input_ids in all_input_ids:
        task, target_reasoning = tokenizer.decode(full_input_ids).split("answer")
        task += "answer"

        task_tokens = tokenizer(task, return_tensors="pt")
        input_ids = task_tokens.input_ids.to(device="cuda")
        attn_mask = task_tokens.attention_mask.to(device="cuda")
        max_length = len(full_input_ids)

        out = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            cg=True,
            temperature=1,
        )
        text = tokenizer.decode(out[0])
        # print(text)
        # print("--- correct: ---")
        # print(target_reasoning.strip())
        model_reasoning = text.split("answer")[-1]
        is_correct = model_reasoning.strip() == target_reasoning.strip()
        is_corrects.append(is_correct)

    perc_correct = sum(is_corrects) / len(is_corrects)
    return perc_correct

In [None]:
sequence_length = 5
mask = True

trainer = MambaTrainer(
    model=model,
    # train_dataset=train_dataset,
    # eval_dataset=eval_dataset,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        # logging_steps=10,
        weight_decay=1e-2,
        # evaluation_strategy="steps",
        # eval_steps=10,
        # save_strategy="epoch"
    ),
    # data_collator=data_collator
    # data_collator=DataCollatorForSFTDataset(tokenizer=tokenizer)
)

# for _ in range(20):
while True:
    train_dataset=MyDataset(num_examples=160, sequence_length=sequence_length, mask=mask)
    eval_dataset=MyDataset(num_examples=20, sequence_length=sequence_length, mask=mask)

    trainer.train_dataset = train_dataset
    trainer.train()

    perc_correct = eval_correctness(eval_dataset[:20]["input_ids"])
    # print(f"correct answers: {perc_correct:4.0%}  seq.len.: {sequence_length}")
    num_correct = int(perc_correct * 20)
    accuracy_bar = "«" + "█" * num_correct + " " * (20 - num_correct) + "»"
    print(f"seq.len.: {sequence_length:3}   " + accuracy_bar)
    if perc_correct >= 0.9:
        sequence_length += 1
    if sequence_length > 40:
        break

In [None]:
eval_correctness(eval_dataset[:20]["input_ids"])

In [None]:
examples = eval_dataset[0:16]["input_ids"]
examples = torch.stack(examples).to("cuda")
trainer.compute_loss(model, dict(input_ids=examples))

In [None]:
answer_token = 31984

examples = eval_dataset[0:3]["input_ids"]
input_ids = torch.stack(examples).to("cuda")
# batched generation
lm_logits = model(input_ids).logits

labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()

# cut out the task part (the part before "answer")
cut_shift_logits = []
cut_labels = []
for ex_shift_logits, ex_labels in zip(shift_logits, labels):
    # find the indexes of the "answer" token
    answer_index = torch.where(ex_labels == answer_token)[0]
    answer_index = int(answer_index)
    # cut out the task part
    cut_shift_logits.append(ex_shift_logits[answer_index:])
    cut_labels.append(ex_labels[answer_index:])

loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
    torch.cat(cut_shift_logits),
    torch.cat(cut_labels),
)

In [None]:
# def compute_metrics(eval_pred):
#     # predictions, labels = eval_pred
#     # # Assuming predictions are token IDs and have already been trimmed of padding
#     # avg_num_tokens = predictions.shape[-1]  # Get the sequence length dimension
#     # return {"avg_num_tokens": avg_num_tokens}
#     return {"dummy": 3}


In [None]:
# import transformers
# from dataclasses import dataclass

# @dataclass
# class DataCollatorForSFTDataset(object):
#     """
#     Collate examples for supervised fine-tuning.
#     """

#     tokenizer: transformers.PreTrainedTokenizer

#     def __call__(self, instances):
#         input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
#         input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
#         labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

#         return dict(
#             input_ids=input_ids,
#             # labels=labels,
#             # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
#         )