In [1]:
import os
import torch
import numpy as np
from transformers import TrainingArguments, GPT2Tokenizer, GPT2LMHeadModel

from generation import DirectTasksDataset, test_tokenization
from train_helpers import DirectReasoningTrainer, get_accuracy_bar, Curriculum

os.environ["TOKENIZERS_PARALLELISM"] = "true"

model = GPT2LMHeadModel.from_pretrained('gpt2').to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
test_tokenization(tokenizer)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [2]:
# use only the first few layers out of 12
model._modules["transformer"]._modules["h"] = model._modules["transformer"]._modules["h"][:3]

In [3]:
def evaluate_example(model, ex):
    input_ids = ex["input_ids"].reshape(1, -1).to("cuda")
    attention_mask=ex["attention_mask"].reshape(1, -1).to("cuda")
    labels = ex["labels"].to("cuda")
    logits = model(input_ids=input_ids).logits
    output_ids = logits.argmax(axis=2).flatten()
    # tokenizer.decode(output_ids.flatten())
    return all(labels == output_ids)

In [4]:
task_steps_limit = 1

trainer = DirectReasoningTrainer(
    model=model,
    tokenizer=tokenizer,
    args=TrainingArguments(
        disable_tqdm=True,  # This disables the progress bars
        learning_rate=5e-4,
        num_train_epochs=1,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        dataloader_num_workers=2,
        optim="adamw_torch",
        output_dir="out",
        weight_decay=1e-2,
    ),
)
trainer.answer_token = tokenizer.encode("\n")[0]
# disable log printing
trainer.log = lambda logs: None

curriculum = Curriculum()
total_examples = 0
while True:
    task_lenghts = curriculum.sample_indexes(128)
    trainer.train_dataset = DirectTasksDataset(tokenizer, task_lenghts)
    total_examples += len(trainer.train_dataset)
    trainer.train()

    # for each steps length, check whether model answers correctly
    task_steps_limit = len(curriculum.avg_scores)
    task_lenghts = list(range(task_steps_limit))
    eval_dataset = DirectTasksDataset(tokenizer, task_lenghts)
    scores = [evaluate_example(model, ex) for ex in eval_dataset]
    curriculum.update_scores(scores)
    print(
        f"{total_examples:9}  seq.len.: {task_steps_limit:3}  "
        + get_accuracy_bar(scores)
    )

    if np.mean(scores) > 0.9:
        # all answers were correct, so increase difficulty level
        curriculum.increment_limit()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


      128  seq.len.:   1  « »
      256  seq.len.:   1  «█»
      384  seq.len.:   2  « █»
      512  seq.len.:   2  «██»
      640  seq.len.:   3  «██ »
      768  seq.len.:   3  «███»
      896  seq.len.:   4  «██  »
     1024  seq.len.:   4  «███ »
     1152  seq.len.:   4  «███ »
     1280  seq.len.:   4  «███ »
     1408  seq.len.:   4  «███ »
     1536  seq.len.:   4  «████»
     1664  seq.len.:   5  «████ »
     1792  seq.len.:   5  «████ »
     1920  seq.len.:   5  «███  »
     2048  seq.len.:   5  «█ ██ »
     2176  seq.len.:   5  «████ »
     2304  seq.len.:   5  «██ █ »
     2432  seq.len.:   5  «████ »
     2560  seq.len.:   5  «███  »
     2688  seq.len.:   5  «█████»
     2816  seq.len.:   6  «████  »
     2944  seq.len.:   6  «█████ »
     3072  seq.len.:   6  «█████ »
     3200  seq.len.:   6  «█████ »
     3328  seq.len.:   6  «█████ »
     3456  seq.len.:   6  «████ █»
     3584  seq.len.:   6  «████  »
     3712  seq.len.:   6  «█████ »
     3840  seq.len.:   6  «███