In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader
import urllib.request

from gpt.char import CharDataset
from gpt.config import Config
from gpt.model import GPTModel

# 1. Data Prepare

In [2]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with urllib.request.urlopen(url) as response:
   data = response.read().decode('utf-8')

cfg = Config()
cd = CharDataset(cfg, data)
cfg.vocab_size = cd.get_vocab_size()
cfg

Config(block_size=128, batch_size=128, d_emb=768, n_heads=8, n_layers=12, drop_rate=0.1, d_mlp=4, qkv_bias=True, vocab_size=65)

# 2. Train

In [3]:
train_loader = DataLoader(cd, batch_size=cfg.batch_size, shuffle=True)
model = GPTModel(cfg)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(1):
    for batch_idx, (x, y) in enumerate(train_loader):
        # move data to device
        x, y = x.to(device), y.to(device)
        # forward pass
        logits = model(x)  # (batch_size, block_size, vocab_size)
        loss = criterion(logits, y)
        # backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print progress
        if batch_idx % 10 == 0:  # for every 10 batches
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

In [None]:
# save the model after train
torch.save(model, "gpt_model_full.pth")

# 3. Inference

In [59]:
# tokenize and decoder
def tokenize(seq):
    return torch.tensor([cd.stoi.get(x) for x in seq])

def decode(tokens):
    return ''.join([cd.itos[x.item()] for x in tokens])

In [None]:
# use the saved model for inference
model = GPTModel(cfg)
model.load_state_dict(torch.load("gpt_model.pth"))
# set the model to evaluation mode: no grad
model.eval()
# disable gradient calculation during inference
with torch.no_grad():
    model.generate(???) # change the temperature setting or top_k

In [4]:
model(cd.__getitem__(1)[0].view(1, -1)).shape

torch.Size([1, 128, 65])