In [None]:
import torch
import os
import torch
import argparse
# from datasets import load_from_disk
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer
from transformers import DataCollatorForLanguageModeling

from generation import generate_task


os.environ["TOKENIZERS_PARALLELISM"] = "true"


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

In [None]:
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # 


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, num_examples, sequence_length):
        texts = []
        for _ in range(num_examples):
            # task, reasoning, _ = generate_task(sequence_length)
            task, _, reasoning = generate_task(sequence_length)
            texts.append(f"{task}\nanswer\n{reasoning}")
        
        # tokenized = tokenizer(texts)["input_ids"]
        tokenized = tokenizer(texts, padding=True)["input_ids"]
        tensors = [torch.LongTensor(tok) for tok in tokenized]

        self.input_ids = tensors
        self.labels = tensors

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


train_dataset=MyDataset(num_examples=1000, sequence_length=2)
eval_dataset=MyDataset(num_examples=1000, sequence_length=2)

In [None]:
trainer = MambaTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    args=TrainingArguments(
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        logging_steps=20,
        weight_decay=1e-2,
        # evaluation_strategy="epoch",
        save_strategy="epoch"
    ),
    
    data_collator=data_collator
)

trainer.train()

In [None]:
# prompt = "1 *=3 +=1 mod5 +=3 +=1 mod5 +=2 +=3 mod5 *=3\nanswer\n1 *=3 3 +=1 4 mod5 4 +=3 7 +=1 8 mod5 3 +=2 5 +=3 8 mod5 3 *=3 9"
# prompt = "1 *=3 +=1 mod5 +=3 +=1 mod5 +=2 +=3 mod5 *=3\nanswer\n1 *=3"
prompt = "1 *=3 +=1\nanswer\n"

tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device="cuda")
attn_mask = tokens.attention_mask.to(device="cuda")
max_length = input_ids.shape[1] + 50

out = model.generate(
    input_ids=input_ids,
    max_length=max_length,
    cg=True,
    temperature=1,
)
text = tokenizer.decode(out[0])
print(text)