# 01 · Training and Baselines (PyTorch)

Train baseline models on CIFAR-10, export ONNX artifacts, and log per-epoch metrics for reproducible comparisons.

- Methodology and reproducibility: docs/reproducibility.md
- Benchmark protocol and metrics: docs/benchmarks.md
- Where results go: docs/results.md

> Cache policy:
 >
> - This notebook reuses existing artifacts under `models_saved/` when found:
 >   - If `.pt` weights and `.pkl` history exist for a model, training is skipped unless `FORCE_RETRAIN=True`.
 >   - If the `.onnx` file exists, export is skipped unless `FORCE_REEXPORT_ONNX=True`.
 > - Datasets are stored in `ROOT/data`. If present, they are reused; otherwise, they are downloaded once.
 >
 > You can change force flags at the top of the code cell below.

In [None]:
# Setup & Logging
import os
from pathlib import Path

import yaml

from utils.logging_utils import get_logger

os.environ.setdefault("LOG_LEVEL", "INFO")
root = Path(__file__).resolve().parent.parent if '__file__' in globals() else Path(os.getcwd()).parent
logger = get_logger("nb01")
logger.info("Notebook 01 starting. Root=%s", root)

# Load config and summarize cache
cfg = yaml.safe_load(open(root / 'config/bench_matrix.yaml', 'r', encoding='utf-8'))
models_cfg = {m['name']: m for m in cfg['models']}

pt_dir = root / 'models_saved' / 'pytorch'
onnx_dir = root / 'models_saved' / 'onnx'

rows = []
for name, c in models_cfg.items():
    pt = pt_dir / c['file_pt']
    hist = pt_dir / c['file_hist']
    onnx = onnx_dir / c['file_onnx']
    rows.append((name, pt.exists(), hist.exists(), onnx.exists()))

print("Model cache status (PT/HIST/ONNX):")
for name, pt_ok, hist_ok, onnx_ok in rows:
    print(f"- {name}: pt={pt_ok}, hist={hist_ok}, onnx={onnx_ok}")
logger.info("Cache status listed for %d models", len(rows))
print("Use FORCE_RETRAIN=1 and/or FORCE_REEXPORT_ONNX=1 to override cache reuse.")

In [2]:
# Training / Export / Baseline
import os, pickle
from pathlib import Path
import yaml
import torch
from torch import optim
from utils.data_utils import DataLoaderFactory
from utils.train_utils import train_model
from utils.io import csv_append_row, CSV_SCHEMA, sha256_file, utc_timestamp, git_commit_short, nvidia_driver_version
from utils.infer_torch import benchmark_dataloader
from utils.logging_utils import get_logger
from models.cnn import CNN
from models.mlp import MLP
from models.mobilenetv3 import MobileNetV3
from models.efficientnet_lite0 import EfficientNetLite0

logger = get_logger("nb01")

# Force flags (can also be set via environment variables)
FORCE_RETRAIN = bool(int(os.getenv('FORCE_RETRAIN', '0')))
FORCE_REEXPORT_ONNX = bool(int(os.getenv('FORCE_REEXPORT_ONNX', '0')))

root = Path(__file__).resolve().parent.parent if '__file__' in globals() else Path(os.getcwd()).parent
cfg = yaml.safe_load(open(root / 'config/bench_matrix.yaml', 'r', encoding='utf-8'))
defs = cfg['defaults']
train_cfg = cfg['train']
models_cfg = {m['name']: m for m in cfg['models']}
out_train_csv = root / cfg['outputs']['train_csv']
models_pt_dir = root / 'models_saved/pytorch'
models_onnx_dir = root / 'models_saved/onnx'
models_pt_dir.mkdir(parents=True, exist_ok=True)
models_onnx_dir.mkdir(parents=True, exist_ok=True)

# Devices and seed
device_cuda = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_cpu = torch.device('cpu')
seed = defs.get('seed', 42)
torch.manual_seed(seed)
logger.info("Config loaded. batch=%s, epochs=%s, seed=%s", defs['batch'], train_cfg['epochs'], seed)

# Data: CIFAR-10 loaders
try:
    train_loader, test_loader = DataLoaderFactory.get_cifar10_dataloaders(
        batch_size=defs['batch'], num_workers=defs['num_workers'], data_dir=str(root / 'data'), download=True
    )
except Exception as ex:
    logger.exception("Failed to prepare data loaders: %s", ex)
    raise


def build_model(name: str):
    if name == 'cnn': return CNN()
    if name == 'mlp': return MLP(input_size=32 * 32 * 3)
    if name == 'mobilenetv3': return MobileNetV3()
    if name == 'efficientnetlite0': return EfficientNetLite0()
    raise ValueError(f'Unknown model {name}')


def train_and_save(name: str):
    pt_path = models_pt_dir / models_cfg[name]['file_pt']
    hist_path = models_pt_dir / models_cfg[name]['file_hist']
    if pt_path.exists() and hist_path.exists() and not FORCE_RETRAIN:
        logger.info("[cache] Skipping training for %s: found %s and %s", name, pt_path.name, hist_path.name)
        m = build_model(name)
        state = torch.load(pt_path, map_location='cpu')
        m.load_state_dict(state['model_state_dict'])
        return m

    m = build_model(name)
    device = device_cuda if torch.cuda.is_available() else device_cpu
    m = m.to(device)
    logger.info("[train] Training %s on %s with input shape %s", name, device, models_cfg[name]['input_shape'])
    opt = optim.Adam(m.parameters(), lr=train_cfg['lr'], weight_decay=train_cfg['weight_decay'])
    hist = train_model(
        m, opt, train_loader, device,
        num_epochs=train_cfg['epochs'],
        early_stopping_patience=train_cfg['early_stopping_patience'],
        verbose=True,
        measure_energy=True,
    )
    # Save epoch-wise history to CSV for plots
    hist_csv = root / 'metrics' / f'{name}_train_history.csv'
    from utils.io import ensure_dir
    ensure_dir(str(hist_csv.parent))
    import csv
    with open(hist_csv, 'w', newline='', encoding='utf-8') as f:
        w = csv.writer(f)
        header = ['epoch', 'loss', 'accuracy', 'epoch_time_s']
        if 'epoch_energy_j' in hist: header.append('epoch_energy_j')
        w.writerow(header)
        for i in range(len(hist['loss'])):
            row = [i + 1, hist['loss'][i], hist['accuracy'][i], hist['epoch_time'][i]]
            if 'epoch_energy_j' in hist: row.append(hist['epoch_energy_j'][i])
            w.writerow(row)
    logger.info("Saved epoch history: %s", hist_csv)
    # Save weights (.pt) and history (.pkl)
    torch.save({'model_state_dict': m.state_dict()}, pt_path)
    with open(hist_path, 'wb') as f:
        pickle.dump(hist, f)
    logger.info("[save] Wrote %s and %s", pt_path, hist_path)
    return m.cpu()


def export_onnx(name: str, model):
    onnx_path = models_onnx_dir / models_cfg[name]['file_onnx']
    if onnx_path.exists() and not FORCE_REEXPORT_ONNX:
        logger.info("[cache] Skipping ONNX export for %s: found %s", name, onnx_path.name)
        return onnx_path

    chw = models_cfg[name]['input_shape']
    sample_shape = (1, chw[0], chw[1], chw[2])
    try:
        if hasattr(model, 'to_onnx'):
            model.to_onnx(sample_shape, str(onnx_path), opset=17, dynamic_batch=True)
        else:
            torch.onnx.export(
                model.eval(), torch.randn(*sample_shape), str(onnx_path),
                input_names=["input"], output_names=["logits"],
                dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
                opset_version=17, do_constant_folding=True
            )
        logger.info("[save] Wrote %s", onnx_path)
    except Exception as ex:
        logger.exception("ONNX export failed for %s: %s", name, ex)
        raise
    return onnx_path


def log_baseline(name: str, onnx_path: Path):
    rows = []
    for dev in ['cpu', 'cuda']:
        if dev == 'cuda' and not torch.cuda.is_available():
            continue
        model = build_model(name)
        state = torch.load(models_pt_dir / models_cfg[name]['file_pt'], map_location=dev)
        model.load_state_dict(state['model_state_dict'])
        device = torch.device(dev)
        metrics = benchmark_dataloader(
            model, test_loader, device=device,
            warmup=defs['warmup'], runs=min(defs['runs'], len(test_loader))
        )
        row = {
            'ts': utc_timestamp(), 'exp_id': 'train-baseline', 'model': name, 'dataset': defs['dataset'],
            'precision': defs['precision'], 'engine': 'pytorch', 'provider': dev.upper(), 'batch': defs['batch'],
            'warmup': defs['warmup'], 'runs': defs['runs'], 'lat_ms_mean': metrics['lat_ms_mean'],
            'lat_ms_p95': metrics['lat_ms_p95'], 'thr_ips': metrics['thr_ips'], 'acc': metrics['acc'],
            'energy_j': '', 'device_name': torch.cuda.get_device_name(0) if dev == 'cuda' else 'CPU',
            'driver_ver': nvidia_driver_version() if dev == 'cuda' else 'N/A', 'commit': git_commit_short(),
            'model_hash': sha256_file(str(onnx_path))
        }
        rows.append(row)
    for r in rows:
        csv_append_row(str(out_train_csv), r, CSV_SCHEMA)


# Main loop
for name in models_cfg.keys():
    model = train_and_save(name)
    onnx_path = export_onnx(name, model)
    log_baseline(name, onnx_path)
print('Done.')

2025-08-13 10:57:34 | INFO | nb01 | Config loaded. batch=64, epochs=20, seed=42
INFO:utils.data_utils:Dataset CIFAR-10: using cached data at C:\Users\padul\OneDrive\Universidad\Doctorado\Desarrollo\federated-lab-multihw\data
2025-08-13 10:57:35 | INFO | nb01 | [train] Training cnn on cuda with input shape [3, 32, 32]
2025-08-13 10:57:48 | INFO | train_utils | Epoch 1/20 | Loss: 1.4993 | Acc: 0.4533 | Time: 12.3s
2025-08-13 10:57:59 | INFO | train_utils | Epoch 2/20 | Loss: 1.1713 | Acc: 0.5813 | Time: 11.8s
2025-08-13 10:58:11 | INFO | train_utils | Epoch 3/20 | Loss: 1.0425 | Acc: 0.6286 | Time: 11.7s
2025-08-13 10:58:23 | INFO | train_utils | Epoch 4/20 | Loss: 0.9723 | Acc: 0.6578 | Time: 11.9s
2025-08-13 10:58:35 | INFO | train_utils | Epoch 5/20 | Loss: 0.9042 | Acc: 0.6834 | Time: 11.5s
2025-08-13 10:58:47 | INFO | train_utils | Epoch 6/20 | Loss: 0.8544 | Acc: 0.7010 | Time: 12.0s
2025-08-13 10:58:58 | INFO | train_utils | Epoch 7/20 | Loss: 0.8173 | Acc: 0.7177 | Time: 11.4s
20

Done.
