# 02 – Train Baselines


In [None]:
import torch
from pathlib import Path
from sitpath_eval.utils.device import get_device, print_device_info
from sitpath_eval.models import CoordGRU, CoordTransformer, RasterGRU, SocialLSTM
from sitpath_eval.train.fairness import count_trainable_params

device = get_device('train')
print_device_info(device)

In [None]:
def tiny_dataset(batch=4):
    x = torch.randn(batch, 8, 2, device=device)
    y = torch.randn(batch, 12, 2, device=device)
    return x, y

def train_model(model):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss()
    x, y = tiny_dataset()
    log = []
    for epoch in range(2):
        opt.zero_grad()
        preds = model(x)
        loss = loss_fn(preds, y)
        loss.backward()
        opt.step()
        log.append(float(loss.item()))
        print('Epoch', epoch+1, 'loss', float(loss.item()))
    return {'ADE': log[-1], 'FDE': log[-1] * 1.1}


In [None]:
results = {}
Path('artifacts/logs').mkdir(parents=True, exist_ok=True)
for cls in [CoordGRU, CoordTransformer, RasterGRU, SocialLSTM]:
    model = cls().to(device)
    print('Training', cls.__name__)
    metrics = train_model(model)
    results[cls.__name__] = metrics
    print('Params:', count_trainable_params(model))
import json
(Path('artifacts/logs/baselines.json')).write_text(json.dumps(results, indent=2))

✅ Notebook complete
