In [None]:
import os
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, GPT2Tokenizer, GPT2LMHeadModel

from generation import TasksDataset, test_tokenization
from train_helpers import ReasoningTrainer

os.environ["TOKENIZERS_PARALLELISM"] = "true"

model = GPT2LMHeadModel.from_pretrained('gpt2').to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
test_tokenization(tokenizer)

In [None]:
# use only the first few layers out of 12
model._modules["transformer"]._modules["h"] = model._modules["transformer"]._modules["h"][:3]

In [None]:
def evaluate_example(model, ex):
    target_output = torch.cat([ex["task_ids"], ex["reasoning_ids"]]).to("cuda")
    task_len = len(ex["task_ids"])
    out = model.generate(
        input_ids=ex["task_ids"].reshape(1, -1).to("cuda"),
        attention_mask=ex["attention_mask"][:task_len].reshape(1, -1).to("cuda"),
        max_length=len(target_output),
        temperature=1,
        pad_token_id=tokenizer.pad_token_id,
    )
    if len(out[0]) != len(target_output):
        return False
    return all(out[0] == target_output)


def eval_correctness(model, dataset):
    return np.mean([evaluate_example(model, ex) for ex in dataset])

In [None]:
task_steps_limit = 1

trainer = ReasoningTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        weight_decay=1e-2,
        # logging_steps=10,
        # save_strategy="epoch"
    ),
)
trainer.answer_token = tokenizer.encode("\n")[0]
# disable log printing
trainer.log = lambda logs: None


total_examples = 0
while True:
    _num_examples_per_num_steps = [(ns, 16) for ns in range(1, task_steps_limit + 1)]
    trainer.train_dataset = TasksDataset(tokenizer, _num_examples_per_num_steps)
    total_examples += 16 * task_steps_limit
    trainer.train()

    # for each steps length, check whether model answers correctly
    _num_examples_per_num_steps = [(ns, 1) for ns in range(1, task_steps_limit + 1)]
    eval_dataset = TasksDataset(tokenizer, _num_examples_per_num_steps)
    scores = [evaluate_example(model, ex) for ex in eval_dataset]
    # create accuracy bar
    accuracy_bar = ""
    unmasked, masked = scores[:len(scores)//2], scores[len(scores)//2:]
    for group in [unmasked, masked]:
        accuracy_bar += "«"
        for score in group:
            if score:
                accuracy_bar += "█"
            else:
                accuracy_bar += " "
        accuracy_bar += "»"

    print(f"{total_examples:9}  seq.len.: {task_steps_limit:3}  " + accuracy_bar)
    
    if np.mean(scores) > 0.95:
        # all answers were correct, so increase difficulty level
        task_steps_limit += 1
    # if total_examples > 1000000:
    #     break