In [2]:
import math
from dataclasses import dataclass
from typing import Optional

import os
os.chdir('/Users/idrishouiralami/Documents/projets_code/GPT')

import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm.auto import tqdm

from utils.masks import Masks

In [3]:
@dataclass
class TrainConfig:
    lr: float = 3e-4
    betas: tuple[float, float] = (0.9, 0.95)
    weight_decay: float = 0.05
    clip_grad_norm: float = 1.0
    pad_id: int = 0
    epochs: int = 10
    amp: bool = False  # set True on CUDA if you want mixed precision
    log_progress: bool = True

In [None]:
class Trainer:
    def __init__(self, model: nn.Module, device: Optional[str] = None, cfg: Optional[TrainConfig] = None):
        self.model = model
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model.to(self.device)

        self.cfg = cfg or TrainConfig()
        self.crit = nn.CrossEntropyLoss(ignore_index=self.cfg.pad_id)
        self.opt = AdamW(self.model.parameters(),
                         lr=self.cfg.lr,
                         betas=self.cfg.betas,
                         weight_decay=self.cfg.weight_decay)
        self.masks = Masks(pad_id=self.cfg.pad_id)
        self.scaler = torch.amp.GradScaler('cuda', enabled=self.cfg.amp)

    def _run_epoch(self, loader, train: bool) -> float:
        self.model.train(train)
        total, steps = 0.0, 0

        pbar = tqdm(loader, desc=("train epoch" if train else "val epoch"),
                    leave=False, disable=not self.cfg.log_progress)

        for src, dec_in, labels in pbar:
            src, dec_in, labels = src.to(self.device), dec_in.to(self.device), labels.to(self.device)

            # masks
            src_mask = self.masks.encoder(src)
            _, _, tgt_mask = self.masks.decoder(dec_in)

            # forward (+ AMP if enabled)
            with torch.amp.autocast("cuda", enabled=self.cfg.amp):
                enc_out = self.model.encode(src, src_mask)
                dec_out = self.model.decode(enc_out, src_mask, dec_in, tgt_mask)
                logits = self.model.project(dec_out)             # (B, T-1, V)
                loss = self.crit(logits.reshape(-1, logits.size(-1)),
                                 labels.reshape(-1))             # use reshape, not view

            if train:
                self.opt.zero_grad(set_to_none=True)
                if self.cfg.amp:
                    self.scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_grad_norm)
                    self.scaler.step(self.opt)
                    self.scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_grad_norm)
                    self.opt.step()

            total += loss.item()
            steps += 1
            if self.cfg.log_progress:
                pbar.set_postfix({"loss": f"{loss.item():.3f}"})

        pbar.close()
        return total / max(1, steps)

    @torch.no_grad()
    def evaluate(self, loader) -> float:
        return self._run_epoch(loader, train=False)

    def fit(self, train_loader, val_loader=None):
        history = []
        for epoch in range(self.cfg.epochs):
            tr = self._run_epoch(train_loader, train=True)
            if val_loader is not None:
                va = self.evaluate(val_loader)
                ppl = math.exp(va)
                print(f"epoch {epoch+1:02d} | train {tr:.3f} | val {va:.3f} | ppl {ppl:.2f}")
                history.append({"epoch": epoch+1, "train": tr, "val": va, "ppl": ppl})
            else:
                print(f"epoch {epoch+1:02d} | train {tr:.3f}")
                history.append({"epoch": epoch+1, "train": tr})
        return history

**`for src, dec_in, labels in pbar`** batch loop:
- `src`: source sequence (the dialogue before Michael speaks)
- `dec_in`: decoder input (shifted Michael response)
- `labels`: expected next tokens
- `src_mask` &  `tgt_mask`: builds proper masks for src & tgt
- `enc_out`: encode the src sequence
- `dec_out`: decode given src sequence & previous outputs
- `logits`: projects to vocab size
- `loss`: computes the CrossEntropyLoss
- `if train`: then backpropagates the weights