In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import time

In [2]:
torch.__version__

'2.8.0+cu128'

In [3]:
import tiktoken

In [4]:
GPT_CONFIG = {
    "vocab_size": 50257,  
    "context_length": 256,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False,
    "attention_type": "flash_v2"
}

In [5]:
#Device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [6]:
!nvidia-smi

Sun Aug 24 08:35:44 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.144                Driver Version: 570.144        CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:0A:00.0 Off |                  N/A |
|  0%   28C    P5             20W /  575W |       4MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [7]:
#Load file
f = "shakespeare.txt"
with open(f, "r", encoding="utf-8") as file:
    text_data = file.read()
print(f"Book Loaded Successfully. Length of book is {len(text_data)}")

Book Loaded Successfully. Length of book is 1115393


In [8]:
#Divide Test & Val data
train_ratio = 0.8
split = int(0.8*len(text_data))
train_data = text_data[:split]
val_data = text_data[split:]
print(f"Length of Train Data is {len(train_data)}")
print(f"Length of Val Data is {len(val_data)}")

Length of Train Data is 892314
Length of Val Data is 223079


In [9]:
class DatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(txt)

        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i+max_length]
            target_chunk = token_ids[i+1:i+max_length+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [10]:
def create_dataloader_v1(txt, batch_size=32, max_length=256, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = DatasetV1(txt, tokenizer, max_length, stride)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    return dataloader

In [11]:
#Creating Data Loaders
train_loader = create_dataloader_v1(
    train_data,
    batch_size=32,
    max_length=GPT_CONFIG["context_length"],
    stride=GPT_CONFIG["context_length"],
    drop_last=True,
    shuffle=True
 )

val_loader = create_dataloader_v1(
    val_data,
    batch_size=32,
    max_length=GPT_CONFIG["context_length"],
    stride=GPT_CONFIG["context_length"],
    drop_last=False,
    shuffle=False
 )

In [12]:
#Define a Multi head attention class

class MultiAttn(nn.Module):
    def __init__(self, d_in, d_out, context_length, drop_out, n_heads, qkv_bias=False):
        super().__init__()
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads"
        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(drop_out)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, n_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_values(x)

        keys = keys.view(b, n_tokens, self.n_heads, self.head_dim)
        queries = queries.view(b, n_tokens, self.n_heads, self.head_dim)
        values = values.view(b, n_tokens, self.n_heads, self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attn_score = queries@keys.transpose(2,3)
        mask_bool = self.mask.bool()[:n_tokens, :n_tokens]
        attn_score.masked_fill(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (attn_weights@values).transpose(1,2)
        context_vec = context_vec.contiguous().view(b,n_tokens,self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec

In [13]:
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

class FlashAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, context_length, drop_out, n_heads, qkv_bias=False):
        super().__init__()
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads"
        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = drop_out
        self.out_proj = nn.Linear(d_out, d_out)
        self.context_length = context_length

    def forward(self, x):
        b, n_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_values(x)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            keys = keys.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1,2)
            queries = queries.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1,2)
            values = values.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1,2)
    
            #Context Manager to select Attention type
            
            context_vec = F.scaled_dot_product_attention(queries, 
                                                         keys, 
                                                         values,
                                                         dropout_p=self.dropout if self.training else 0.0,
                                                         is_causal=True)
            
            
            context_vec = context_vec.transpose(1,2).contiguous().view(b, n_tokens, self.d_out)
            context_vec = self.out_proj(context_vec)

        return context_vec

In [14]:
torch.__version__

'2.8.0+cu128'

In [15]:
class AttentionFactory:
    """
    Factory class to create attention mechanisms.
    """

    @staticmethod
    def get_attention(attention_name):
        attention_map = {
            "multihead": MultiAttn,
            #"flash": FlashAttention,
            "flash_v2": FlashAttentionV2
        }

        return attention_map.get(attention_name.lower(), None)

In [16]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x-mean)/torch.sqrt(var+self.eps)
        #print(f"x: {torch.get_device(x)} | x_norm: {torch.get_device(x_norm)}| scale: {torch.get_device(self.scale)}|shift: {torch.get_device(self.shift)}")
        return self.scale*x_norm + self.shift

In [17]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"]),
            nn.GELU(),
            nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"])
        )

    def forward(self, x):
        return self.layers(x)

In [18]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention_type = cfg["attention_type"]
        self.attn = AttentionFactory.get_attention(cfg["attention_type"])(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            drop_out=cfg["drop_rate"],
            n_heads=cfg["n_heads"],
            qkv_bias=cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])
        

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = self.drop_resid(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut

        return x

In [19]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        #self.device = device
        self.tok_embed = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_embed = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_block = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        b, seq_len = in_idx.shape
        tok_emb = self.tok_embed(in_idx)
        pos_emb = self.pos_embed(torch.arange(seq_len, device=in_idx.device))
        x = tok_emb + pos_emb
        x = x

        x = self.drop_emb(x)
        x = self.trf_block(x)
        x = self.final_norm(x)

        logits = self.out_head(x)

        return logits

In [20]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)
    return tokenizer.decode(flat.tolist())

In [21]:
def generate(model, idx, max_new_tokens, context_size, temperature = 0.0, top_k = None, eos_id=None):

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.inference_mode():
            logits = model(idx_cond)
        logits = logits[:,-1,:]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:,-1]
            logits = torch.where(logits<min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        if temperature > 0.0:
            logits = logits / temperature

            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        if idx_next == eos_id:
            break

        idx = torch.cat((idx, idx_next), dim=1)

    return idx

In [22]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    model.to(device)
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits = model(input_batch)
        loss = nn.functional.cross_entropy(logits.flatten(0,1), target_batch.flatten())
    return loss

def calc_loss_loader(dataloader, model, device, num_batches=None):
    total_loss = 0
    if len(dataloader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(dataloader)
    else:
        num_batches = min(num_batches, len(dataloader))

    for i, (input_batch, target_batch) in enumerate(dataloader):
        if i<num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device=device)
            total_loss += loss.item()
        else:
            break

    return total_loss / num_batches

In [23]:
def train_model_simple(model, train_dataloader, val_dataloader, num_epochs, optimizer, 
                    tokenizer, eval_freq, eval_iter, start_context, device):
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = calc_loss_batch(X, y, model, device=device)
            loss.backward()
            optimizer.step()
            tokens_seen += X.numel()
            global_step += 1

        
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                model, train_dataloader, val_dataloader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
        generate_and_print_sample(
            model, tokenizer, device, start_context
        )

        if device=="cuda" and epoch==5:
            print(torch.cuda.memory_summary(device=device, abbreviated=True))

    return train_losses, val_losses, track_tokens_seen

In [24]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device=device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device=device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [25]:
def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_embed.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate(
            model=model, idx=encoded,
            max_new_tokens=50, context_size=context_size
        )
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))  # Compact print format
    model.train()

In [26]:
torch.set_float32_matmul_precision("high")
model = GPTModel(GPT_CONFIG).to(device)
tokenizer = tiktoken.get_encoding("gpt2")
model = torch.compile(model)

In [None]:
#Training starts here!!

start_time = time.time()
device = device

model = GPTModel(GPT_CONFIG)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
print(f"Using {GPT_CONFIG['attention_type']} Attention")

num_epochs = 30
train_losses, val_losses, tokens_seen = train_model_simple(
    model=model, train_dataloader=train_loader, val_dataloader=val_loader, optimizer=optimizer, device=device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="Every effort moves you", tokenizer=tokenizer
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

print("Training completed successfully! Existing training block")

#Training Ends

Using flash_v2 Attention
Ep 1 (Step 000000): Train loss 9.449, Val loss 9.439
Ep 1 (Step 000005): Train loss 7.900, Val loss 7.911
Ep 1 (Step 000010): Train loss 6.721, Val loss 6.861
Ep 1 (Step 000015): Train loss 6.300, Val loss 6.484
Ep 1 (Step 000020): Train loss 6.151, Val loss 6.381
Ep 1 (Step 000025): Train loss 6.013, Val loss 6.315
Ep 1 (Step 000030): Train loss 5.974, Val loss 6.229
Every effort moves you,                                                 
Ep 2 (Step 000035): Train loss 5.889, Val loss 6.185
Ep 2 (Step 000040): Train loss 5.870, Val loss 6.166
Ep 2 (Step 000045): Train loss 5.831, Val loss 6.148
Ep 2 (Step 000050): Train loss 5.814, Val loss 6.113
Ep 2 (Step 000055): Train loss 5.765, Val loss 6.074
Ep 2 (Step 000060): Train loss 5.724, Val loss 6.049
Every effort moves you                                                  
Ep 3 (Step 000065): Train loss 5.696, Val loss 6.017
Ep 3 (Step 000070): Train loss 5.660, Val loss 6.044
Ep 3 (Step 000075): Train loss 5.5