## Imports

In [None]:
!pip install tiktoken
!pip install torch torchvision torchaudio

In [None]:
!wget https://cdn.jsdelivr.net/gh/karpathy/nanoGPT/model.py -O model.py


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import os
import numpy as np
import time
import tqdm
from model import GPTConfig, GPT
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

torch.manual_seed(1337)

In [None]:
import wandb, os
wandb.login()

## Dataset

In [None]:
!mkdir -p data
DATASET_FILE="data/tinystories_input.txt"
DATASET_URL="https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt"
!wget $DATASET_URL -O $DATASET_FILE

In [None]:
import os, requests, tiktoken, numpy as np
from tqdm import tqdm

os.makedirs("data", exist_ok=True)
input_file_path = "data/tinystories.txt"

if not os.path.exists(input_file_path):
    url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt"
    with open(input_file_path, "w", encoding="utf-8") as f:
        f.write(requests.get(url).text)

num_lines = sum(1 for _ in open(input_file_path, "r"))
split = int(0.9 * num_lines)

enc = tiktoken.get_encoding("gpt2")

train_out = "data/tinystories_train.bin"
val_out = "data/tinystories_val.bin"
train_f = open(train_out, "wb")
val_f = open(val_out, "wb")

with open(input_file_path, "r") as f:
    for i, line in enumerate(tqdm(f, total=num_lines)):
        tokens = enc.encode(line, allowed_special={"<|endoftext|>"})
        arr = np.array(tokens, dtype=np.uint16)
        if i < split:
            arr.tofile(train_f)
        else:
            arr.tofile(val_f)

train_f.close(); val_f.close()

## AdamW Baseline

In [None]:
import os, time, torch, numpy as np, tqdm, wandb
from model import GPT, GPTConfig

LR_VALUES =  [3e-4, 1e-4, 1e-3, 3e-3, 1e-2]
WD_VALUES =  [0, 1e-1, 1e-2, 1e-3, 1e-4]

sweep_config = {
    'method': 'grid',
    'metric': {
      'name': 'val_loss',
      'goal': 'minimize'
    },
    'parameters': {
        'learning_rate': {
            'values': LR_VALUES
        },
        'weight_decay': {
            'values': WD_VALUES
        }
    }
}

train_data = np.memmap("data/tinystories_train.bin", dtype=np.uint16, mode="r")
val_data = np.memmap("data/tinystories_val.bin", dtype=np.uint16, mode="r")
vocab_size = 50304
block_size = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_iters = 10000
eval_interval = 100

model_config = GPTConfig(
    block_size=block_size,
    vocab_size=vocab_size,
    n_layer=6,
    n_head=6,
    n_embd=384,
    dropout=0.1,
    bias=False
)

def get_batch(split, batch_size=24):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+block_size+1].astype(np.int64)) for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss(model, eval_iters=100):
    model.eval()
    losses = {}
    for split in ['train', 'val']:
        loss_sum = 0
        for _ in range(eval_iters):
            xb, yb = get_batch(split)
            xb, yb = xb.to(device), yb.to(device)
            _, loss = model(xb, yb)
            loss_sum += loss.item()
        losses[split] = loss_sum / eval_iters
    model.train()
    return losses

def train_one_run():
    with wandb.init() as run:
        config = wandb.config

        run.name = f"lr_{config.learning_rate:.0e}_wd_{config.weight_decay}"
        print(f"--- Starting run: {run.name} ---")

        model = GPT(model_config).to(device)
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        best_val_loss = float('inf')

        start_time = time.time()
        for iter_num in tqdm.tqdm(range(max_iters), desc=run.name):
            xb, yb = get_batch('train')
            xb, yb = xb.to(device), yb.to(device)

            logits, loss = model(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            if iter_num % eval_interval == 0 or iter_num == max_iters - 1:
                losses = estimate_loss(model)
                elapsed = time.time() - start_time
                current_val_loss = losses["val"]

                if current_val_loss < best_val_loss:
                    best_val_loss = current_val_loss
                    checkpoint_path = os.path.join(wandb.run.dir, "best_model.pt")
                    torch.save(model.state_dict(), checkpoint_path)

                wandb.log({
                    "iter": iter_num,
                    "train_loss": loss.item(),
                    "val_loss": current_val_loss,
                    "best_val_loss": best_val_loss,
                    "elapsed_time": elapsed
                })

        ckpt_path = f"checkpoint_iter_{iter_num}.pt"
        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "iter": iter_num,
            "config": dict(wandb.config),
        }, ckpt_path)

        artifact_name = f"model-{run.name}-{run.id}"
        artifact = wandb.Artifact(
            name=artifact_name,
            type="model"
        )
        artifact.add_file(ckpt_path)
        run.log_artifact(artifact)

# Create the sweep
sweep_id = wandb.sweep(
    sweep_config,
    entity="182-proj",
    project="nano-gpt-adamw-200k"
)
print(sweep_id)

In [None]:
wandb.agent(f"182-proj/nano-gpt-adamw-200k/{sweep_id}", function=train_one_run)
print("Sweep complete.")