In [None]:
import argparse
import time
import json
import os

import torch
import torch.nn.functional as F

from einops import rearrange

from transformers import AutoTokenizer, AutoModelForCausalLM

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from torch.utils.data import Dataset

from transformers import AutoTokenizer, TrainingArguments
from transformers import Trainer

from generation import generate_task

In [None]:
device = "cuda"
genlen = 100
os.environ["TOKENIZERS_PARALLELISM"] = "true"

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)

In [None]:
tokenizer.encode("5 +5 *5 mod5 5 mod5")

In [None]:
prompt = "the ultimate"

tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen

In [None]:
out = model.generate(
    # **model_inputs,
    input_ids=input_ids,
    max_length=max_length,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=False,
    temperature=1,
    top_k=1,
    top_p=1,
    min_p=0,
    repetition_penalty=1,
)

In [None]:
text = tokenizer.batch_decode(out.sequences.tolist())
text[0]

In [None]:
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

    def save_model(self, output_dir, _internal_call=None):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)
        
        # https://huggingface.co/state-spaces/mamba-130m/blob/main/config.json
        json_str = """
{
    "d_model": 768,
    "n_layer": 24,
    "vocab_size": 50277,
    "ssm_cfg": {},
    "rms_norm": true,
    "residual_in_fp32": true,
    "fused_add_norm": true,
    "pad_vocab_size_multiple": 8
}"""
        with open(f"{output_dir}/config.json", 'w') as f:
            f.write(json_str)

In [None]:
num_examples = 1000
texts = []
for _ in range(num_examples):
    task, reasoning, _ = generate_task(5)
    text = f"{task}\nanswer\n{reasoning}"
    print(text)
    texts.append(text)
    break

In [None]:
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

data_module = SFTDataModule(
    tokenizer=tokenizer,
    data_path=args.data_path,
)

trainer = MambaTrainer(
    model=model,
    train_dataset=data_module.dataset,
    tokenizer=tokenizer,
    args=TrainingArguments(
        learning_rate=5e-4,
        num_train_epochs=10,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=1,
        optim="adamw_torch",
        output_dir="output",
        save_total_limit=2,
        logging_steps=50,
        save_steps=500,
    ),
    data_collator=data_module.data_collator,
)

trainer.train()
trainer.save_model(args.output)