# Encoders mini-project:
## Implementation of a numerical sequence generator

### I. Introduction

As seen in the explanation.ipnyb notebook, enven though decoders as implemented in this repository were introduced as part of a bigger architecture - the *transformer* architecture - they can be used as a standalone architecture for sequences generation. 

In this notebook, we will implement a simple sequence generator and make different tests and observation to illustrate what was said in the explanations. 

### II. Implementation of the model

First, let's import the code from model.py. The containt of this file if precisely what was done in the explanations notebook:

In [22]:
from model import SelfAttention, TransformerBlock, StandaloneDecoderBlock, StandaloneDecoder
import torch
import torch.nn as nn
import torch.nn.functional as F

### III. Configuration

In [23]:
BATCH_SIZE = 64
BLOCK_SIZE = 256
NUM_EPOCHS = 5
EMBED_SIZE = 256
NUM_LAYERS = 4
NUM_HEADS = 4
FORWARD_EXPANSION = 4
LEARNING_RATE = 5e-4
DROPOUT = 0.2
MAX_LENGTH=30
PAD_TOKEN_ID=11


### IV. Creation of the dataset

For this mini-project, we will use the tiny Shakespear dataset, avaiable at https://www.kaggle.com/datasets/thedevastator/the-bards-best-a-character-modeling-dataset

In [24]:
from torch.utils.data import Dataset
import random as rd
import numpy as np
import string

class ShakespearDataset(Dataset):
    def __init__(self,
                 folder="data/train.csv",
                 block_size=256,
                 stride=1):
        super().__init__()

        # Read file
        with open(folder, 'r', encoding='UTF-8') as f:
            data = f.read()

        # Convert to lowercase:
        data = data.lower()

        # Make vocabulary: 
        self.vocab = list(string.ascii_lowercase) + [':', ',', '\'', ';', '-', '.', '?', '!', '(', ')', ' ', '\"', "&", "\n"] # + ['<PAD>', '<SOS>', '<EOS>'] ne sera pas utile pour ce projet

        # Clean text from eventual characters that are not in the vocabulary:
        cleaned_data = ""
        for char in data:
            if char in self.vocab:
                cleaned_data += char

            
        self.vocab2idx = {token: i for i, token in enumerate(self.vocab)}
        self.idx2vocab = {i: token for token, i in self.vocab2idx.items()}

        self.block_size = block_size
        self.stride = stride

        # Encode data:
        self.ids = torch.tensor([self.vocab2idx[ch] for ch in cleaned_data], dtype=torch.long)

        self.num_samples = max(0, (len(self.ids) - (self.block_size + 1) + self.stride) // self.stride)
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.block_size
        x = self.ids[start:end]       
        y = self.ids[start+1:end+1]           
        
        return x, y
    
    def encode(self, s: str) -> torch.Tensor:
        return torch.tensor([self.vocab2idx[ch] for ch in s if ch in self.vocab2idx], dtype=torch.long)

    def decode(self, ids) -> str:
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        return "".join(self.idx2vocab[int(i)] for i in ids)


We will also need to implement the associated dataloader. We will only use a subset for faster trainings

In [25]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset, RandomSampler

train_dataset = ShakespearDataset(
    folder="data/train.csv",
    block_size=BLOCK_SIZE
)

test_dataset = ShakespearDataset(
    folder="data/test.csv",
    block_size=BLOCK_SIZE
)

N_TRAIN = 50000
N_EVAL = 10000
train_subset = Subset(train_dataset, list(range(N_TRAIN)))
test_subset = Subset(test_dataset, list(range(N_EVAL)))

train_loader = DataLoader(
    train_subset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    drop_last=True, # Drop last batch if number of samples not dividible by batch_size
)

test_loader = DataLoader(
    test_subset, 
    batch_size=BATCH_SIZE
)

### V. Training

In [26]:
import torch.optim as optim

if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"USING DEVICE: {device}")

model = StandaloneDecoder(
    trg_vocab_size=len(train_dataset.vocab),
    embed_size=EMBED_SIZE,
    num_layers=NUM_LAYERS,
    num_heads=4,
    forward_expansion=FORWARD_EXPANSION,
    dropout=DROPOUT,
    device=device,
    max_length=BLOCK_SIZE
)

model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss() 


USING DEVICE: mps


In [27]:
x, y = next(iter(train_loader))     # x, y: [B, T]
x, y = x.to(device), y.to(device)
logits = model(x)                   # attendu: [B, T, V]
print("logits:", logits.shape, "targets:", y.shape)

logits: torch.Size([64, 256, 40]) targets: torch.Size([64, 256])


In [28]:
import math
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

# -------------------------------------------------
# Helpers de vérification
# -------------------------------------------------
def check_model_device(model, device):
    param_devices = {p.device.type for p in model.parameters()}
    buffer_devices = {b.device.type for b in model.buffers()}
    if param_devices:
        assert all(d == device.type for d in param_devices), \
            f"Params not all on {device.type}: {param_devices}"
    if buffer_devices:
        assert all(d == device.type for d in buffer_devices), \
            f"Buffers not all on {device.type}: {buffer_devices}"
    print(f"✓ model devices -> params: {sorted(param_devices) or 'none'} | buffers: {sorted(buffer_devices) or 'none'}")

def check_batch(x, y, device, V=None):
    assert x.ndim == 2 and y.ndim == 2, f"Bad dims: x{tuple(x.shape)} y{tuple(y.shape)}"
    assert x.shape == y.shape, f"x.shape {x.shape} != y.shape {y.shape}"
    # Compare le type de device, pas l'objet exact
    assert x.device.type == device.type and y.device.type == device.type, \
        f"Batch not on {device.type}: x{x.device} y{y.device}"
    assert y.dtype == torch.long, f"Targets must be int64/long, got {y.dtype}"
    assert x.dtype in (torch.long, torch.int64, torch.int32), f"Inputs should be ints, got {x.dtype}"
    if V is not None:
        ymax = int(y.max().item())
        assert ymax < V, f"Target id {ymax} >= vocab size {V}"
    print(f"✓ batch ok -> {tuple(x.shape)} on {x.device} (expected {device}) dtypes: x={x.dtype}, y={y.dtype}")

def check_logits(logits, y, device):
    assert logits.device.type == device.type, f"logits on {logits.device}, expected {device}"
    assert logits.ndim == 3, f"logits must be [B,T,V], got {tuple(logits.shape)}"
    B, T, V = logits.shape
    assert (B, T) == tuple(y.shape), f"logits [B,T,V]={B,T,V} incompatible with targets {tuple(y.shape)}"
    assert logits.dtype in (torch.float16, torch.float32, torch.bfloat16), f"Unexpected dtype for logits: {logits.dtype}"
    print(f"✓ logits ok -> [B,T,V]={B,T,V} on {logits.device}, dtype={logits.dtype}")
    return B, T, V


def count_params(model):
    n = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return n, trainable

# -------------------------------------------------
# Device info (prints utiles)
# -------------------------------------------------
if torch.cuda.is_available():
    dev_name = torch.cuda.get_device_name()
    print(f"CUDA visible | device: {torch.cuda.current_device()} -> {dev_name}")
elif torch.backends.mps.is_available():
    print("MPS visible (Apple Silicon)")
else:
    print("CPU only")

# Déplacement modèle + opti/critère
model = model.to(device)
check_model_device(model, device)

total_params, trainable_params = count_params(model)
print(f"Paramètres: {total_params/1e6:.2f}M (dont entraînables: {trainable_params/1e6:.2f}M)")

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss()

# -------------------------------------------------
# Epoch loop avec vérifs
# -------------------------------------------------
def run_epoch(loader, train=True, max_batches=None, log_interval=50, desc=None, vocab_size=None):
    model.train(mode=train)
    total_loss, total_tokens = 0.0, 0
    ema_loss = None
    start_time = time.time()

    if desc is None:
        desc = "train" if train else "eval"

    pbar = tqdm(enumerate(loader), total=len(loader), desc=desc, leave=False)
    for b, (x, y) in pbar:
        if max_batches is not None and b >= max_batches:
            break

        # Move to device
        x = x.to(device, non_blocking=(device.type == "cuda"))
        y = y.to(device, non_blocking=(device.type == "cuda"))

        # Vérifs batch
        if b == 0:  # afficher une fois par epoch
            check_batch(x, y, device, V=vocab_size)

        if train:
            optimizer.zero_grad(set_to_none=True)

        # forward
        logits = model(x)  # [B, T, V]

        # Vérifs logits
        if b == 0:
            B, T, V = check_logits(logits, y, device)
        else:
            B, T, V = logits.shape

        loss = criterion(logits.reshape(B*T, V), y.reshape(B*T))

        # backward + step
        grad_norm = None
        if train:
            loss.backward()
            grad_norm = clip_grad_norm_(model.parameters(), max_norm=1.0).item()
            optimizer.step()

        # stats globales
        tokens = B * T
        total_loss += loss.item() * tokens
        total_tokens += tokens

        # EMA pour un affichage plus stable
        ema_loss = loss.item() if ema_loss is None else 0.9 * ema_loss + 0.1 * loss.item()

        # tqdm postfix (per-batch)
        postfix = {
            "loss(batch)": f"{loss.item():.4f}",
            "loss(ema)": f"{ema_loss:.4f}",
            "ppl(ema)": f"{math.exp(ema_loss):.2f}",
        }
        if grad_norm is not None:
            postfix["grad_norm"] = f"{grad_norm:.2f}"
        pbar.set_postfix(postfix)

        # prints périodiques optionnels
        if log_interval and (b + 1) % log_interval == 0:
            avg_loss_so_far = total_loss / max(total_tokens, 1)
            print(f"  step {b+1:05d} | {desc} avg_loss {avg_loss_so_far:.4f} | ppl {math.exp(avg_loss_so_far):.2f}")

    avg_loss = total_loss / max(total_tokens, 1)
    ppl = float(math.exp(avg_loss))
    elapsed = time.time() - start_time
    return avg_loss, ppl, elapsed

# -------------------------------------------------
# Entraînement
# -------------------------------------------------
EPOCHS = 10
best_val = float('inf')

print("===== Démarrage de l'entraînement =====")
print(f"device: {device} | vocab: {len(train_dataset.vocab)} | "
      f"batch_size: {getattr(train_loader, 'batch_size', 'NA')} | seq_len: {getattr(train_dataset, 'block_size', 'NA')}")
print(f"pin_memory(train_loader)={getattr(train_loader, 'pin_memory', 'NA')} | "
      f"num_workers={getattr(train_loader, 'num_workers', 'NA')}")

for epoch in range(1, EPOCHS + 1):
    print(f"\n----- Epoch {epoch:02d}/{EPOCHS} -----")
    train_loss, train_ppl, train_time = run_epoch(
        train_loader, train=True, desc=f"train[{epoch:02d}]", log_interval=100, vocab_size=len(train_dataset.vocab)
    )
    val_loss, val_ppl, val_time = run_epoch(
        test_loader, train=False, desc=f"eval [{epoch:02d}]", log_interval=None, vocab_size=len(train_dataset.vocab)
    )

    print(f"[{epoch:02d}] train loss {train_loss:.4f} | ppl {train_ppl:.2f} | time {train_time:.1f}s")
    print(f"[{epoch:02d}]   val loss {val_loss:.4f} | ppl {val_ppl:.2f} | time {val_time:.1f}s")

    # checkpoint si amélioration
    if val_loss < best_val:
        best_val = val_loss
        ckpt_path = f"best_decoder.pt"
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "metrics": {
                "train_loss": train_loss, "train_ppl": train_ppl,
                "val_loss": val_loss, "val_ppl": val_ppl,
            },
            "config": {
                "vocab_size": len(train_dataset.vocab),
                "embed_size": EMBED_SIZE,
                "num_layers": NUM_LAYERS,
                "forward_expansion": FORWARD_EXPANSION,
                "dropout": DROPOUT,
                "max_length": MAX_LENGTH,
                "block_size": BLOCK_SIZE,
            }
        }, ckpt_path)
        print(f"✓ checkpoint sauvegardé → {ckpt_path} (val_loss={val_loss:.4f})")
    else:
        print(f"↳ pas d'amélioration (best_val={best_val:.4f})")



MPS visible (Apple Silicon)
✓ model devices -> params: ['mps'] | buffers: none
Paramètres: 3.24M (dont entraînables: 3.24M)
===== Démarrage de l'entraînement =====
device: mps | vocab: 40 | batch_size: 64 | seq_len: 256
pin_memory(train_loader)=False | num_workers=0

----- Epoch 01/10 -----


train[01]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[01]:  13%|█▎        | 100/781 [01:13<07:56,  1.43it/s, loss(batch)=2.5042, loss(ema)=2.4969, ppl(ema)=12.14, grad_norm=0.21]

  step 00100 | train[01] avg_loss 2.6621 | ppl 14.33


train[01]:  26%|██▌       | 200/781 [02:25<06:55,  1.40it/s, loss(batch)=2.4304, loss(ema)=2.4274, ppl(ema)=11.33, grad_norm=0.18]

  step 00200 | train[01] avg_loss 2.5557 | ppl 12.88


train[01]:  38%|███▊      | 300/781 [03:43<06:13,  1.29it/s, loss(batch)=2.3882, loss(ema)=2.3945, ppl(ema)=10.96, grad_norm=0.22]

  step 00300 | train[01] avg_loss 2.5058 | ppl 12.25


train[01]:  51%|█████     | 400/781 [04:59<04:46,  1.33it/s, loss(batch)=2.3516, loss(ema)=2.3590, ppl(ema)=10.58, grad_norm=0.43]

  step 00400 | train[01] avg_loss 2.4725 | ppl 11.85


train[01]:  64%|██████▍   | 500/781 [06:15<02:41,  1.74it/s, loss(batch)=2.2867, loss(ema)=2.3051, ppl(ema)=10.03, grad_norm=0.29]

  step 00500 | train[01] avg_loss 2.4442 | ppl 11.52


train[01]:  77%|███████▋  | 600/781 [07:24<01:38,  1.85it/s, loss(batch)=2.2831, loss(ema)=2.2632, ppl(ema)=9.61, grad_norm=0.53] 

  step 00600 | train[01] avg_loss 2.4170 | ppl 11.21


train[01]:  90%|████████▉ | 700/781 [08:12<00:35,  2.26it/s, loss(batch)=2.2125, loss(ema)=2.2088, ppl(ema)=9.11, grad_norm=0.66]

  step 00700 | train[01] avg_loss 2.3901 | ppl 10.91


eval [01]:   1%|          | 1/157 [00:00<00:24,  6.33it/s, loss(batch)=2.3157, loss(ema)=2.3157, ppl(ema)=10.13]                 

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[01] train loss 2.3678 | ppl 10.67 | time 528.1s
[01]   val loss 2.3383 | ppl 10.36 | time 30.7s
✓ checkpoint sauvegardé → best_decoder.pt (val_loss=2.3383)

----- Epoch 02/10 -----


train[02]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[02]:  13%|█▎        | 100/781 [00:51<06:54,  1.64it/s, loss(batch)=2.0832, loss(ema)=2.1032, ppl(ema)=8.19, grad_norm=0.38]

  step 00100 | train[02] avg_loss 2.1247 | ppl 8.37


train[02]:  26%|██▌       | 200/781 [01:54<05:34,  1.74it/s, loss(batch)=2.0387, loss(ema)=2.0553, ppl(ema)=7.81, grad_norm=0.61]

  step 00200 | train[02] avg_loss 2.1005 | ppl 8.17


train[02]:  38%|███▊      | 300/781 [02:55<06:20,  1.26it/s, loss(batch)=2.0092, loss(ema)=2.0075, ppl(ema)=7.44, grad_norm=0.39]

  step 00300 | train[02] avg_loss 2.0758 | ppl 7.97


train[02]:  51%|█████     | 400/781 [04:06<04:41,  1.35it/s, loss(batch)=1.9142, loss(ema)=1.9496, ppl(ema)=7.03, grad_norm=0.55]

  step 00400 | train[02] avg_loss 2.0515 | ppl 7.78


train[02]:  64%|██████▍   | 500/781 [05:06<02:25,  1.93it/s, loss(batch)=1.9067, loss(ema)=1.9093, ppl(ema)=6.75, grad_norm=0.73]

  step 00500 | train[02] avg_loss 2.0267 | ppl 7.59


train[02]:  77%|███████▋  | 600/781 [05:59<01:52,  1.61it/s, loss(batch)=1.9057, loss(ema)=1.8665, ppl(ema)=6.47, grad_norm=0.52]

  step 00600 | train[02] avg_loss 2.0032 | ppl 7.41


train[02]:  90%|████████▉ | 700/781 [07:00<00:49,  1.65it/s, loss(batch)=1.8141, loss(ema)=1.8272, ppl(ema)=6.22, grad_norm=0.66]

  step 00700 | train[02] avg_loss 1.9799 | ppl 7.24


eval [02]:   1%|          | 1/157 [00:00<00:25,  6.16it/s, loss(batch)=2.1930, loss(ema)=2.1930, ppl(ema)=8.96]                  

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[02] train loss 1.9619 | ppl 7.11 | time 461.4s
[02]   val loss 2.2606 | ppl 9.59 | time 35.0s
✓ checkpoint sauvegardé → best_decoder.pt (val_loss=2.2606)

----- Epoch 03/10 -----


train[03]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[03]:  13%|█▎        | 100/781 [01:11<07:39,  1.48it/s, loss(batch)=1.7712, loss(ema)=1.7576, ppl(ema)=5.80, grad_norm=0.49]

  step 00100 | train[03] avg_loss 1.7674 | ppl 5.86


train[03]:  26%|██▌       | 200/781 [02:17<06:37,  1.46it/s, loss(batch)=1.7170, loss(ema)=1.7177, ppl(ema)=5.57, grad_norm=0.85]

  step 00200 | train[03] avg_loss 1.7502 | ppl 5.76


train[03]:  38%|███▊      | 300/781 [03:25<05:32,  1.45it/s, loss(batch)=1.6858, loss(ema)=1.6757, ppl(ema)=5.34, grad_norm=0.70]

  step 00300 | train[03] avg_loss 1.7312 | ppl 5.65


train[03]:  51%|█████     | 400/781 [04:32<04:10,  1.52it/s, loss(batch)=1.6177, loss(ema)=1.6431, ppl(ema)=5.17, grad_norm=0.46]

  step 00400 | train[03] avg_loss 1.7124 | ppl 5.54


train[03]:  64%|██████▍   | 500/781 [05:41<03:04,  1.52it/s, loss(batch)=1.6027, loss(ema)=1.6222, ppl(ema)=5.06, grad_norm=0.47]

  step 00500 | train[03] avg_loss 1.6959 | ppl 5.45


train[03]:  77%|███████▋  | 600/781 [06:53<02:14,  1.35it/s, loss(batch)=1.5629, loss(ema)=1.5806, ppl(ema)=4.86, grad_norm=0.71]

  step 00600 | train[03] avg_loss 1.6794 | ppl 5.36


train[03]:  90%|████████▉ | 700/781 [08:03<00:57,  1.41it/s, loss(batch)=1.5521, loss(ema)=1.5582, ppl(ema)=4.75, grad_norm=0.58]

  step 00700 | train[03] avg_loss 1.6642 | ppl 5.28


eval [03]:   0%|          | 0/157 [00:00<?, ?it/s]                                                                               

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[03] train loss 1.6524 | ppl 5.22 | time 539.3s
[03]   val loss 2.2770 | ppl 9.75 | time 38.6s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 04/10 -----


train[04]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[04]:  13%|█▎        | 100/781 [01:07<07:44,  1.47it/s, loss(batch)=1.5224, loss(ema)=1.5167, ppl(ema)=4.56, grad_norm=0.52]

  step 00100 | train[04] avg_loss 1.5278 | ppl 4.61


train[04]:  26%|██▌       | 200/781 [02:11<06:34,  1.47it/s, loss(batch)=1.5075, loss(ema)=1.4920, ppl(ema)=4.45, grad_norm=0.51]

  step 00200 | train[04] avg_loss 1.5155 | ppl 4.55


train[04]:  38%|███▊      | 300/781 [03:19<05:14,  1.53it/s, loss(batch)=1.4727, loss(ema)=1.4783, ppl(ema)=4.39, grad_norm=0.43]

  step 00300 | train[04] avg_loss 1.5056 | ppl 4.51


train[04]:  51%|█████     | 400/781 [04:25<04:12,  1.51it/s, loss(batch)=1.4492, loss(ema)=1.4512, ppl(ema)=4.27, grad_norm=0.56]

  step 00400 | train[04] avg_loss 1.4943 | ppl 4.46


train[04]:  64%|██████▍   | 500/781 [05:33<03:04,  1.53it/s, loss(batch)=1.4453, loss(ema)=1.4371, ppl(ema)=4.21, grad_norm=0.47]

  step 00500 | train[04] avg_loss 1.4842 | ppl 4.41


train[04]:  77%|███████▋  | 600/781 [06:40<02:04,  1.46it/s, loss(batch)=1.4075, loss(ema)=1.4123, ppl(ema)=4.11, grad_norm=0.53]

  step 00600 | train[04] avg_loss 1.4739 | ppl 4.37


train[04]:  90%|████████▉ | 700/781 [07:40<00:37,  2.16it/s, loss(batch)=1.3718, loss(ema)=1.3972, ppl(ema)=4.04, grad_norm=0.62]

  step 00700 | train[04] avg_loss 1.4644 | ppl 4.32


eval [04]:   1%|          | 1/157 [00:00<00:28,  5.51it/s, loss(batch)=2.4708, loss(ema)=2.4708, ppl(ema)=11.83]                 

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[04] train loss 1.4566 | ppl 4.29 | time 506.8s
[04]   val loss 2.4009 | ppl 11.03 | time 28.1s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 05/10 -----


train[05]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[05]:  13%|█▎        | 100/781 [01:11<07:28,  1.52it/s, loss(batch)=1.3536, loss(ema)=1.3724, ppl(ema)=3.94, grad_norm=0.60]

  step 00100 | train[05] avg_loss 1.3754 | ppl 3.96


train[05]:  26%|██▌       | 200/781 [02:12<05:39,  1.71it/s, loss(batch)=1.3578, loss(ema)=1.3430, ppl(ema)=3.83, grad_norm=0.48]

  step 00200 | train[05] avg_loss 1.3645 | ppl 3.91


train[05]:  38%|███▊      | 300/781 [03:21<05:14,  1.53it/s, loss(batch)=1.3546, loss(ema)=1.3357, ppl(ema)=3.80, grad_norm=0.56]

  step 00300 | train[05] avg_loss 1.3570 | ppl 3.88


train[05]:  51%|█████     | 400/781 [04:30<04:14,  1.50it/s, loss(batch)=1.3224, loss(ema)=1.3200, ppl(ema)=3.74, grad_norm=0.56]

  step 00400 | train[05] avg_loss 1.3493 | ppl 3.85


train[05]:  64%|██████▍   | 500/781 [05:38<03:16,  1.43it/s, loss(batch)=1.3164, loss(ema)=1.3046, ppl(ema)=3.69, grad_norm=0.60]

  step 00500 | train[05] avg_loss 1.3415 | ppl 3.82


train[05]:  77%|███████▋  | 600/781 [06:45<01:55,  1.57it/s, loss(batch)=1.2985, loss(ema)=1.2944, ppl(ema)=3.65, grad_norm=0.74]

  step 00600 | train[05] avg_loss 1.3340 | ppl 3.80


train[05]:  90%|████████▉ | 700/781 [07:50<00:49,  1.65it/s, loss(batch)=1.2659, loss(ema)=1.2727, ppl(ema)=3.57, grad_norm=0.68]

  step 00700 | train[05] avg_loss 1.3265 | ppl 3.77


eval [05]:   1%|          | 1/157 [00:00<00:27,  5.73it/s, loss(batch)=2.5852, loss(ema)=2.5852, ppl(ema)=13.27]                 

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[05] train loss 1.3207 | ppl 3.75 | time 510.0s
[05]   val loss 2.5340 | ppl 12.60 | time 32.0s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 06/10 -----


train[06]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[06]:  13%|█▎        | 100/781 [00:55<07:06,  1.60it/s, loss(batch)=1.2581, loss(ema)=1.2499, ppl(ema)=3.49, grad_norm=0.53]

  step 00100 | train[06] avg_loss 1.2556 | ppl 3.51


train[06]:  26%|██▌       | 200/781 [01:53<05:35,  1.73it/s, loss(batch)=1.2261, loss(ema)=1.2389, ppl(ema)=3.45, grad_norm=0.68]

  step 00200 | train[06] avg_loss 1.2480 | ppl 3.48


train[06]:  38%|███▊      | 300/781 [02:55<04:59,  1.61it/s, loss(batch)=1.2398, loss(ema)=1.2268, ppl(ema)=3.41, grad_norm=0.61]

  step 00300 | train[06] avg_loss 1.2420 | ppl 3.46


train[06]:  51%|█████     | 400/781 [03:47<03:17,  1.93it/s, loss(batch)=1.2326, loss(ema)=1.2189, ppl(ema)=3.38, grad_norm=0.58]

  step 00400 | train[06] avg_loss 1.2359 | ppl 3.44


train[06]:  64%|██████▍   | 500/781 [04:43<02:25,  1.94it/s, loss(batch)=1.2164, loss(ema)=1.1941, ppl(ema)=3.30, grad_norm=0.72]

  step 00500 | train[06] avg_loss 1.2298 | ppl 3.42


train[06]:  77%|███████▋  | 600/781 [05:36<01:47,  1.68it/s, loss(batch)=1.1787, loss(ema)=1.1931, ppl(ema)=3.30, grad_norm=0.59]

  step 00600 | train[06] avg_loss 1.2239 | ppl 3.40


train[06]:  90%|████████▉ | 700/781 [06:42<00:51,  1.56it/s, loss(batch)=1.2105, loss(ema)=1.1793, ppl(ema)=3.25, grad_norm=0.82]

  step 00700 | train[06] avg_loss 1.2175 | ppl 3.38


eval [06]:   0%|          | 0/157 [00:00<?, ?it/s]                                                                               

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[06] train loss 1.2124 | ppl 3.36 | time 455.4s
[06]   val loss 2.7100 | ppl 15.03 | time 35.8s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 07/10 -----


train[07]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[07]:  13%|█▎        | 100/781 [00:55<07:54,  1.43it/s, loss(batch)=1.1535, loss(ema)=1.1524, ppl(ema)=3.17, grad_norm=0.70]

  step 00100 | train[07] avg_loss 1.1553 | ppl 3.18


train[07]:  26%|██▌       | 200/781 [01:53<06:58,  1.39it/s, loss(batch)=1.1240, loss(ema)=1.1408, ppl(ema)=3.13, grad_norm=0.67]

  step 00200 | train[07] avg_loss 1.1520 | ppl 3.16


train[07]:  38%|███▊      | 300/781 [03:02<05:10,  1.55it/s, loss(batch)=1.1472, loss(ema)=1.1299, ppl(ema)=3.10, grad_norm=0.61]

  step 00300 | train[07] avg_loss 1.1461 | ppl 3.15


train[07]:  51%|█████     | 400/781 [04:08<04:13,  1.50it/s, loss(batch)=1.1038, loss(ema)=1.1253, ppl(ema)=3.08, grad_norm=0.63]

  step 00400 | train[07] avg_loss 1.1407 | ppl 3.13


train[07]:  64%|██████▍   | 500/781 [05:17<03:07,  1.49it/s, loss(batch)=1.1226, loss(ema)=1.1136, ppl(ema)=3.05, grad_norm=0.74]

  step 00500 | train[07] avg_loss 1.1360 | ppl 3.11


train[07]:  77%|███████▋  | 600/781 [06:23<01:58,  1.53it/s, loss(batch)=1.0874, loss(ema)=1.0922, ppl(ema)=2.98, grad_norm=0.67]

  step 00600 | train[07] avg_loss 1.1305 | ppl 3.10


train[07]:  90%|████████▉ | 700/781 [07:36<00:56,  1.44it/s, loss(batch)=1.0949, loss(ema)=1.0929, ppl(ema)=2.98, grad_norm=0.72]

  step 00700 | train[07] avg_loss 1.1256 | ppl 3.08


eval [07]:   1%|          | 1/157 [00:00<00:24,  6.45it/s, loss(batch)=2.9530, loss(ema)=2.9530, ppl(ema)=19.16]                 

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[07] train loss 1.1214 | ppl 3.07 | time 502.7s
[07]   val loss 2.8462 | ppl 17.22 | time 24.9s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 08/10 -----


train[08]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[08]:  13%|█▎        | 100/781 [00:55<06:45,  1.68it/s, loss(batch)=1.0830, loss(ema)=1.0706, ppl(ema)=2.92, grad_norm=0.78]

  step 00100 | train[08] avg_loss 1.0738 | ppl 2.93


train[08]:  26%|██▌       | 200/781 [01:52<05:49,  1.66it/s, loss(batch)=1.0427, loss(ema)=1.0584, ppl(ema)=2.88, grad_norm=0.69]

  step 00200 | train[08] avg_loss 1.0688 | ppl 2.91


train[08]:  38%|███▊      | 300/781 [02:57<05:03,  1.58it/s, loss(batch)=1.0567, loss(ema)=1.0531, ppl(ema)=2.87, grad_norm=0.74]

  step 00300 | train[08] avg_loss 1.0637 | ppl 2.90


train[08]:  51%|█████     | 400/781 [04:00<03:57,  1.60it/s, loss(batch)=1.0418, loss(ema)=1.0408, ppl(ema)=2.83, grad_norm=0.92]

  step 00400 | train[08] avg_loss 1.0593 | ppl 2.88


train[08]:  64%|██████▍   | 500/781 [05:00<02:07,  2.20it/s, loss(batch)=1.0352, loss(ema)=1.0346, ppl(ema)=2.81, grad_norm=0.77]

  step 00500 | train[08] avg_loss 1.0545 | ppl 2.87


train[08]:  77%|███████▋  | 600/781 [05:50<01:22,  2.19it/s, loss(batch)=1.0308, loss(ema)=1.0177, ppl(ema)=2.77, grad_norm=0.77]

  step 00600 | train[08] avg_loss 1.0494 | ppl 2.86


train[08]:  90%|████████▉ | 700/781 [06:36<00:37,  2.14it/s, loss(batch)=1.0087, loss(ema)=1.0147, ppl(ema)=2.76, grad_norm=0.64]

  step 00700 | train[08] avg_loss 1.0449 | ppl 2.84


eval [08]:   0%|          | 0/157 [00:00<?, ?it/s]                                                                               

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[08] train loss 1.0410 | ppl 2.83 | time 435.5s
[08]   val loss 2.9962 | ppl 20.01 | time 28.2s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 09/10 -----


train[09]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[09]:  13%|█▎        | 100/781 [00:46<05:29,  2.07it/s, loss(batch)=0.9965, loss(ema)=0.9936, ppl(ema)=2.70, grad_norm=0.71]

  step 00100 | train[09] avg_loss 0.9993 | ppl 2.72


train[09]:  26%|██▌       | 200/781 [01:34<04:34,  2.12it/s, loss(batch)=0.9788, loss(ema)=0.9839, ppl(ema)=2.67, grad_norm=0.83]

  step 00200 | train[09] avg_loss 0.9935 | ppl 2.70


train[09]:  38%|███▊      | 300/781 [02:22<03:47,  2.11it/s, loss(batch)=0.9828, loss(ema)=0.9783, ppl(ema)=2.66, grad_norm=0.74]

  step 00300 | train[09] avg_loss 0.9892 | ppl 2.69


train[09]:  51%|█████     | 400/781 [03:18<02:55,  2.17it/s, loss(batch)=0.9748, loss(ema)=0.9673, ppl(ema)=2.63, grad_norm=0.64]

  step 00400 | train[09] avg_loss 0.9841 | ppl 2.68


train[09]:  64%|██████▍   | 500/781 [04:04<02:07,  2.21it/s, loss(batch)=0.9339, loss(ema)=0.9553, ppl(ema)=2.60, grad_norm=0.63]

  step 00500 | train[09] avg_loss 0.9798 | ppl 2.66


train[09]:  77%|███████▋  | 600/781 [04:53<01:32,  1.96it/s, loss(batch)=0.9396, loss(ema)=0.9502, ppl(ema)=2.59, grad_norm=0.70]

  step 00600 | train[09] avg_loss 0.9756 | ppl 2.65


train[09]:  90%|████████▉ | 700/781 [05:40<00:36,  2.22it/s, loss(batch)=0.9466, loss(ema)=0.9454, ppl(ema)=2.57, grad_norm=0.76]

  step 00700 | train[09] avg_loss 0.9716 | ppl 2.64


eval [09]:   1%|          | 1/157 [00:00<00:25,  6.16it/s, loss(batch)=3.1978, loss(ema)=3.1978, ppl(ema)=24.48]                 

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


                                                                                                                  

[09] train loss 0.9681 | ppl 2.63 | time 376.5s
[09]   val loss 3.0543 | ppl 21.21 | time 27.2s
↳ pas d'amélioration (best_val=2.2606)

----- Epoch 10/10 -----


train[10]:   0%|          | 0/781 [00:00<?, ?it/s]

✓ batch ok -> (64, 256) on mps:0 (expected mps) dtypes: x=torch.int64, y=torch.int64
✓ logits ok -> [B,T,V]=(64, 256, 40) on mps:0, dtype=torch.float32


train[10]:  13%|█▎        | 100/781 [00:53<06:42,  1.69it/s, loss(batch)=0.9287, loss(ema)=0.9247, ppl(ema)=2.52, grad_norm=0.81]

  step 00100 | train[10] avg_loss 0.9291 | ppl 2.53


train[10]:  26%|██▌       | 200/781 [01:46<06:53,  1.41it/s, loss(batch)=0.9433, loss(ema)=0.9196, ppl(ema)=2.51, grad_norm=0.75]

  step 00200 | train[10] avg_loss 0.9256 | ppl 2.52


train[10]:  38%|███▊      | 300/781 [02:37<04:15,  1.88it/s, loss(batch)=0.8966, loss(ema)=0.9111, ppl(ema)=2.49, grad_norm=0.72]

  step 00300 | train[10] avg_loss 0.9220 | ppl 2.51


train[10]:  51%|█████     | 400/781 [03:32<03:00,  2.11it/s, loss(batch)=0.9142, loss(ema)=0.8980, ppl(ema)=2.45, grad_norm=0.73]

  step 00400 | train[10] avg_loss 0.9175 | ppl 2.50


train[10]:  64%|██████▍   | 500/781 [04:22<02:15,  2.07it/s, loss(batch)=0.8976, loss(ema)=0.8965, ppl(ema)=2.45, grad_norm=0.84]

  step 00500 | train[10] avg_loss 0.9138 | ppl 2.49


                                                                                                                                 

KeyboardInterrupt: 