In [1]:
import torch
import os
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer
from transformers import DataCollatorForLanguageModeling

from generation import generate_task, mask_all_values

os.environ["TOKENIZERS_PARALLELISM"] = "true"
answer_token = 31984

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

ModuleNotFoundError: No module named 'torch'

In [2]:
# use only the first 8 layers out of 24
model._modules["backbone"].layers = model._modules["backbone"].layers[:8]

In [3]:
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")

        # batched generation
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        # cut out the task part (the part before "answer")
        reasoning_shift_logits = []
        reasoning_labels = []
        final_answer_shift_logits = []
        final_answer_labels = []
        for ex_shift_logits, ex_labels in zip(shift_logits, labels):
            # find the indexes of the "answer" token
            answer_index = torch.where(ex_labels == answer_token)[0]
            answer_index = int(answer_index)
            # cut out the task part
            reasoning_shift_logits.append(ex_shift_logits[answer_index:-1])
            reasoning_labels.append(ex_labels[answer_index:-1])
            # loss for the final answer will be calculated separately
            final_answer_shift_logits.append(ex_shift_logits[-1:])
            final_answer_labels.append(ex_labels[-1:])

        # calculate loss only for the tokens after "answer"
        loss_fct = torch.nn.CrossEntropyLoss()
        reasoning_lm_loss = loss_fct(
            torch.cat(reasoning_shift_logits),
            torch.cat(reasoning_labels),
        )
        loss_fct = torch.nn.CrossEntropyLoss()
        final_answer_lm_loss = loss_fct(
            torch.cat(final_answer_shift_logits),
            torch.cat(final_answer_labels),
        )
        return reasoning_lm_loss + final_answer_lm_loss

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

    # def log(self, logs):
    #     pass  # Override to do nothing and avoid printing logs

NameError: name 'Trainer' is not defined

In [4]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, mask, num_examples_per_num_steps):
        """
        num_examples_per_num_steps: list of tuples (num_steps, num_examples)
        """
        texts = []
        for num_steps, num_examples in num_examples_per_num_steps:
            for _ in range(num_examples):
                task, reasoning = generate_task(num_steps)
                if mask:
                    mask_all_values(reasoning)
                texts.append(f"{task}\nanswer\n{reasoning}")
        
        tokenized = tokenizer(texts, padding=True)["input_ids"]
        self.input_ids = [torch.LongTensor(tok) for tok in tokenized]
        # max_num_steps = max([num_steps for num_steps, _ in num_examples_per_num_steps])

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i])

NameError: name 'torch' is not defined

In [51]:
def get_completion_from_model(example_input_ids):
    task, _ = tokenizer.decode(example_input_ids).split("answer")
    task += "answer"
    task_tokens = tokenizer(task, return_tensors="pt")

    out = model.generate(
        input_ids=task_tokens.input_ids.to(device="cuda"),
        max_length=len(example_input_ids),
        cg=True,
        temperature=1,
    )
    return out[0]

In [55]:
example_input_ids = ex["input_ids"]
task, _ = tokenizer.decode(example_input_ids).split("answer")
task += "answer"
task
# task_tokens = tokenizer(task, return_tensors="pt")
# task_tokens

'1 +=3 mod5 +=3 +=2 mod2 *=2 +=3 +=2 mod2 mod5 mod5 *=2 +=3 +=4 mod2 *=2 *=3 +=1 +=1 mod5 mod2 +=1 +=3 +=1 +=2 mod5 mod2 *=2 +=2 +=3 +=2 mod2 mod2 mod2 +=3 *=2 mod5 +=2 +=3 mod5 mod2 +=2 +=1 +=3 +=1 mod5 +=4 +=1 mod5 *=2 mod5 +=2 *=2 mod5 mod5 *=2 *=3 +=1 mod2\nanswer'

In [None]:


def eval_correctness(dataset):
    all_input_ids = dataset[:]["input_ids"]
    # maybe todo: parallelize
    is_corrects = []
    for full_input_ids in all_input_ids:

        is_correct = model_reasoning.strip() == target_reasoning.strip()
        is_corrects.append(is_correct)

    perc_correct = sum(is_corrects) / len(is_corrects)
    return perc_correct

In [41]:
seq_length = 59
mask = True

trainer = MambaTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        weight_decay=1e-2,
        # logging_steps=10,
        # save_strategy="epoch"
    ),
)

i = 1
examples_per_epoch = 16 * 160
while True:
    train_dataset=MyDataset(num_examples=examples_per_epoch, sequence_length=seq_length, mask=mask)

    trainer.train_dataset = train_dataset
    trainer.train()

    eval_dataset=MyDataset(num_examples=25, sequence_length=seq_length, mask=mask)
    perc_correct = eval_correctness(eval_dataset)
    num_correct = int(perc_correct * 25)
    accuracy_bar = "«" + "█" * num_correct + " " * (25 - num_correct) + "»"
    total_examples = i * examples_per_epoch
    print(f"{total_examples:9}  seq.len.: {seq_length:3}  " + accuracy_bar)
    i += 1

    if perc_correct >= 0.95 and seq_length < 100:
        seq_length += 1
    # if i == 512:
    #     break
    # if sequence_length > 40:
    #     break

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


KeyboardInterrupt: 

In [10]:
for seq_length in range(55, 63):
    eval_dataset=MyDataset(num_examples=50, sequence_length=seq_length, mask=mask)
    perc_correct = eval_correctness(eval_dataset[:50]["input_ids"])
    num_correct = int(perc_correct * 50)
    accuracy_bar = "«" + "█" * num_correct + " " * (50 - num_correct) + "»"
    total_examples = i * examples_per_epoch
    print(f"seq.len.: {seq_length:3}  " + accuracy_bar)

seq.len.:  55  «                                                  »
seq.len.:  56  «                                                  »
seq.len.:  57  «███████████                                       »
seq.len.:  58  «███████████████████████████                       »
seq.len.:  59  «███████████████████████████████████               »
seq.len.:  60  «███████████████████████                           »
seq.len.:  61  «                                                  »
seq.len.:  62  «                                                  »


In [13]:
full_input_ids = eval_dataset[0]["input_ids"]
task, target_reasoning = tokenizer.decode(full_input_ids).split("answer")
task += "answer"

task_tokens = tokenizer(task, return_tensors="pt")
input_ids = task_tokens.input_ids.to(device="cuda")
# attn_mask = task_tokens.attention_mask.to(device="cuda")
max_length = len(full_input_ids)

out = model.generate(
    input_ids=input_ids,
    max_length=max_length,
    cg=True,
    temperature=1,
)
text = tokenizer.decode(out[0])
model_reasoning = text.split("answer")[-1]
# is_correct = model_reasoning.strip() == target_reasoning.strip()
print(text)
print(target_reasoning)



1 +=1 +=4 +=2 +=1 mod2 mod5 mod2 mod5 *=3 +=1 +=2 +=2 mod5 mod2 mod2 *=2 mod5 +=1 +=1 +=3 mod5 +=2 mod5 +=3 +=1 mod5 +=4 +=1 +=1 mod5 +=2 +=1 mod5 +=2 +=4 mod5 +=1 *=2 mod5 +=4 +=1 mod5 +=1 +=3 mod2 +=1 +=2 +=1 +=3 mod5 +=4 mod5 +=4 mod5 mod2 +=4 +=3 +=1 mod5 +=2 +=3 mod2
answer
1 +=1 0 +=4 0 +=2 0 +=1 0 mod2 0 mod5 0 mod2 0 mod5 0 *=3 0 +=1 0 +=2 0 +=2 0 mod5 0 mod2 0 mod2 0 *=2 0 mod5 0 +=1 0 +=1 0 +=3 0 mod5 0 +=2 0 mod5 0 +=3 0 +=1 0 mod5 0 +=4 0 +=1 0 +=1 0 mod5 0 +=2 0 +=1 0 mod5 0 +=2 0 +=4 0 mod5 0 +=1 0 *=2 0 mod5 0 +=4 0 +=1 0 mod5 0 +=1 0 +=3 0 mod2 0 +=1 0 +=2 0 +=1 0 +=3 0 mod5 0 +=4 0 mod5 0 mod2 0 +=4 0 +=4 0 mod5 0 +=2 0 +=3 0 mod2 1 1 += concess 0 d d d d d

1 +=1 0 +=4 0 +=2 0 +=1 0 mod2 0 mod5 0 mod2 0 mod5 0 *=3 0 +=1 0 +=2 0 +=2 0 mod5 0 mod2 0 mod2 0 *=2 0 mod5 0 +=1 0 +=1 0 +=3 0 mod5 0 +=2 0 mod5 0 +=3 0 +=1 0 mod5 0 +=4 0 +=1 0 +=1 0 mod5 0 +=2 0 +=1 0 mod5 0 +=2 0 +=4 0 mod5 0 +=1 0 *=2 0 mod5 0 +=4 0 +=1 0 mod5 0 +=1 0 +=3 0 mod2 0 +=1 0 +=2 0 +=1 0 +=3 0 mod

In [None]:
eval_correctness(eval_dataset[:20]["input_ids"])

In [16]:
# trainer.save_model("output", None)

In [28]:
eval_dataset=MyDataset(num_examples=50, sequence_length=57, mask=True)

In [29]:
eval_dataset[0:4]
trainer = MambaTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        # logging_steps=10,
        weight_decay=1e-2,
        # save_strategy="epoch"
    ),
)

[tensor([   18,   771,    19,  7079,    21,  7079,    21,   771,    22, 38825,
            19,   771,    22,  7079,    20,  7079,    20,   771,    19,  7079,
            21,  7079,    21,   771,    19, 38825,    20,  7079,    18, 38825,
            19,   771,    22, 38825,    20,   771,    22,  7079,    18,  7079,
            19,  7079,    18,   771,    22,   771,    22,  7079,    20,  7079,
            20,   771,    22,  7079,    18,   771,    19, 38825,    19,   771,
            22,  7079,    21,  7079,    20,   771,    22,  7079,    19,   771,
            22, 38825,    20,  7079,    20,   771,    22, 38825,    19,  7079,
            21,  7079,    18,   771,    19,  7079,    19, 38825,    20,   771,
            22,  7079,    20,  7079,    18,   771,    22,  7079,    19,  7079,
            18,  7079,    18,  7079,    18,  7079,    18,   771,    19, 38825,
            20,  7079,    20,  7079,    19,   187, 31984,   187,    18,   771,
            19,   470,  7079,    21,   470,  7079,  