<a href="https://colab.research.google.com/github/drhemanm/EvoTransformer-PIQA-Benchmark/blob/main/Copy_of_OptimizedEVo_for_PIQA_40_million_param.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# EvoTransformer PIQA - STABLE & PROVEN (<50M params, 70%+ target)
# Copy -> Paste -> Run in Colab

# ---------------------------
# 1) Install dependencies
# ---------------------------
print("Installing packages (may take a minute)...")
import sys
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "transformers==4.40.1", "torch", "tqdm", "tensorflow-datasets", "sentencepiece"])
print("Installed.")

# ---------------------------
# 2) Imports
# ---------------------------
import os
import math
import random
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from transformers import AutoTokenizer
import tensorflow_datasets as tfds

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------------------
# 3) PROVEN Config
# ---------------------------
CONFIG = {
    "max_len": 128,
    "d_model": 512,
    "num_heads": 8,
    "ff_dim": 2048,
    "depth": 8,
    "dropout": 0.15,             # Back to original
    "activation": "gelu",
    "batch_size": 16,
    "epochs": 30,
    "lr": 3e-5,                  # Conservative, stable LR
    "min_lr": 1e-6,
    "warmup_epochs": 3,
    "weight_decay": 0.01,
    "seed": 42,
    "train_samples": None,
    "val_samples": None,
}

torch.manual_seed(CONFIG["seed"])
random.seed(CONFIG["seed"])

# ---------------------------
# 4) Tokenizer
# ---------------------------
TOKENIZER_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, use_fast=True)

# ---------------------------
# 5) TFDS loader
# ---------------------------
def load_piqa_tfds(split="train", limit=None):
    ds = tfds.load("piqa", split=split)
    out = []
    for ex in tfds.as_numpy(ds):
        goal = ex["goal"].decode("utf-8") if isinstance(ex["goal"], (bytes, bytearray)) else ex["goal"]
        sol1 = ex["sol1"].decode("utf-8") if isinstance(ex["sol1"], (bytes, bytearray)) else ex["sol1"]
        sol2 = ex["sol2"].decode("utf-8") if isinstance(ex["sol2"], (bytes, bytearray)) else ex["sol2"]
        label = int(ex["label"])
        out.append({"goal": goal, "sol1": sol1, "sol2": sol2, "label": label})
        if limit and len(out) >= limit:
            break
    return out

print("Downloading PIQA...")
train_raw = load_piqa_tfds("train", limit=CONFIG["train_samples"])
val_raw = load_piqa_tfds("validation", limit=CONFIG["val_samples"])
print(f"Train: {len(train_raw)}, Val: {len(val_raw)}")

# ---------------------------
# 6) Dataset (NO augmentation)
# ---------------------------
class PIQAPairDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=128):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        t1 = ex["goal"] + " [SEP] " + ex["sol1"]
        t2 = ex["goal"] + " [SEP] " + ex["sol2"]

        enc1 = self.tokenizer(t1, truncation=True, max_length=self.max_len,
                             padding="max_length", return_tensors="pt")
        enc2 = self.tokenizer(t2, truncation=True, max_length=self.max_len,
                             padding="max_length", return_tensors="pt")

        return {
            "input_ids_1": enc1["input_ids"].squeeze(0),
            "attn_mask_1": enc1["attention_mask"].squeeze(0),
            "input_ids_2": enc2["input_ids"].squeeze(0),
            "attn_mask_2": enc2["attention_mask"].squeeze(0),
            "label": torch.tensor(ex["label"], dtype=torch.long),
        }

def collate_fn(batch):
    return {
        "input_ids_1": torch.stack([b["input_ids_1"] for b in batch]),
        "attn_mask_1": torch.stack([b["attn_mask_1"] for b in batch]),
        "input_ids_2": torch.stack([b["input_ids_2"] for b in batch]),
        "attn_mask_2": torch.stack([b["attn_mask_2"] for b in batch]),
        "label": torch.stack([b["label"] for b in batch]),
    }

# ---------------------------
# 7) Model
# ---------------------------
class EvoTransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim, dropout, activation):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.GELU() if activation == "gelu" else nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        key_padding_mask = (attn_mask == 0) if attn_mask is not None else None
        attn_out, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x

class EvoTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, num_heads, ff_dim, depth, dropout, activation):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        self.layers = nn.ModuleList([
            EvoTransformerBlock(d_model, num_heads, ff_dim, dropout, activation)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, attn_mask):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
        for layer in self.layers:
            x = layer(x, attn_mask)
        return self.norm(x[:, 0, :])

class EvoPIQAClassifier(nn.Module):
    def __init__(self, tokenizer, cfg):
        super().__init__()
        self.encoder = EvoTransformerEncoder(
            vocab_size=tokenizer.vocab_size,
            max_len=cfg["max_len"],
            d_model=cfg["d_model"],
            num_heads=cfg["num_heads"],
            ff_dim=cfg["ff_dim"],
            depth=cfg["depth"],
            dropout=cfg["dropout"],
            activation=cfg["activation"],
        )

        # Simple scoring head
        self.score_head = nn.Sequential(
            nn.Linear(cfg["d_model"], cfg["d_model"] // 2),
            nn.GELU(),
            nn.Dropout(cfg["dropout"]),
            nn.Linear(cfg["d_model"] // 2, 1)
        )

    def forward(self, batch):
        h1 = self.encoder(batch["input_ids_1"], batch["attn_mask_1"])
        h2 = self.encoder(batch["input_ids_2"], batch["attn_mask_2"])

        s1 = self.score_head(h1).squeeze(-1)
        s2 = self.score_head(h2).squeeze(-1)

        logits = torch.stack([s1, s2], dim=1)
        return logits

# ---------------------------
# 8) DataLoaders
# ---------------------------
train_ds = PIQAPairDataset(train_raw, tokenizer, CONFIG["max_len"])
val_ds = PIQAPairDataset(val_raw, tokenizer, CONFIG["max_len"])

train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], collate_fn=collate_fn)

# ---------------------------
# 9) Training with warmup
# ---------------------------
def evaluate(model, dataloader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(batch)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch["label"]).sum().item()
            total += batch["label"].size(0)
    return correct / total

def get_lr(epoch, cfg):
    """Warmup then cosine decay"""
    if epoch < cfg["warmup_epochs"]:
        return cfg["lr"] * (epoch + 1) / cfg["warmup_epochs"]
    else:
        progress = (epoch - cfg["warmup_epochs"]) / (cfg["epochs"] - cfg["warmup_epochs"])
        return cfg["min_lr"] + (cfg["lr"] - cfg["min_lr"]) * 0.5 * (1 + math.cos(math.pi * progress))

def train_loop(model, train_loader, val_loader, cfg, device):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    criterion = nn.CrossEntropyLoss()

    best_val = 0.0
    best_state = None
    patience = 0
    max_patience = 10

    for epoch in range(cfg["epochs"]):
        # Update learning rate
        current_lr = get_lr(epoch, cfg)
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr

        model.train()
        running_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['epochs']}")
        for batch in pbar:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            logits = model(batch)
            loss = criterion(logits, batch["label"])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            running_loss += loss.item() * batch["label"].size(0)
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'lr': f"{current_lr:.2e}"})

        avg_loss = running_loss / len(train_loader.dataset)
        val_acc = evaluate(model, val_loader, device)

        status = "🔥" if val_acc > best_val else "  "
        print(f"{status} Epoch {epoch+1}: Loss={avg_loss:.4f}, Val={val_acc*100:.2f}%, LR={current_lr:.2e}")

        if val_acc > best_val:
            best_val = val_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            patience = 0
            print(f"   ✓ New best: {val_acc*100:.2f}%")
        else:
            patience += 1
            if patience >= max_patience:
                print(f"   Early stopping (patience={max_patience})")
                break

    os.makedirs("logs", exist_ok=True)
    if best_state:
        torch.save({"state_dict": best_state, "config": CONFIG}, "logs/best_evo_stable.pt")
        print(f"\n✓ Model saved")
    return best_val

# ---------------------------
# 10) Run
# ---------------------------
model = EvoPIQAClassifier(tokenizer, CONFIG)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Params: {total_params:,} (~{total_params/1e6:.1f}M)")

print("\n" + "="*60)
print("STABLE TRAINING APPROACH:")
print("="*60)
print("✓ Conservative LR (3e-5) with warmup")
print("✓ Simple cosine decay")
print("✓ No augmentation (cleaner signal)")
print("✓ Standard dropout (0.15)")
print("✓ No label smoothing")
print("✓ Batch size 16")
print("="*60 + "\n")

best = train_loop(model, train_loader, val_loader, CONFIG, device)

print(f"\n{'='*60}")
print(f"🎯 BEST: {best*100:.2f}%")
print(f"{'='*60}")

if best >= 0.70:
    print("🎉 70%+ ACHIEVED!")
else:
    print(f"Gap to 70%: {(0.70-best)*100:.2f}pp")

Installing packages (may take a minute)...
Installed.
Device: cuda
Downloading PIQA...
Train: 16113, Val: 1838
Params: 41,044,481 (~41.0M)

STABLE TRAINING APPROACH:
✓ Conservative LR (3e-5) with warmup
✓ Simple cosine decay
✓ No augmentation (cleaner signal)
✓ Standard dropout (0.15)
✓ No label smoothing
✓ Batch size 16



Epoch 1/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 1: Loss=0.6937, Val=54.46%, LR=1.00e-05
   ✓ New best: 54.46%


Epoch 2/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 2: Loss=0.6878, Val=55.17%, LR=2.00e-05
   ✓ New best: 55.17%


Epoch 3/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 3: Loss=0.6788, Val=56.91%, LR=3.00e-05
   ✓ New best: 56.91%


Epoch 4/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 4: Loss=0.6627, Val=57.78%, LR=3.00e-05
   ✓ New best: 57.78%


Epoch 5/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 5: Loss=0.6453, Val=57.51%, LR=2.99e-05


Epoch 6/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 6: Loss=0.6225, Val=59.09%, LR=2.96e-05
   ✓ New best: 59.09%


Epoch 7/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 7: Loss=0.5926, Val=59.85%, LR=2.91e-05
   ✓ New best: 59.85%


Epoch 8/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 8: Loss=0.5659, Val=60.50%, LR=2.85e-05
   ✓ New best: 60.50%


Epoch 9/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 9: Loss=0.5279, Val=61.21%, LR=2.76e-05
   ✓ New best: 61.21%


Epoch 10/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 10: Loss=0.4837, Val=60.72%, LR=2.66e-05


Epoch 11/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 11: Loss=0.4435, Val=60.55%, LR=2.55e-05


Epoch 12/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 12: Loss=0.4046, Val=61.92%, LR=2.42e-05
   ✓ New best: 61.92%


Epoch 13/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 13: Loss=0.3642, Val=61.04%, LR=2.28e-05


Epoch 14/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 14: Loss=0.3344, Val=60.07%, LR=2.12e-05


Epoch 15/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 15: Loss=0.2993, Val=62.02%, LR=1.97e-05
   ✓ New best: 62.02%


Epoch 16/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 16: Loss=0.2806, Val=61.64%, LR=1.80e-05


Epoch 17/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 17: Loss=0.2629, Val=61.59%, LR=1.63e-05


Epoch 18/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 18: Loss=0.2356, Val=61.59%, LR=1.47e-05


Epoch 19/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 19: Loss=0.2207, Val=61.75%, LR=1.30e-05


Epoch 20/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 20: Loss=0.2083, Val=61.37%, LR=1.13e-05


Epoch 21/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 21: Loss=0.1981, Val=62.19%, LR=9.76e-06
   ✓ New best: 62.19%


Epoch 22/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 22: Loss=0.1868, Val=62.19%, LR=8.25e-06


Epoch 23/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 23: Loss=0.1688, Val=62.08%, LR=6.84e-06


Epoch 24/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 24: Loss=0.1648, Val=62.51%, LR=5.55e-06
   ✓ New best: 62.51%


Epoch 25/30:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 25: Loss=0.1582, Val=63.17%, LR=4.39e-06
   ✓ New best: 63.17%


Epoch 26/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 26: Loss=0.1509, Val=62.02%, LR=3.39e-06


Epoch 27/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 27: Loss=0.1438, Val=62.13%, LR=2.54e-06


Epoch 28/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 28: Loss=0.1490, Val=62.35%, LR=1.87e-06


Epoch 29/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 29: Loss=0.1433, Val=62.24%, LR=1.39e-06


Epoch 30/30:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 30: Loss=0.1393, Val=62.08%, LR=1.10e-06

✓ Model saved

🎯 BEST: 63.17%
Gap to 70%: 6.83pp


In [None]:
# EvoTransformer PIQA - OPTIMIZED FOR 70%+
# Copy -> Paste -> Run

print("Installing packages...")
import sys, subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "transformers==4.40.1", "torch", "tqdm", "tensorflow-datasets", "sentencepiece"])

import os, math, random
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import tensorflow_datasets as tfds

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# OPTIMIZED CONFIG
CONFIG = {
    "max_len": 128,
    "d_model": 512,
    "num_heads": 8,
    "ff_dim": 2048,
    "depth": 8,
    "dropout": 0.1,              # Lower (was 0.15)
    "activation": "gelu",
    "batch_size": 16,
    "epochs": 50,                # Much longer
    "lr": 8e-5,                  # Higher (was 3e-5)
    "min_lr": 1e-6,
    "warmup_epochs": 5,          # More warmup
    "weight_decay": 0.01,
    "seed": 42,
}

torch.manual_seed(CONFIG["seed"])
random.seed(CONFIG["seed"])

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)

def load_piqa_tfds(split="train", limit=None):
    ds = tfds.load("piqa", split=split)
    out = []
    for ex in tfds.as_numpy(ds):
        out.append({
            "goal": ex["goal"].decode("utf-8") if isinstance(ex["goal"], bytes) else ex["goal"],
            "sol1": ex["sol1"].decode("utf-8") if isinstance(ex["sol1"], bytes) else ex["sol1"],
            "sol2": ex["sol2"].decode("utf-8") if isinstance(ex["sol2"], bytes) else ex["sol2"],
            "label": int(ex["label"])
        })
        if limit and len(out) >= limit: break
    return out

print("Loading PIQA...")
train_raw = load_piqa_tfds("train")
val_raw = load_piqa_tfds("validation")
print(f"Train: {len(train_raw)}, Val: {len(val_raw)}")

class PIQAPairDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=128):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        enc1 = self.tokenizer(ex["goal"] + " [SEP] " + ex["sol1"], truncation=True,
                             max_length=self.max_len, padding="max_length", return_tensors="pt")
        enc2 = self.tokenizer(ex["goal"] + " [SEP] " + ex["sol2"], truncation=True,
                             max_length=self.max_len, padding="max_length", return_tensors="pt")
        return {
            "input_ids_1": enc1["input_ids"].squeeze(0),
            "attn_mask_1": enc1["attention_mask"].squeeze(0),
            "input_ids_2": enc2["input_ids"].squeeze(0),
            "attn_mask_2": enc2["attention_mask"].squeeze(0),
            "label": torch.tensor(ex["label"], dtype=torch.long),
        }

def collate_fn(batch):
    return {k: torch.stack([b[k] for b in batch]) for k in batch[0].keys()}

class EvoTransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim, dropout, activation):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(nn.Linear(d_model, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, d_model))
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        key_padding_mask = (attn_mask == 0) if attn_mask is not None else None
        attn_out, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x

class EvoTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, num_heads, ff_dim, depth, dropout, activation):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        self.layers = nn.ModuleList([EvoTransformerBlock(d_model, num_heads, ff_dim, dropout, activation) for _ in range(depth)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, attn_mask):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
        for layer in self.layers:
            x = layer(x, attn_mask)
        return self.norm(x[:, 0, :])

class EvoPIQAClassifier(nn.Module):
    def __init__(self, tokenizer, cfg):
        super().__init__()
        self.encoder = EvoTransformerEncoder(tokenizer.vocab_size, cfg["max_len"], cfg["d_model"],
                                            cfg["num_heads"], cfg["ff_dim"], cfg["depth"], cfg["dropout"], cfg["activation"])
        self.score_head = nn.Sequential(nn.Linear(cfg["d_model"], cfg["d_model"]//2), nn.GELU(),
                                       nn.Dropout(cfg["dropout"]), nn.Linear(cfg["d_model"]//2, 1))

    def forward(self, batch):
        h1 = self.encoder(batch["input_ids_1"], batch["attn_mask_1"])
        h2 = self.encoder(batch["input_ids_2"], batch["attn_mask_2"])
        s1, s2 = self.score_head(h1).squeeze(-1), self.score_head(h2).squeeze(-1)
        return torch.stack([s1, s2], dim=1)

train_ds = PIQAPairDataset(train_raw, tokenizer, CONFIG["max_len"])
val_ds = PIQAPairDataset(val_raw, tokenizer, CONFIG["max_len"])
train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], collate_fn=collate_fn)

def evaluate(model, dataloader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            preds = torch.argmax(model(batch), dim=1)
            correct += (preds == batch["label"]).sum().item()
            total += batch["label"].size(0)
    return correct / total

def get_lr(epoch, cfg):
    if epoch < cfg["warmup_epochs"]:
        return cfg["lr"] * (epoch + 1) / cfg["warmup_epochs"]
    progress = (epoch - cfg["warmup_epochs"]) / (cfg["epochs"] - cfg["warmup_epochs"])
    return cfg["min_lr"] + (cfg["lr"] - cfg["min_lr"]) * 0.5 * (1 + math.cos(math.pi * progress))

model = EvoPIQAClassifier(tokenizer, CONFIG).to(device)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

optimizer = optim.AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
criterion = nn.CrossEntropyLoss()

best_val = 0.0
patience = 0

for epoch in range(CONFIG["epochs"]):
    lr = get_lr(epoch, CONFIG)
    for g in optimizer.param_groups: g['lr'] = lr

    model.train()
    running_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    for batch in pbar:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        loss = criterion(model(batch), batch["label"])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()a
        running_loss += loss.item() * batch["label"].size(0)
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'lr': f"{lr:.2e}"})

    avg_loss = running_loss / len(train_loader.dataset)
    val_acc = evaluate(model, val_loader, device)

    print(f"{'🔥' if val_acc > best_val else '  '} Epoch {epoch+1}: Loss={avg_loss:.4f}, Val={val_acc*100:.2f}%, LR={lr:.2e}")

    if val_acc > best_val:
        best_val = val_acc
        patience = 0
        torch.save({"state_dict": model.state_dict(), "config": CONFIG}, "logs/best_evo_70plus.pt")
        print(f"   ✓ New best: {val_acc*100:.2f}%")
    else:
        patience += 1
        if patience >= 15:
            print("Early stop")
            break

print(f"\n🎯 BEST: {best_val*100:.2f}%")

Installing packages...
Device: cuda
Loading PIQA...
Train: 16113, Val: 1838
Params: 41,044,481


Epoch 1/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 1: Loss=0.6918, Val=54.62%, LR=1.60e-05
   ✓ New best: 54.62%


Epoch 2/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 2: Loss=0.6843, Val=55.82%, LR=3.20e-05
   ✓ New best: 55.82%


Epoch 3/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 3: Loss=0.6754, Val=57.24%, LR=4.80e-05
   ✓ New best: 57.24%


Epoch 4/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 4: Loss=0.6610, Val=58.05%, LR=6.40e-05
   ✓ New best: 58.05%


Epoch 5/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 5: Loss=0.6445, Val=58.54%, LR=8.00e-05
   ✓ New best: 58.54%


Epoch 6/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 6: Loss=0.5987, Val=58.76%, LR=8.00e-05
   ✓ New best: 58.76%


Epoch 7/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 7: Loss=0.5350, Val=59.41%, LR=7.99e-05
   ✓ New best: 59.41%


Epoch 8/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 8: Loss=0.4626, Val=60.34%, LR=7.96e-05
   ✓ New best: 60.34%


Epoch 9/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 9: Loss=0.3971, Val=60.99%, LR=7.91e-05
   ✓ New best: 60.99%


Epoch 10/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 10: Loss=0.3393, Val=60.12%, LR=7.85e-05


Epoch 11/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 11: Loss=0.2888, Val=61.32%, LR=7.76e-05
   ✓ New best: 61.32%


Epoch 12/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 12: Loss=0.2617, Val=61.04%, LR=7.66e-05


Epoch 13/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 13: Loss=0.2177, Val=62.13%, LR=7.54e-05
   ✓ New best: 62.13%


Epoch 14/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 14: Loss=0.2012, Val=61.92%, LR=7.40e-05


Epoch 15/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 15: Loss=0.1644, Val=62.13%, LR=7.25e-05


Epoch 16/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 16: Loss=0.1534, Val=62.79%, LR=7.08e-05
   ✓ New best: 62.79%


Epoch 17/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 17: Loss=0.1397, Val=62.13%, LR=6.89e-05


Epoch 18/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 18: Loss=0.1275, Val=62.19%, LR=6.69e-05


Epoch 19/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 19: Loss=0.1090, Val=62.62%, LR=6.48e-05


Epoch 20/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 20: Loss=0.1055, Val=62.84%, LR=6.26e-05
   ✓ New best: 62.84%


Epoch 21/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 21: Loss=0.0924, Val=63.38%, LR=6.03e-05
   ✓ New best: 63.38%


Epoch 22/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 22: Loss=0.0890, Val=62.62%, LR=5.78e-05


Epoch 23/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 23: Loss=0.0843, Val=63.00%, LR=5.53e-05


Epoch 24/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 24: Loss=0.1694, Val=62.46%, LR=5.27e-05


Epoch 25/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 25: Loss=0.0661, Val=63.76%, LR=5.01e-05
   ✓ New best: 63.76%


Epoch 26/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 26: Loss=0.0560, Val=62.40%, LR=4.74e-05


Epoch 27/50:   0%|          | 0/1008 [00:00<?, ?it/s]

🔥 Epoch 27: Loss=0.0591, Val=64.58%, LR=4.46e-05
   ✓ New best: 64.58%


Epoch 28/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 28: Loss=0.0511, Val=63.17%, LR=4.19e-05


Epoch 29/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 29: Loss=0.0431, Val=62.51%, LR=3.91e-05


Epoch 30/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 30: Loss=0.0449, Val=63.44%, LR=3.64e-05


Epoch 31/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 31: Loss=0.0393, Val=62.89%, LR=3.36e-05


Epoch 32/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 32: Loss=0.0310, Val=63.00%, LR=3.09e-05


Epoch 33/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 33: Loss=0.0339, Val=62.62%, LR=2.83e-05


Epoch 34/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 34: Loss=0.0323, Val=63.71%, LR=2.57e-05


Epoch 35/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 35: Loss=0.0326, Val=63.28%, LR=2.32e-05


Epoch 36/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 36: Loss=0.0232, Val=62.95%, LR=2.08e-05


Epoch 37/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 37: Loss=0.0245, Val=63.22%, LR=1.84e-05


Epoch 38/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 38: Loss=0.0209, Val=64.36%, LR=1.62e-05


Epoch 39/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 39: Loss=0.0201, Val=64.53%, LR=1.41e-05


Epoch 40/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 40: Loss=0.0153, Val=63.76%, LR=1.21e-05


Epoch 41/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 41: Loss=0.0221, Val=63.98%, LR=1.02e-05


Epoch 42/50:   0%|          | 0/1008 [00:00<?, ?it/s]

   Epoch 42: Loss=0.0186, Val=63.76%, LR=8.54e-06
Early stop

🎯 BEST: 64.58%
