# Train a GRU to Generate Human-like Mouse Trajectories

This notebook:
- Loads all `../data/session_*.jsonl` recordings.
- Normalizes each trajectory: translate to (0,0), rotate so target is at angle 0, and scale so the target lies at x=1.
- Resamples each path to a fixed length (e.g., 64 steps).
- Trains a small GRU to predict next (x,y) given current (x,y) and normalized time `u`.
- Generates 100 trajectories towards random canvas targets (750x550) and visualizes them.

Requirements: `torch` (installed via `training/requirements.txt`). The repo `.venv` already has numpy/matplotlib.


In [None]:
import os, json, math, glob
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE


In [None]:
DATA_DIR = (Path.cwd() / '..' / 'data').resolve()
files = sorted(DATA_DIR.glob('session_*.jsonl'))
if not files:
    raise FileNotFoundError(f'No session_*.jsonl in {DATA_DIR}. Collect data via web_capture first.')
len(files), files[-3:]


In [None]:
# Utilities: normalization and resampling
def to_arrays(path):
    t = np.array([p['t'] for p in path], dtype=float)
    x = np.array([p['x'] for p in path], dtype=float)
    y = np.array([p['y'] for p in path], dtype=float)
    t = t - t.min()
    keep = np.r_[True, np.diff(t) > 0]
    return t[keep], x[keep], y[keep]

def normalize_to_target(t, x, y, target):
    # translate so start at (0,0)
    x0, y0 = x[0], y[0]
    xt, yt = x - x0, y - y0
    tx, ty = float(target.get('x', x0)) - x0, float(target.get('y', y0)) - y0
    D = math.hypot(tx, ty)
    if D == 0: D = 1.0
    ang = math.atan2(ty, tx)
    ca, sa = math.cos(-ang), math.sin(-ang)
    xr = ca*xt - sa*yt
    yr = sa*xt + ca*yt
    xnorm, ynorm = xr / D, yr / D
    return xnorm, ynorm, D, ang, (x0, y0)

def resample_fixed(t, x, y, steps=64):
    if len(t) < 2:
        u = np.linspace(0,1,steps)
        return u, np.zeros(steps), np.zeros(steps)
    T = t[-1] - t[0]
    if T <= 0: T = 1.0
    u_raw = (t - t[0]) / T
    u = np.linspace(0, 1, steps)
    x_i = np.interp(u, u_raw, x)
    y_i = np.interp(u, u_raw, y)
    return u, x_i, y_i

def inverse_transform(xn, yn, D, ang, origin):
    # from normalized (target at 1,0) back to canvas
    xr, yr = xn * D, yn * D
    ca, sa = math.cos(ang), math.sin(ang)
    x = ca*xr - sa*yr + origin[0]
    y = sa*xr + ca*yr + origin[1]
    return x, y


In [None]:
# Load and prepare dataset
STEPS = 64
samples = []  # each: dict with 'u', 'xy' (steps x 2)
count_lines = 0
for fp in files:
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            count_lines += 1
            line = line.strip()
            if not line: continue
            try:
                rec = json.loads(line)
            except Exception:
                continue
            path = rec.get('path', [])
            if len(path) < 5: # too short
                continue
            t, x, y = to_arrays(path)
            # basic sanity
            if not np.isfinite([t,x,y]).all():
                continue
            u, xi, yi = resample_fixed(t, x, y, steps=STEPS)
            xn, yn, D, ang, origin = normalize_to_target(t, xi, yi, rec.get('target', {}))
            if np.isnan(xn).any() or np.isnan(yn).any():
                continue
            xy = np.stack([xn, yn], axis=1)
            samples.append({'u': u.astype(np.float32), 'xy': xy.astype(np.float32)})
len(samples)


In [None]:
# Torch Dataset
class TrajDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        u = s['u']   # (S,)
        xy = s['xy'] # (S,2)
        # inputs are first S-1 steps; targets are next S-1 steps
        inp = np.concatenate([xy[:-1], u[:-1,None]], axis=1)  # (S-1, 3)
        tgt = xy[1:]  # (S-1, 2)
        return torch.from_numpy(inp), torch.from_numpy(tgt)

ds = TrajDataset(samples)
n_train = int(0.9*len(ds))
n_val = len(ds) - n_train
train_ds, val_ds = torch.utils.data.random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(0))
len(train_ds), len(val_ds)


In [None]:
BATCH = 128
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, drop_last=False)
next(iter(train_loader))[0].shape, next(iter(train_loader))[1].shape


In [None]:
# Model: GRU that maps (x,y,u) -> next (x,y)
class TrajGRU(nn.Module):
    def __init__(self, input_size=3, hidden=128, layers=2):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden, num_layers=layers, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(hidden, 64), nn.ReLU(),
            nn.Linear(64, 2)
        )
    def forward(self, x, h=None):
        # x: (B, T, 3)
        y, h = self.gru(x, h)
        out = self.head(y)
        return out, h

model = TrajGRU().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = nn.MSELoss()
model


In [None]:
# Train loop
EPOCHS = 20
train_hist, val_hist = [], []
for epoch in range(1, EPOCHS+1):
    model.train(); tl=0.0; nb=0
    for inp, tgt in train_loader:
        inp = inp.to(DEVICE).float()
        tgt = tgt.to(DEVICE).float()
        opt.zero_grad()
        pred, _ = model(inp)  # (B, T, 2)
        loss = crit(pred, tgt)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        tl += loss.item(); nb += 1
    train_loss = tl / max(1, nb)
    # val
    model.eval(); vl=0.0; vb=0
    with torch.no_grad():
        for inp, tgt in val_loader:
            inp = inp.to(DEVICE).float()
            tgt = tgt.to(DEVICE).float()
            pred, _ = model(inp)
            loss = crit(pred, tgt)
            vl += loss.item(); vb += 1
    val_loss = vl / max(1, vb)
    train_hist.append(train_loss); val_hist.append(val_loss)
    print(f'Epoch {epoch:02d}  train {train_loss:.4f}  val {val_loss:.4f}')

plt.figure(figsize=(6,3))
plt.plot(train_hist, label='train')
plt.plot(val_hist, label='val')
plt.xlabel('epoch'); plt.ylabel('MSE'); plt.legend(); plt.title('Training curve')
plt.tight_layout(); plt.show()


In [None]:
# Save checkpoints to training/checkpoints/
from pathlib import Path
ROOT = Path.cwd()
TRAIN_DIR = ROOT / 'training' if (ROOT / 'training').exists() else ROOT
CKPT_DIR = TRAIN_DIR / 'checkpoints'
CKPT_DIR.mkdir(parents=True, exist_ok=True)
last_path = CKPT_DIR / 'last.pt'
torch.save({'model': model.state_dict()}, last_path)
print(f'Saved last checkpoint to {last_path}')


In [None]:
# Optional: load a checkpoint before generation (set LOAD accordingly)
from pathlib import Path
LOAD = 'last'  # 'last', 'best', a file path, or None to skip
ROOT = Path.cwd()
TRAIN_DIR = ROOT / 'training' if (ROOT / 'training').exists() else ROOT
CKPT_DIR = TRAIN_DIR / 'checkpoints'
if LOAD in ('last', 'best'):
    path = CKPT_DIR / f'{LOAD}.pt'
elif LOAD:
    path = Path(LOAD)
else:
    path = None
if path and path.exists():
    state = torch.load(path, map_location=DEVICE)
    model.load_state_dict(state['model'])
    print(f'Loaded checkpoint from {path}')
else:
    print('No checkpoint loaded')


In [None]:
# Generation: sample 100 random targets on 750x550 canvas
CANVAS_W, CANVAS_H = 750, 550
RADIUS = 12
MARGIN = RADIUS + 10
def sample_start_target():
    sx = np.random.randint(MARGIN, CANVAS_W - MARGIN)
    sy = np.random.randint(MARGIN, CANVAS_H - MARGIN)
    # ensure target is not too close to start
    for _ in range(100):
        tx = np.random.randint(MARGIN, CANVAS_W - MARGIN)
        ty = np.random.randint(MARGIN, CANVAS_H - MARGIN)
        if (tx-sx)**2 + (ty-sy)**2 >= (5*RADIUS)**2:
            return (sx, sy), (tx, ty)
    return (sx, sy), (tx, ty)

def generate_norm(model, steps=STEPS):
    model.eval()
    with torch.no_grad():
        # start at (0,0), roll forward
        xn = np.zeros((steps,), dtype=np.float32)
        yn = np.zeros((steps,), dtype=np.float32)
        u = np.linspace(0, 1, steps, dtype=np.float32)
        h = None
        for i in range(steps-1):
            inp = torch.tensor([[ [xn[i], yn[i], u[i]] ]], dtype=torch.float32, device=DEVICE)  # (1,1,3)
            out, h = model(inp, h)  # (1,1,2)
            xn[i+1], yn[i+1] = out[0,0].cpu().numpy()
        return u, xn, yn

# Generate 100
trajectories = []
for _ in range(100):
    (sx, sy), (tx, ty) = sample_start_target()
    u, xn, yn = generate_norm(model, steps=STEPS)
    D = math.hypot(tx-sx, ty-sy)
    ang = math.atan2(ty-sy, tx-sx)
    x, y = inverse_transform(xn, yn, D, ang, (sx, sy))
    trajectories.append((x, y, (sx, sy), (tx, ty)))
len(trajectories)


In [None]:
# Visualize generated trajectories
plt.figure(figsize=(8,6))
for (x, y, (sx, sy), (tx, ty)) in trajectories:
    plt.plot(x, y, color='tab:purple', alpha=0.2, lw=1.8)
# draw a few sampled targets
for i in range(0, len(trajectories), max(1, len(trajectories)//10)):
    _, _, (sx, sy), (tx, ty) = trajectories[i]
    plt.scatter([sx, tx], [sy, ty], s=12, c=['#22c55e','#ef4444'])
plt.xlim(0, CANVAS_W); plt.ylim(0, CANVAS_H)
plt.gca().invert_yaxis()  # canvas y-down convention
plt.gca().set_aspect('equal', adjustable='box')
plt.title('Generated trajectories to random targets (purple)')
plt.xlabel('x (px)'); plt.ylabel('y (px)')
plt.tight_layout(); plt.show()
