In [1]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from datasets import load_from_disk
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

ds = load_from_disk("data/tinystories")

wiki_ds = load_from_disk("data/wiki")
train = ds["train"]
val = ds["validation"]

In [3]:
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets.arrow_dataset import Dataset
from typing import Generator

class StreamingTokenDataset(IterableDataset):
    def __init__(
            self, 
            dataset: Dataset,
            tokenizer: AutoTokenizer,
            context_size=256, 
            buffer_size=10_000
        ) -> None:

        self.dataset = dataset
        self.tokenizer = tokenizer

        self.context_size = context_size
        self.buffer_size = buffer_size

    def _token_stream(self) -> Generator[int, None, None]:
        for example in self.dataset:
            tokens = self.tokenizer.encode(example["text"])
            yield from tokens
            yield 0

    def _chunk_stream(self):
        buf = []
        for token in self._token_stream():
            buf.append(token)
            if len(buf) > self.context_size:

                context_batch = buf[:self.context_size + 1]

                input_tokens = torch.tensor(context_batch[:self.context_size], dtype=torch.long)
                pred_tokens = torch.tensor(context_batch[1:], dtype=torch.long)
                yield input_tokens, pred_tokens
                buf = buf[self.context_size:]

    def __iter__(self):
        yield from self._chunk_stream()

In [4]:
batch_size = 142

train_dataset = StreamingTokenDataset(train, tokenizer)
val_dataset = StreamingTokenDataset(val, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=4) # loader length: 4654
wiki_loader = DataLoader(StreamingTokenDataset(wiki_ds, tokenizer), batch_size=4) # loader length: 244

In [5]:
from utils.gpt import GPTDecoder

vocab_size = tokenizer.vocab_size
embed_dim = 256
num_heads = 8
ff_hidden_dim = 2048
num_layers = 6
context_length = 256
dropout = 0.1
window = (64, 0)

gpt = GPTDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    context_length=context_length,
    dropout=dropout,
)

In [6]:
gpt.gradient_checkpointing_enable()

In [7]:
def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [8]:
class GPUMonitor:
    def __init__(self):
        self.reset()

    def reset(self):
        torch.cuda.reset_peak_memory_stats()
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)
        self.start_event.record()
    
    def stop_timing(self):
        self.end_event.record()
        torch.cuda.synchronize()  # Critical: Wait for GPU
        return self.start_event.elapsed_time(self.end_event) / 1000.0 # ms to sec

    def get_peak_memory(self):
        return torch.cuda.max_memory_allocated() / (1024 ** 3) # GB

PROFILE_START = 5
PROFILE_END = 25 
measurements = {
    "forward_mem": [],
    "step_peak_mem": [],
    "step_times": []
}

In [9]:

import torch.nn as nn
from tqdm import tqdm
from torch.amp import autocast
import time

epochs = 1
grad_clip = 10.0

device = torch.device(choose_device())
print(f"Training on device: {device}")

gpt.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW(gpt.parameters())
monitor = GPUMonitor()

for epoch in range(1, epochs + 1):
    gpt.train()
    total_loss = 0.0

    progress = tqdm(enumerate(train_loader), total=15_000, desc=f"Epoch {epoch}/{epochs}")

    epoch_start_time = time.time()

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        is_profiling = (i >= PROFILE_START) and (i < PROFILE_END)

        optimizer.zero_grad()
        if is_profiling:
            monitor.reset()

        with autocast('cuda', dtype=torch.bfloat16): 
            out = gpt(batch_x)
            loss = criterion(
                out.view(-1, out.size(-1)),
                batch_y.view(-1),
            )
        
        if is_profiling:
            measurements["forward_mem"].append(monitor.get_peak_memory())

        loss.backward()
        
        if is_profiling:
            measurements["step_peak_mem"].append(monitor.get_peak_memory())

        torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)
        optimizer.step()

        # --- D. Record Time ---
        if is_profiling:
            step_time = monitor.stop_timing()
            measurements["step_times"].append(step_time)

        if i == PROFILE_END:
            avg_fwd_mem = sum(measurements["forward_mem"]) / len(measurements["forward_mem"])
            avg_peak_mem = sum(measurements["step_peak_mem"]) / len(measurements["step_peak_mem"])
            avg_time = sum(measurements["step_times"]) / len(measurements["step_times"])
            
            # Use standard print, but we might need to break the tqdm line cleanly
            progress.write("\n" + "="*40)
            progress.write(f"PROFILING COMPLETE (Steps {PROFILE_START}-{PROFILE_END})")
            progress.write(f"Avg Forward Memory: {avg_fwd_mem:.4f} GB")
            progress.write(f"Avg Peak Step Memory: {avg_peak_mem:.4f} GB")
            progress.write(f"Avg Step Time: {avg_time:.4f} sec")
            progress.write("="*40 + "\n")

        loss_val = loss.detach().float().item()
        total_loss += loss_val
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({
            "loss": f"{avg_loss:.4f}",
            "lr": optimizer.param_groups[0]["lr"],
        })

    total_epoch_time = time.time() - epoch_start_time
    print(f"Epoch {epoch} done | Average training loss: {avg_loss:.4f}")
    print(f"Perplexity on training data: {torch.exp(torch.tensor(avg_loss)):.4f}")
    print(f"Total Epoch Time: {total_epoch_time:.2f} sec")
    
    # --- Print Profiling Statistics ---
    if len(measurements["step_times"]) > 0:
        avg_fwd_mem = sum(measurements["forward_mem"]) / len(measurements["forward_mem"])
        avg_peak_mem = sum(measurements["step_peak_mem"]) / len(measurements["step_peak_mem"])
        avg_time = sum(measurements["step_times"]) / len(measurements["step_times"])
        
        print("\n" + "="*40)
        print(f"PROFILING RESULTS (Steps {PROFILE_START}-{PROFILE_END})")
        print(f"Avg Forward Memory: {avg_fwd_mem:.4f} GB")
        print(f"Avg Peak Step Memory: {avg_peak_mem:.4f} GB")
        print(f"Avg Step Time: {avg_time:.4f} sec")
        print("="*40 + "\n")

torch.save(gpt.state_dict(), f"data/gradient_checkpointing_final.pt")
print("Training complete. Model saved to gradient_checkpointing_final.pt")

Training on device: cuda


Epoch 1/1:   0%|          | 0/15000 [00:00<?, ?it/s]

Epoch 1/1:   0%|          | 13/15000 [00:11<3:32:03,  1.18it/s, loss=56.9005, lr=0.001]Token indices sequence length is longer than the specified maximum sequence length for this model (1106 > 1024). Running this sequence through the model will result in indexing errors
Epoch 1/1:   0%|          | 25/15000 [00:21<3:33:17,  1.17it/s, loss=44.3712, lr=0.001]


PROFILING COMPLETE (Steps 5-25)
Avg Forward Memory: 14.1643 GB
Avg Peak Step Memory: 20.9563 GB
Avg Step Time: 0.7282 sec



Epoch 1/1:  87%|████████▋ | 13039/15000 [3:04:08<27:41,  1.18it/s, loss=2.6244, lr=0.001]  

Epoch 1 done | Average training loss: 2.6244
Perplexity on training data: 13.7960
Total Epoch Time: 11048.35 sec

PROFILING RESULTS (Steps 5-25)
Avg Forward Memory: 14.1643 GB
Avg Peak Step Memory: 20.9563 GB
Avg Step Time: 0.7282 sec

Training complete. Model saved to gradient_checkpointing_final.pt





In [10]:
gpt.eval()
val_loss = 0.0

val_progress = tqdm(enumerate(val_loader), total=4654, desc="Validation")

with torch.no_grad():
    for i, (batch_x, batch_y) in val_progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        with autocast('cuda', dtype=torch.bfloat16): 
            out = gpt(batch_x)
            loss = criterion(
                out.view(-1, out.size(-1)),
                batch_y.view(-1),
            )

        val_loss += loss.item()
        val_progress.set_postfix({"val_loss": f"{val_loss / (i + 1):.4f}"})

avg_val_loss = val_loss / (i + 1 )
val_perplexity = torch.exp(torch.tensor(avg_val_loss))

print("\n" + "="*40)
print(f"VALIDATION RESULTS")
print(f"Validation Loss: {avg_val_loss:.4f}")
print(f"Validation Perplexity: {val_perplexity:.4f}")
print("="*40 + "\n")

Validation: 100%|██████████| 4654/4654 [01:02<00:00, 74.56it/s, val_loss=1.8878]


VALIDATION RESULTS
Validation Loss: 1.8878
Validation Perplexity: 6.6051






In [11]:
gpt.eval()
val_loss = 0.0

val_progress = tqdm(enumerate(wiki_loader), total=244, desc="Validation_held_out_dataset")

with torch.no_grad():
    for i, (batch_x, batch_y) in val_progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        with autocast('cuda', dtype=torch.bfloat16): 
            out = gpt(batch_x)
            loss = criterion(
                out.view(-1, out.size(-1)),
                batch_y.view(-1),
            )

        val_loss += loss.item()
        val_progress.set_postfix({"val_loss": f"{val_loss / (i + 1):.4f}"})

avg_val_loss = val_loss / (i + 1 )
val_perplexity = torch.exp(torch.tensor(avg_val_loss))

print("\n" + "="*40)
print(f"VALIDATION RESULTS")
print(f"Validation Loss: {avg_val_loss:.4f}")
print(f"Validation Perplexity: {val_perplexity:.4f}")
print("="*40 + "\n")

Validation_held_out_dataset: 245it [00:03, 65.34it/s, val_loss=10.9659]                         


VALIDATION RESULTS
Validation Loss: 10.9659
Validation Perplexity: 57868.9727




