In [None]:

%run -n 00_config.ipynb
%run -n 01_data_pipeline.ipynb
%run -n 02_gpt2_model.ipynb



In [None]:
import torch
import tiktoken
import itertools
from datetime import datetime


def estimate_loss(model, loader, device, eval_iters):
    model.eval()
    losses = torch.zeros(eval_iters)
    for i, (X, Y) in enumerate(itertools.islice(loader, eval_iters)):
        X, Y = X.to(device), Y.to(device)
        logits, loss = model(X, Y)
        losses[i] = loss.item()
    model.train()
    return losses.mean()

@torch.no_grad
def evaluate_model(model, train_loader, val_loader, device, eval_iters):
    """
    Args:
      model: to evaluate
      train_loader: training dataset iterator
      val_loader: validation dataset iterator
      eval_iters: the number of iterations to pull from the loaders

    Returns:
      dict with 'train' and 'val' loss
  """
    train_loss = estimate_loss(model, train_loader, device, eval_iters)
    val_loss = estimate_loss(model, val_loader, device, eval_iters)
    return {'train': train_loss, 'val': val_loss}

def train_model(model, train_loader, val_loader, optimizer, cfg):
    device = torch.device(cfg.device) # Ensure device object

    if device.type == 'cuda':
        print("Starting training on CUDA device. Initializing memory stats.")
        # Reset peak stats at the beginning of training if you want to track peaks per training run
        torch.cuda.reset_peak_memory_stats(device)
        print_gpu_memory_stats("Start of training_model", device)

    for epoch in range(cfg.num_epochs):
        model.train()
        running_loss = 0.0
        print(f"[{epoch + 1} / {cfg.num_epochs}]: starting at {datetime.now()}, will log every {cfg.log_interval} steps")
        if device.type == 'cuda':
            print_gpu_memory_stats(f"Start of Epoch {epoch + 1}", device)

        for i, (X, Y) in enumerate(train_loader):
            X, Y = X.to(cfg.device), Y.to(cfg.device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            logits, loss = model(X, Y)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if (i + 1) % cfg.log_iterval == 0:
                print(f"[{epoch + 1}  {i + 1:5d}]: running loss {running_loss / cfg.log_iterval:.3f}")
                running_loss = 0.0

            if (i + 1) % cfg.eval_interval == 0:
                losses = evaluate_model(model, train_loader, val_loader, device, eval_iters=cfg.eval_iters)
                print(f"[{epoch + 1}  {i + 1:5d}]: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}, eval_iters: {cfg.eval_iters}")
                running_loss = 0.0


In [None]:
cfg = GPT2Config().from_yaml("gpt2_config.yaml")
enc = tiktoken.get_encoding('gpt2')

train_loader = GeneratorWrapper(cfg, enc)
val_loader = GeneratorWrapper(cfg, enc)

try:
  del model
except:
  pass

try:
  del optimizer
except:
  pass

model = GPTModel(cfg)
model.to(cfg.device)
if cfg.compile_model:
    model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
train_model(model, train_loader, val_loader, optimizer, cfg=cfg)

print('Finished Training')