In [None]:
import argparse
import time
import json
import os

import torch
import torch.nn.functional as F

from einops import rearrange

from transformers import AutoTokenizer, AutoModelForCausalLM

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from torch.utils.data import Dataset

from transformers import AutoTokenizer, TrainingArguments
from transformers import Trainer

from generation import generate_task

In [None]:
device = "cuda"
genlen = 100
os.environ["TOKENIZERS_PARALLELISM"] = "true"

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token = "fin"


model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)

In [None]:
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

    def save_model(self, output_dir, _internal_call=None):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)
        
        # https://huggingface.co/state-spaces/mamba-130m/blob/main/config.json
        json_str = """
{
    "d_model": 768,
    "n_layer": 24,
    "vocab_size": 50277,
    "ssm_cfg": {},
    "rms_norm": true,
    "residual_in_fp32": true,
    "fused_add_norm": true,
    "pad_vocab_size_multiple": 8
}"""
        with open(f"{output_dir}/config.json", 'w') as f:
            f.write(json_str)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, num_examples, sequence_length):
        texts = []
        for _ in range(num_examples):
            task, reasoning, _ = generate_task(sequence_length)
            texts.append(f"{task}\nanswer\n{reasoning}")
        
        tokenized = tokenizer(texts, padding=True)

        self.input_ids = tokenized
        self.labels = tokenized

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])
  

In [None]:
train_dataset=MyDataset(num_examples=1000, sequence_length=10)
eval_dataset=MyDataset(num_examples=1000, sequence_length=10)

In [None]:
!pip install accelerate -U

In [None]:
import accelerate

In [None]:

trainer = MambaTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    args=TrainingArguments(
        learning_rate=5e-4,
        num_train_epochs=2,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=1,
        optim="adamw_torch",
        output_dir="output",
        save_total_limit=2,
        logging_steps=1,
        save_steps=500,
    ),
)

trainer.train()
trainer.save_model("output")

In [None]:
prompt = "1 +=3 +=4 +=1 mod5 +=3 +=1 *=3 +=1 +=1 mod5 +=3 +=3 mod2 +=2\nanswer\n"
# tokens = tokenizer(prompt, return_tensors="pt")
# input_ids = tokens.input_ids.to(device=device)
# attn_mask = tokens.attention_mask.to(device=device)
# out = model.generate(
#     input_ids=input_ids,
#     max_length=100,
#     temperature=1,
# )
# prompt = "the ultimate"

tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen

out = model.generate(
    input_ids=input_ids,
    max_length=max_length,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=False,
    temperature=1,
    top_k=1,
    top_p=1,
    min_p=0,
    repetition_penalty=1,
)
text = tokenizer.batch_decode(out.sequences.tolist())
print(text[0])