# GPT Training Notebook
**Purpose:** training loop and orchestration notebook using a custom `GPTModel`, `tokenizer`, `dataset`, and `config`.  
Features: streaming data support, FP16/AMP, gradient accumulation, checkpointing, TensorBoard logging, evaluation metrics (loss, perplexity, accuracy, tokens/sec, GPU usage), and generation utilities.


In [1]:
import os
import math
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader,IterableDataset
from torch.nn import functional as F
from pathlib import Path 
from torch.optim import AdamW
from tqdm.auto import tqdm

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from config import train_config
print(f"training config: {train_config}")

training config: {'vocab_size': None, 'context_len': 1024, 'emb_dim': 768, 'n_heads': 12, 'n_layers': 12, 'dropout_rate': 0.1, 'qkv_bias': False, 'batch_size': 4, 'grad_accum_steps': 8, 'learning_rate': 0.0003, 'weight_decay': 0.1, 'max_iters': 200000, 'eval_interval': 2000, 'save_interval': 5000, 'lr_warmup_iters': 2000, 'max_grad_norm': 1.0, 'use_fp16': True, 'context_stride': 1}


In [3]:
# Configuration
DATA_PATH = "./tig_dataset"
TOKENIZER_PATH = "./tokenizers/Tig_unigram_16000"
SAVE_PATH = "./saved_models/gpt"

os.makedirs(SAVE_PATH, exist_ok=True)


In [4]:
# ================================================
# 2. Load Tokenizer
# ================================================
from transformers import PreTrainedTokenizerFast, GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT usually has no pad token, use EOS

# tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_PATH)
print("Tokenizer loaded ✅")


Tokenizer loaded ✅


In [None]:
# ================================================
# 3. Prepare Dataset
# ================================================
class TextFolderDataset(Dataset):
    """Streams and tokenizes large text datasets efficiently."""
    def __init__(self, folder_path, tokenizer, block_size=1024):
        self.file_paths = list(Path(folder_path).rglob("*.txt"))
        self.tokenizer = tokenizer
        self.block_size = block_size
        print(f"Found {len(self.file_paths)} text files.")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        text = Path(self.file_paths[idx]).read_text(encoding="utf-8")
        tokens = self.tokenizer.encode(text, truncation=True, max_length=self.block_size, return_tensors="pt")
        tokens = tokens.squeeze(0)
        input_ids = tokens[:-1]
        target_ids = tokens[1:]
        return input_ids, target_ids

def collate_batch(batch):
    """Pads variable-length token sequences in a batch."""
    input_ids, target_ids = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    target_ids = torch.nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=-100)
    return input_ids, target_ids

# train_data = TextFolderDataset(DATA_PATH, tokenizer)
# train_loader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collate_batch)


Found 11 text files.


In [8]:
class StreamingTextDataset(IterableDataset):
    """
    Streams large text datasets in chunks for memory-efficient GPT training.
    Each chunk becomes a block of tokens of `block_size`.
    """
    def __init__(self, folder_path, tokenizer, block_size=1024, chunk_size=32_768):
        self.file_paths = list(Path(folder_path).rglob("*.txt"))
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.chunk_size = chunk_size

    def _read_file_in_chunks(self, file_path):
        """Yield text chunks from a file."""
        with open(file_path, "r", encoding="utf-8") as f:
            while True:
                chunk = f.read(self.chunk_size)
                if not chunk:
                    break
                yield chunk

    def _tokenize_chunk(self, chunk):
        """Tokenize chunk and split into block_size sequences."""
        tokens = self.tokenizer.encode(chunk, truncation=False, add_special_tokens=False)
        blocks = [tokens[i:i+self.block_size] for i in range(0, len(tokens), self.block_size)]
        return blocks

    def __iter__(self):
        for file_path in self.file_paths:
            for chunk in self._read_file_in_chunks(file_path):
                for block in self._tokenize_chunk(chunk):
                    if len(block) < 2:
                        continue
                    tokens = torch.tensor(block, dtype=torch.long)
                    input_ids = tokens[:-1]
                    target_ids = tokens[1:]
                    yield input_ids, target_ids


In [9]:
def collate_batch(batch):
    input_ids, target_ids = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    target_ids = torch.nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=-100)
    return input_ids, target_ids


In [10]:
train_dataset = StreamingTextDataset(DATA_PATH, tokenizer, block_size=1024, chunk_size=32_768)
train_loader = DataLoader(
    train_dataset,
    batch_size=2,           # small batches to avoid OOM
    collate_fn=collate_batch,
    num_workers=2
)
print("Dataset and DataLoader ready ✅")

Dataset and DataLoader ready ✅


In [11]:
# ================================================
# 4. Load Your GPT Model
# ================================================
from gpt_model import GPTModel  # import your model class

from config import GPT_CONFIG  # import your config dictionary

model = GPTModel(GPT_CONFIG).to(device)
print("GPT model loaded ✅")


GPT model loaded ✅


In [None]:

optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
num_epochs = 3
gradient_accumulation_steps = 4
log_interval = 50
os.makedirs("checkpoints", exist_ok=True)

def compute_perplexity(loss):
    return math.exp(loss)

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for step, (inp, tgt) in enumerate(pbar):
        inp, tgt = inp.to(device), tgt.to(device)
        logits = model(inp)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=-100)
        loss = loss / gradient_accumulation_steps
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        running_loss += loss.item() * gradient_accumulation_steps
        if step % log_interval == 0 and step > 0:
            avg_loss = running_loss / log_interval
            pbar.set_postfix(loss=f"{avg_loss:.3f}", ppl=f"{compute_perplexity(avg_loss):.2f}")
            running_loss = 0.0

    # Save checkpoint after each epoch
    ckpt_path = f"checkpoints/epoch_{epoch+1}.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"✅ Saved checkpoint: {ckpt_path}")


Epoch 1/3: 0it [00:00, ?it/s]

In [None]:
# # ================================================
# # 6. Training Loop
# # ================================================
# model.train()
# for epoch in range(num_epochs):
#     running_loss = 0.0
#     pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

#     for step, (inp, tgt) in enumerate(pbar):
#         inp, tgt = inp.to(device), tgt.to(device)
#         logits = model(inp)
#         loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=-100)
        
#         optimizer.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()

#         running_loss += loss.item()
#         if step % log_interval == 0 and step > 0:
#             avg_loss = running_loss / log_interval
#             perplexity = compute_perplexity(avg_loss)
#             pbar.set_postfix(loss=f"{avg_loss:.3f}", ppl=f"{perplexity:.2f}")
#             running_loss = 0.0

#     # Save checkpoint
#     ckpt_path = f"checkpoints/epoch_{epoch+1}.pt"
#     torch.save(model.state_dict(), ckpt_path)
#     print(f"✅ Saved checkpoint: {ckpt_path}")


Epoch 1/3:   0%|          | 0/3 [00:00<?, ?it/s]

: 

In [None]:
# ================================================
# 7. Evaluation (Perplexity)
# ================================================
@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss, total_tokens = 0, 0
    for inp, tgt in tqdm(data_loader, desc="Evaluating"):
        inp, tgt = inp.to(device), tgt.to(device)
        logits = model(inp)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            tgt.view(-1),
            ignore_index=-100,
            reduction="sum"
        )
        total_loss += loss.item()
        total_tokens += (tgt != -100).sum().item()
    return total_loss / total_tokens

val_loss = evaluate(model, train_loader)
print(f"✅ Validation Perplexity: {math.exp(val_loss):.2f}")
