# 05 — Evaluation & Analysis

Comprehensive evaluation of the trained World Model agent.

**Prerequisites:** Complete notebooks 02-04 (all models trained).

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, str(Path.cwd().parent))
from src.config import Config
from src.controller import Controller
from src.mdn_rnn import MDNRNN
from src.rollout import rollout_episode
from src.vae import ConvVAE

config = Config()
base = Path.cwd().parent

In [None]:
# Load all models
vae = ConvVAE(latent_dim=config.vae.latent_dim)
vae.load_state_dict(
    torch.load(
        base / config.vae.checkpoint_dir / "vae_final.pt",
        map_location="cpu",
        weights_only=True,
    )
)
vae.eval()

rnn = MDNRNN(
    latent_dim=config.rnn.latent_dim,
    action_dim=config.rnn.action_dim,
    hidden_dim=config.rnn.hidden_dim,
    num_gaussians=config.rnn.num_gaussians,
)
rnn.load_state_dict(
    torch.load(
        base / config.rnn.checkpoint_dir / "rnn_final.pt",
        map_location="cpu",
        weights_only=True,
    )
)
rnn.eval()

controller = Controller(
    latent_dim=config.controller.latent_dim,
    hidden_dim=config.controller.hidden_dim,
    action_dim=config.controller.action_dim,
)
ckpt = torch.load(
    base / config.controller.checkpoint_dir / "controller_best.pt",
    map_location="cpu",
    weights_only=True,
)
controller.set_params(ckpt["params"])
controller.eval()
print(f"Best training reward: {ckpt['reward']:.1f}")

In [None]:
# Run 100 evaluation episodes
rewards = []
for seed in range(100):
    r = rollout_episode(vae, rnn, controller, seed=seed)
    rewards.append(r)
    if (seed + 1) % 10 == 0:
        print(f"  [{seed + 1}/100] running mean: {np.mean(rewards):.1f}")

print(f"\n{'=' * 40}")
print(f"Mean reward: {np.mean(rewards):.1f} ± {np.std(rewards):.1f}")
print(f"Median: {np.median(rewards):.1f}")
print(f"Min: {np.min(rewards):.1f}, Max: {np.max(rewards):.1f}")
print("Paper target: ~906")

In [None]:
# Reward distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.hist(rewards, bins=20, edgecolor="black", alpha=0.7)
ax1.axvline(
    np.mean(rewards), color="red", linestyle="--", label=f"Mean: {np.mean(rewards):.0f}"
)
ax1.axvline(906, color="green", linestyle="--", label="Paper: 906")
ax1.set_title("Reward Distribution")
ax1.set_xlabel("Total Reward")
ax1.legend()

ax2.plot(rewards, alpha=0.5)
ax2.axhline(np.mean(rewards), color="red", linestyle="--")
ax2.set_title("Reward per Episode")
ax2.set_xlabel("Episode")
ax2.set_ylabel("Total Reward")

plt.tight_layout()
plt.show()