#  <font color='#FFE15D'><b>üíé Train, Evaluate, and Generate Functions (LLM-specific) </b></font><font color='#FF0B55'><b>[Walkthrough]</b></font>

# üî¥ **Import**

In [None]:
import os
import sys
import time
import math
import json
import random
from tqdm import tqdm
from pprint import pprint
from itertools import cycle
from termcolor import colored
from dataclasses import dataclass
from prettytable import PrettyTable

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from datasets import load_dataset
from tokenizers import Tokenizer

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

from torchmetrics import MeanMetric

# üî¥ **Utils**

In [None]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]
    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)

In [None]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [None]:
# Benchmarking function
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# üî¥ **Init**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
seq_len = 128 # Desired sequence length for each row

# üî¥ **Dataset**

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
dataset

In [None]:
tokenizer = Tokenizer.from_file("bpe-tokenizer_tinystories.json")
tokenizer

In [None]:
# Load tokens from pytorch file
train_token_ids = torch.load('tokenized-train-samples_vocab-10k.pt')
valid_token_ids = torch.load('tokenized-valid-samples_vocab-10k.pt')

print("üìä Number of Tokens")
print(f"üîπ Train: {len(train_token_ids):,} tokens")
print(f"üîπ Valid: {len(valid_token_ids):,} tokens")

In [None]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample.long()#[:-1], sample[1:]

# üî¥ **Model**

## üü† Multi Head Attention

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_size = self.n_embd // self.n_head

        self.qkv_proj = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.residual = True

    def forward(self, x):
        B, T, C = x.shape
        # QKV linear
        q, k, v = self.qkv_proj(x).view(B, T, 3*self.n_head, self.head_size).transpose(1, 2).chunk(3, dim=-3)
        # Scaled Dot Product Attention using pytorch
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        # Reshape and final projection
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

## üü† Feed Forward (MLP)

In [None]:
class FeedForward(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.f_expnd = config.f_expnd

        self.up_proj = nn.Linear(self.n_embd, int(self.f_expnd*self.n_embd), bias=False)
        self.down_proj = nn.Linear(int(self.f_expnd*self.n_embd), self.n_embd, bias=False)
        self.down_proj.residual = True

    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

## üü† Decoder Block

In [None]:
class DecoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        # Multi Head Attention
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mha = MultiHeadAttention(config)
        # Feed Forward Neural Network
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

## üü† GPT

In [None]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd) # Token embedding
        self.wpe = nn.Embedding(config.max_seq_len, config.n_embd) # Position embedding
        self.decoders = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]) # Decoders
        self.lnf = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Classifier
        self.lm_head.weight = self.wte.weight # Weight tying

        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'residual'):
                std *= (2*self.config.n_layer)**-0.5
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(self, idx):
        B, T = idx.shape
        # Token Embedding + Position Embedding
        x = self.wte(idx) + self.wpe(torch.arange(T, device=device))
        # Decoders
        for decoder in self.decoders:
            x = decoder(x)
        # Classifier
        x = self.lnf(x)
        logits = self.lm_head(x)
        return logits

## üü† Config

In [None]:
@dataclass
class GPTConfig:
    vocab_size: int = 50257 # number of tokens
    max_seq_len: int = 1024 # max sequence length
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    f_expnd: int = 4 # expansion factor in mlp

# üî¥ **Functions ‚öôÔ∏è**

## üü† Temp

In [None]:
torch.manual_seed(1337)

seq_len = 128
train_set = TinyStoriesDataset(train_token_ids, seq_len)
valid_set = TinyStoriesDataset(valid_token_ids, seq_len)

batch_size = 192
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)#, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)#, num_workers=4)

print(f"üìä Number of Batches")
print(f"üîπ Train: {len(train_loader):,} batches")
print(f"üîπ Valid: {len(valid_loader):,} batches")

In [None]:
model = GPT(
    GPTConfig(
        max_seq_len=256,
        vocab_size=10_000,
        n_embd=128,
        n_layer=8,
        n_head=16
        )
    ).to(device)

print(model)
num_trainable_params(model)

In [None]:
learning_rate = 6e-4
weight_decay = 0.1
use_fused = True

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.95),
    weight_decay=weight_decay,
    fused=use_fused
    )

## üü† Logger

In [None]:
# Logger class for saving and plotting training logs
class Logger:
    """
    Manages training history logging, saving to disk, and plotting learning curves.
    """
    def __init__(self, log_dir='logs', run_name='default_run'):
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        self.run_name = run_name
        self.history = {
            'train_loss': [],
            'valid_loss': [],
            'best_loss_valid': float('inf'),
            'seen_tokens': []
        }

    def log(self, train_loss, valid_loss, seen_tokens):
        self.history['train_loss'].append(train_loss)
        self.history['valid_loss'].append(valid_loss)
        self.history['seen_tokens'].append(seen_tokens)

    def save(self):
        # Save history
        file_path = os.path.join(self.log_dir, f'{self.run_name}.json')
        with open(file_path, 'w') as f:
            json.dump(self.history, f, indent=4)
        # Save best model and optimizer
        current_loss_valid = self.history['valid_loss'][-1]
        if current_loss_valid < self.history['best_loss_valid']:
            log = dict(model=model.state_dict(), optimizer=optimizer)
            torch.save(log, f'{self.log_dir}/best-model.pt')
            self.history['best_loss_valid'] = current_loss_valid
            print("‚úÖ Model Saved!")

    def plot(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.history['seen_tokens'], self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['seen_tokens'], self.history['valid_loss'], label='Valid Loss')
        plt.xlabel('Seen Tokens')
        plt.ylabel('Loss')
        plt.title(f'Training Curve: {self.run_name}')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, f'{self.run_name}_curve.png'))
        plt.show()

## üü† Train ‚û∞

In [None]:
# Trainer class to manage model training, evaluation and reporting
class LLMTrainer:
    """
    Trainer handles training loops, periodic evaluation, logging, and sample generation.
    """
    def __init__(self, model, optimizer, train_loader, valid_loader,
                 loss_fn=F.cross_entropy, device='cuda',
                 total_tokens=10_000_000, log_interval_tokens=1_000_000,
                 generation=None):

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.loss_fn = loss_fn
        self.device = device

        self.seen_tokens = 0
        self.total_tokens = total_tokens
        self.token_eval_counter = 0
        self.log_interval_tokens = log_interval_tokens

        self.logger = Logger(log_dir='logs', run_name='gpt2_tinystories')
        self._print_config_summary()

        self.generation = generation

    def train(self):
        """
        Main training loop that stops when total token count is reached.
        """
        # # Initial evaluation before any training
        # initial_loss = self.evaluate()
        # self.logger.log(initial_loss, initial_loss, 0)
        # print(f"üë∂ [Initial] Train Loss (Untrained Model): {initial_loss:.4f}\n")

        loss_train = MeanMetric()
        self.model.train()
        train_iter = cycle(self.train_loader)

        batches = 0
        total_time_elapsed = 0
        start_time = time.time()

        with tqdm(total=self.total_tokens, desc="Training", unit="t") as pbar:
            while self.seen_tokens < self.total_tokens:
                # Get inputs
                inputs = next(train_iter).to(self.device)
                # Forward pass
                logits = self.model(inputs[:, :-1])
                # Calculate loss
                loss = self.loss_fn(logits.view(-1, logits.shape[-1]), inputs[:, 1:].flatten())
                # Backward pass
                loss.backward()
                # Clip gradients
                nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                # Update model
                self.optimizer.step()
                self.optimizer.zero_grad()
                # Calc running loss
                loss_train.update(loss.item(), inputs.shape[0])

                num_tokens_this_batch = inputs[:, :-1].numel()
                self.seen_tokens += num_tokens_this_batch
                self.token_eval_counter += num_tokens_this_batch
                batches += 1
                elapsed = time.time() - start_time
                batches_per_sec = batches / elapsed

                pbar.set_postfix({
                    "B/S": f"{batches_per_sec:.2f}",
                    "Loss": f"{loss_train.compute().item():.4f}",
                    "LR": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                })
                pbar.update(num_tokens_this_batch)

                if (self.token_eval_counter >= self.log_interval_tokens) or (self.seen_tokens >= self.total_tokens):
                    # Evaluate
                    loss_valid = self.evaluate()
                    print(f"\nValid Loss: {loss_valid:.4f}")
                    # Log
                    self.logger.log(loss_train.compute().item(), loss_valid, self.seen_tokens)
                    self.logger.save()
                    # Generate
                    if self.generation:
                        self.generate()
                    # Reset
                    self.token_eval_counter = 0
                    batches = 0
                    start_time = time.time()

        self.logger.plot()

    def evaluate(self):
        """
        Evaluate model on validation set.
        """
        loss_valid = MeanMetric()
        self.model.eval()
        with torch.no_grad():
            for inputs in self.valid_loader:
                inputs = inputs.to(self.device)
                logits = self.model(inputs[:, :-1])
                loss = self.loss_fn(logits.view(-1, logits.shape[-1]), inputs[:, 1:].flatten())
                loss_valid.update(loss.item(), inputs.shape[0])
        return loss_valid.compute().item()

    def generate(self):
        """
        Generate and print text samples from the model.
        """
        generated_texts = []
        for prompt in self.generation.prompts:
            gen_text = generate(
                self.model, self.generation.tokenizer, prompt,
                n_rep=self.generation.n_rep,
                max_seq_len=self.generation.max_seq_len,
                T=self.generation.T, top_k=self.generation.top_k,
                seed=self.generation.seed)
            generated_texts.append(gen_text)
        # TODO: Save
        # Print
        # print(150*'.')
        # item = 0
        # prompt0 = self.generation.prompts[item]
        # for gen_text in generated_texts[item]:
        #     print(colored(f"\n{prompt0}", "green"), end=' ')
        #     print(colored(f"{gen_text[len(prompt0):]}", "cyan"))
        #     print(150*'.')
        # print()
        item = 0
        prompt0 = self.generation.prompts[item]
        gen_text0 = generated_texts[item][0]
        print(colored(f"\n{prompt0}", "green"), end=' ')
        print(colored(f"{gen_text0[len(prompt0):]}", "cyan"))
        print()

    def _print_config_summary(self):
        """
        Print a summary table of training configuration.
        """
        table = PrettyTable()
        table.title = "Training Configuration Summary"
        table.field_names = ["Component", "Details"]
        # Model
        table.add_row(["Model Type", str(self.model.config).replace("Config", "")])
        # Optimizer
        optimizer_name = self.optimizer.__class__.__name__
        optimizer_params = ', '.join([f"{k}={v}" for k, v in self.optimizer.defaults.items() if k in ["lr", "betas", "weight_decay", "fused"]])
        optimizer_display = f"{optimizer_name}({optimizer_params})"
        table.add_row(["Optimizer", optimizer_display])
        # Parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        te_params = self.model.wte.weight.numel()
        table.add_row(["Total Parameters (Tr+TE)", f"{total_params:,} ({total_params-te_params:,}+{te_params:,})"])

        table.add_row(["Loss Function", self.loss_fn.__name__ if hasattr(self.loss_fn, '__name__') else str(self.loss_fn)])
        table.add_row(["Batch Shape", f"{self.train_loader.batch_size}x{self.train_loader.dataset[0].shape[-1]-1}"])
        table.add_row(["Device", self.device])
        table.add_row(["Max Tokens", f"{self.total_tokens:,}"])
        table.add_row(["Log Interval Tokens", f"{self.log_interval_tokens:,}"])
        print(table)

In [None]:
torch.cuda.empty_cache()

In [None]:
@dataclass
class GenerationConfig:
    tokenizer: Tokenizer
    prompts: list[str]
    T: float = 0.9
    max_seq_len: int = 128
    top_k: int = 10
    n_rep: int = 3
    seed: int = 42


trainer = LLMTrainer(
    model, optimizer, train_loader, valid_loader,
    total_tokens=100_000, log_interval_tokens=50_000,
    generation=GenerationConfig(tokenizer=tokenizer, prompts=["In last", "One day"]))

In [None]:
trainer.train()

### üü° Temp

In [None]:
from torch.utils.data import TensorDataset
loader = DataLoader(TensorDataset(torch.arange(4)), batch_size=2)

In [None]:
from itertools import cycle
iterr = cycle(loader)

In [None]:
next(iterr)

In [None]:
with tqdm(total=100, desc="Training", unit='t') as pbar:
    for i in range(10):
        time.sleep(0.1)
        pbar.set_postfix({'loss': torch.randn(1).item()})
        pbar.update(10)

In [None]:
from prettytable import PrettyTable

# Create a table
table = PrettyTable()

# Add columns
table.field_names = ["Name", "Position", "Age"]

# Add rows
table.add_row(["Alice", "Manager", 35])
table.add_row(["Bob", "Data Analyst", 29])
table.add_row(["Charlie", "Engineer", 32])

# Print the table
print(table)

## üü† Generate