# GPT Training Notebook
**Purpose:** training loop and orchestration notebook using a custom `GPTModel`, `tokenizer`, `dataset`, and `config`.  
Features: streaming data support, FP16/AMP, gradient accumulation, checkpointing, TensorBoard logging, evaluation metrics (loss, perplexity, accuracy, tokens/sec, GPU usage), and generation utilities.


In [14]:
import os
import math
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader,IterableDataset
from torch.nn import functional as F
from pathlib import Path 
from torch.optim import AdamW
from tqdm.auto import tqdm

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


In [15]:
from config import train_config
print(f"training config: {train_config}")

training config: {'vocab_size': None, 'context_len': 1024, 'emb_dim': 768, 'n_heads': 12, 'n_layers': 12, 'dropout_rate': 0.1, 'qkv_bias': False, 'batch_size': 2, 'grad_accum_steps': 8, 'learning_rate': 0.0003, 'weight_decay': 0.1, 'max_iters': 200000, 'eval_interval': 2000, 'save_interval': 5000, 'lr_warmup_iters': 2000, 'max_grad_norm': 1.0, 'use_fp16': True, 'context_stride': 1, 'num_epochs': 3}


In [16]:
# Configuration
DATA_PATH = "./tig_dataset"
TOKENIZER_PATH = "./tokenizers/Tig_unigram_16000"
SAVE_PATH = "./saved_models/gpt"

os.makedirs(SAVE_PATH, exist_ok=True)


In [17]:
# ================================================
# 2. Load Tokenizer
# ================================================
from transformers import PreTrainedTokenizerFast, GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT usually has no pad token, use EOS

# tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_PATH)
print("Tokenizer loaded ✅")


Tokenizer loaded ✅


In [18]:
# ================================================
# 4. Load Your GPT Model
# ================================================
from gpt_model import GPTModel  # import your model class

from config import GPT_CONFIG  # import your config dictionary
GPT_CONFIG["vocab_size"] = tokenizer.vocab_size  # update vocab size based on tokenizer

model = GPTModel(GPT_CONFIG).to(device)
print("GPT model loaded ✅")


GPT model loaded ✅


In [19]:
print("Tokenizer vocab size:", tokenizer.vocab_size)
print("Model embedding size:", model.tok_emb.num_embeddings)


Tokenizer vocab size: 50257
Model embedding size: 50257


In [20]:
texts = []
for file_path in Path(DATA_PATH).rglob("*eng.txt"):
    with open(file_path, "r", encoding="utf-8") as f:
        texts.append(f.read())
full_text = "\n".join(texts)
print(f"✅ Loaded {len(texts)} text files from {DATA_PATH}")

# Tokenize entire corpus once
token_ids = tokenizer(full_text, return_tensors="pt")["input_ids"].squeeze(0)
print(f"✅ Tokenized full dataset: {token_ids.shape[0]} tokens")


✅ Loaded 1 text files from ./tig_dataset


Token indices sequence length is longer than the specified maximum sequence length for this model (1204933 > 1024). Running this sequence through the model will result in indexing errors


✅ Tokenized full dataset: 1204933 tokens


In [21]:
class TokenDataset(Dataset):
    def __init__(self, tokens, context_len=1024):
        self.tokens = tokens
        self.context_len = context_len

    def __len__(self):
        return (len(self.tokens) - 1) // self.context_len

    def __getitem__(self, idx):
        start = idx * self.context_len
        end = start + self.context_len
        x = self.tokens[start:end]
        y = self.tokens[start+1:end+1]
        return x, y

dataset = TokenDataset(token_ids, context_len=train_config["context_len"])
dataloader = DataLoader(dataset, batch_size=train_config["batch_size"], shuffle=True)
print(f"✅ Dataset ready with {len(dataset)} sequences")


✅ Dataset ready with 1176 sequences


In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["learning_rate"])


In [23]:
num_epochs = train_config["num_epochs"]

In [10]:

# -------------------------------
# Dataset class for sequential token chunks
# -------------------------------
class TokenDataset(Dataset):
    def __init__(self, token_ids, context):
        self.token_ids = token_ids
        self.context = context

    def __len__(self):
        return max(0, len(self.token_ids) - self.context)

    def __getitem__(self, idx):
        input_ids = self.token_ids[idx: idx + self.context]
        target_ids = self.token_ids[idx + 1: idx + 1 + self.context]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

# -------------------------------
# Perplexity helper
# -------------------------------
def compute_perplexity(loss):
    return math.exp(loss)

In [24]:
def train():
    model.train()
    for epoch in range(train_config["num_epochs"]):
        running_loss = 0.0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{train_config['num_epochs']}")
        for step, (inp, tgt) in enumerate(pbar):
            inp, tgt = inp.to(device), tgt.to(device)

            logits, loss = model(inp, targets=tgt)
            loss = loss / train_config["gradient_accumulation_steps"]
            loss.backward()

            if (step + 1) % train_config["gradient_accumulation_steps"] == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

            running_loss += loss.item() * train_config["gradient_accumulation_steps"]
            if step % 50 == 0 and step > 0:
                avg_loss = running_loss / 50
                pbar.set_postfix(loss=f"{avg_loss:.4f}", ppl=f"{compute_perplexity(avg_loss):.2f}")
                running_loss = 0.0

        # Save checkpoint after each epoch
        ckpt_path = f"{SAVE_PATH}/epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), ckpt_path)
        print(f"✅ Saved checkpoint: {ckpt_path}")



# -------------------------------
# TRAINING CALL
# -------------------------------

print("🚀 Starting training...")
train()
print("✅ Training complete!")




🚀 Starting training...


Epoch 1/3:   0%|          | 0/588 [00:00<?, ?it/s]

: 

In [None]:
# -------------------------------
# TEXT GENERATION FUNCTION
# -------------------------------
def generate_text(prompt, max_tokens=100):
    model.eval()
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
    generated = model.generate(input_ids, max_new_tokens=max_tokens)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Example generation
prompt = "Once upon a time"
print("📝 Generated text:")
print(generate_text(prompt, max_tokens=100))

In [None]:
train(model, token_ids, device, num_epochs=num_epochs,
      batch_size=train_config["batch_size"], context=train_config["context_len"],
      lr=train_config["learning_rate"])

In [None]:
model.eval()
prompt = "Once upon a time"
prompt_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

# generate 100 new tokens
generated_ids = model.generate(prompt_ids, max_new_tokens=100)
generated_text = tokenizer.decode(generated_ids[0].tolist())
print("✅ Generated text:")
print(generated_text)


In [None]:
model.load_state_dict(torch.load("path/to/epoch_3.pth"))
model.to(device)

In [None]:

optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
num_epochs = 3
gradient_accumulation_steps = 4
log_interval = 50
os.makedirs("checkpoints", exist_ok=True)

def compute_perplexity(loss):
    return math.exp(loss)

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for step, (inp, tgt) in enumerate(pbar):
        inp, tgt = inp.to(device), tgt.to(device)
        logits = model(inp)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=-100)
        loss = loss / gradient_accumulation_steps
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        running_loss += loss.item() * gradient_accumulation_steps
        if step % log_interval == 0 and step > 0:
            avg_loss = running_loss / log_interval
            pbar.set_postfix(loss=f"{avg_loss:.3f}", ppl=f"{compute_perplexity(avg_loss):.2f}")
            running_loss = 0.0

    # Save checkpoint after each epoch
    ckpt_path = f"checkpoints/epoch_{epoch+1}.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"✅ Saved checkpoint: {ckpt_path}")


Epoch 1/3: 0it [00:00, ?it/s]

In [None]:
# # ================================================
# # 6. Training Loop
# # ================================================
# model.train()
# for epoch in range(num_epochs):
#     running_loss = 0.0
#     pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

#     for step, (inp, tgt) in enumerate(pbar):
#         inp, tgt = inp.to(device), tgt.to(device)
#         logits = model(inp)
#         loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=-100)
        
#         optimizer.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()

#         running_loss += loss.item()
#         if step % log_interval == 0 and step > 0:
#             avg_loss = running_loss / log_interval
#             perplexity = compute_perplexity(avg_loss)
#             pbar.set_postfix(loss=f"{avg_loss:.3f}", ppl=f"{perplexity:.2f}")
#             running_loss = 0.0

#     # Save checkpoint
#     ckpt_path = f"checkpoints/epoch_{epoch+1}.pt"
#     torch.save(model.state_dict(), ckpt_path)
#     print(f"✅ Saved checkpoint: {ckpt_path}")


Epoch 1/3:   0%|          | 0/3 [00:00<?, ?it/s]

: 

In [None]:
# ================================================
# 7. Evaluation (Perplexity)
# ================================================
@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss, total_tokens = 0, 0
    for inp, tgt in tqdm(data_loader, desc="Evaluating"):
        inp, tgt = inp.to(device), tgt.to(device)
        logits = model(inp)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            tgt.view(-1),
            ignore_index=-100,
            reduction="sum"
        )
        total_loss += loss.item()
        total_tokens += (tgt != -100).sum().item()
    return total_loss / total_tokens

val_loss = evaluate(model, train_loader)
print(f"✅ Validation Perplexity: {math.exp(val_loss):.2f}")
