In [1]:
import logging

import torch
from transformers import GPT2Tokenizer

from src import oasst, text_util
from src.model import GPT2
from src.trainer import Trainer, TrainerConfig

logging.basicConfig(level=logging.INFO)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
N_SAMPLES_TRAIN = 2000

trainer_config = TrainerConfig(
    batch_size=8,
    gradient_acc_steps=1,
    log_interval=8,
    compile=False,
    base_learning_rate=1e-4,
    min_learning_rate=1e-6,
    lr_step_size=100,
    lr_gamma=0.75,
    weight_decay=0.01,
    betas=(0.9, 0.95),
    grad_clip=1.0,
    num_workers=0,
    prefetch_factor=None,
    pin_memory=False,
    validation_samples=100,
    validation_interval=200,
    generate_sample_prompts=[
        "How do I bake a cake?",
        "What are the best attractions in Rome, Italy?",
        "What does an architect do?"
    ],
    generate_max_tokens=200,
    generate_temperature=1.0,
    generate_top_k=50,
)

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
text_util.add_pad_token_to_tokenizer(tokenizer)

In [5]:
train_dataset, validation_dataset = oasst.load_oasst_dataset("oasst1", tokenizer)

Extracted and parsed 20147 conversations
Extracted and parsed 1002 conversations


In [6]:
model = GPT2.from_pretrained("gpt2", override_args={"dropout": 0.1})

INFO:src.model:Initializing a pre-trained gpt2 model...
INFO:src.model:Overriding dropout to 0.1
INFO:src.model:Initialized GPT with 124.44 M parameters (of which 38.60 M in embeddings)
INFO:src.model:Loading pre-trained weights from HuggingFace...


In [7]:
fine_tuneable = model.to_fine_tuneable()
fine_tuneable.add_padding_token()

INFO:src.model:Initialized GPT with 124.44 M parameters (of which 38.60 M in embeddings)


In [8]:
trainer =  Trainer(
    trainer_config, fine_tuneable, tokenizer, train_dataset, validation_dataset, DEVICE
)

In [None]:
trainer.train(N_SAMPLES_TRAIN)

INFO:src.trainer:Staring model training for 2000 samples...


🔄 iter:      0 │ 📊 samples:        8 │ 📉 loss: 94.1180 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      1 │ 📊 samples:       16 │ 📉 loss: 67.4277 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      2 │ 📊 samples:       24 │ 📉 loss: 49.0198 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      3 │ 📊 samples:       32 │ 📉 loss: 38.2399 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      4 │ 📊 samples:       40 │ 📉 loss: 31.7503 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      5 │ 📊 samples:       48 │ 📉 loss: 27.3478 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      6 │ 📊 samples:       56 │ 📉 loss: 24.0879 │ 📈 lr:  1.00e-04 │ ⚡    1 samples/s
🔄 iter:      7 │ 📊 samples:       64 │ 📉 loss: 21.5943 │ 📈 lr:  1.00e-04 │ ⚡    0 samples/s
🔄 iter:      8 │ 📊 samples:       72 │ 📉 loss: 10.3398 │ 📈 lr:  1.00e-04 │ ⚡    1 samples/s
🔄 iter:      9 │ 📊 samples:       80 │ 📉 loss:  5.6349 │ 📈 lr:  1.00e-04 │ ⚡    1 samples/s
🔄 iter:     10 │ 📊 samples:       88 │ 📉 loss:  4.5529 │ 📈 lr:  1.00e-04 │ ⚡    