In [None]:
from gpt import ModelArgs, GPTModel, generate_text_simple
import tiktoken
import torch
from gpt_train import (
    text_to_token_ids,
    token_ids_to_text,
    calc_loss_loader,
    generate,
    train_model_simple,
)
from dataset import create_dataloader


GPT_CONFIG_124M = ModelArgs(
    vocab_size=50257,
    context_length=1024,
    emb_dim=768,
    n_heads=12,
    n_layers=12,
    drop_rate=0.1,
    qkv_bias=False,
    inter_dim=3072,
)

file_path = "asserts/the-verdict.txt"
with open(file_path, "r", encoding="utf-8") as file:
    text_data = file.read()

train_ratio = 0.70
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]
print("train_data length:", len(train_data))
print("val_data length:", len(val_data))

train_loader = create_dataloader(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M.context_length,
    stride=GPT_CONFIG_124M.context_length,
    drop_last=True,
    shuffle=True,
    num_workers=0,
)
val_loader = create_dataloader(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M.context_length,
    stride=GPT_CONFIG_124M.context_length,
    drop_last=False,
    shuffle=False,
    num_workers=0,
)

In [None]:
tokenizer = tiktoken.get_encoding("gpt2")
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

with torch.no_grad():
    train_loss = calc_loss_loader(train_loader, model, device)
    val_loss = calc_loss_loader(val_loader, model, device)
print("Training loss:", train_loss)
print("Validation loss:", val_loss)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)
num_epochs = 15

train_loader, val_losses, tokens_seen = train_model_simple(
    model,
    train_loader,
    val_loader,
    optimizer,
    device,
    num_epochs=num_epochs,
    eval_freq=5,
    eval_iter=5,
    start_context="Every effort moves you",
    tokenizer=tokenizer,
)

torch.manual_seed(123)
model.eval()
token_ids = generate(
    model=model,
    idx=text_to_token_ids("Every effort moves you", tokenizer, device=device),
    max_new_tokens=15,
    context_size=GPT_CONFIG_124M.context_length,
    top_k=25,
    temperature=1.4,
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    "model_and_optimizer.pth",
)