# 04 — Train Controller (CMA-ES)

Train a linear controller using CMA-ES to map [z, h] -> action.

**Prerequisites:**
1. Train VAE (notebook 02)
2. Train MDN-RNN (notebook 03)

In [None]:
import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path.cwd().parent))
from src.config import Config
from src.controller import Controller
from src.mdn_rnn import MDNRNN
from src.train_controller import train_controller
from src.vae import ConvVAE

config = Config()
# Controller training is CPU-bound (env rollouts), keep on CPU
config.device = "cpu"
print(f"Controller params: {Controller().num_params}")

In [None]:
# Load trained VAE and RNN
base = Path.cwd().parent

vae = ConvVAE(latent_dim=config.vae.latent_dim)
vae.load_state_dict(
    torch.load(
        base / config.vae.checkpoint_dir / "vae_final.pt",
        map_location="cpu",
        weights_only=True,
    )
)
vae.eval()

rnn = MDNRNN(
    latent_dim=config.rnn.latent_dim,
    action_dim=config.rnn.action_dim,
    hidden_dim=config.rnn.hidden_dim,
    num_gaussians=config.rnn.num_gaussians,
)
rnn.load_state_dict(
    torch.load(
        base / config.rnn.checkpoint_dir / "rnn_final.pt",
        map_location="cpu",
        weights_only=True,
    )
)
rnn.eval()
print("Models loaded.")

In [None]:
# For local testing, reduce population and rollouts
config.controller.population_size = 16
config.controller.num_rollouts = 4
config.controller.max_generations = 10

best_controller = train_controller(config.controller, config, vae, rnn)

In [None]:
# Quick evaluation
import numpy as np

from src.rollout import rollout_episode

rewards = []
for seed in range(10):
    r = rollout_episode(vae, rnn, best_controller, seed=seed)
    rewards.append(r)
    print(f"Episode {seed}: reward = {r:.1f}")

print(f"\nMean: {np.mean(rewards):.1f} ± {np.std(rewards):.1f}")