In [None]:
# SPDX-FileCopyrightText: 2026 Mario Gemoll
# SPDX-License-Identifier: 0BSD

import os
import subprocess


def is_correct_repo() -> bool:
    try:
        result = subprocess.run(
            ["git", "remote", "get-url", "origin"],
            capture_output=True,
            text=True,
            check=True,
        )
        remote_url = result.stdout.strip()
        return remote_url in [
            "https://github.com/mariogemoll/reinforcement-learning.git",
            "git@github.com:mariogemoll/reinforcement-learning.git",
        ]
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


if not is_correct_repo():
    !git clone https://github.com/mariogemoll/reinforcement-learning.git

if not os.getcwd().endswith("reinforcement-learning/py"):
    %cd reinforcement-learning/py

In [None]:
%pip install -q gymnax

In [None]:
import base64
from pathlib import Path

from IPython.display import display

from dqn import fresh_params, run_config, make_model
from util import eval_dqn_max_score, plot_dqn_metrics

In [None]:
cfg = {
    "lr": 5e-4,
    "decay_dur": 10_000,
    "batch_size": 128,
    "buf_cap": 10_000,
    "learn_start": 10_000,
    "upd_every": 200,
}
total_steps = 100_000

ep_rets, losses, best_params = run_config(cfg, total_steps, fresh_params())
mean_last50 = sum(ep_rets[-50:]) / min(50, len(ep_rets))
print(f"Done. {len(ep_rets)} episodes, mean last-50 return: {mean_last50:.1f}")

In [None]:
plot_dqn_metrics(losses, ep_rets)

In [None]:
max_count, n_eps, max_score, max_pct = eval_dqn_max_score(
    best_params,
    num_eval_episodes=100,
    batch_size=20,
    seed=123,
    show_progress=True,
)
print(f"Episodes with max score ({max_score:.0f}): {max_count}/{n_eps} ({max_pct:.1f}%)")

In [None]:
%%bash
# Build CartPole visualization
set -euo pipefail

cd ../ts

if command -v pnpm >/dev/null 2>&1; then
  echo "Using package manager: pnpm"
  pnpm i
  pnpm run build:anywidget-cartpole
else
  echo "Using package manager: npm"
  npm install
  npm run build:anywidget-cartpole
fi

In [None]:
from cartpole import CartPoleVisualization, as_f32, write_safetensors

model = make_model(best_params)

# Transpose kernels: (in_dim, out_dim) -> (out_dim, in_dim).
tensors = {
    "w0": as_f32(model.layer1.kernel).T,
    "b0": as_f32(model.layer1.bias).reshape(-1),
    "w2": as_f32(model.layer2.kernel).T,
    "b2": as_f32(model.layer2.bias).reshape(-1),
    "w4": as_f32(model.output.kernel).T,
    "b4": as_f32(model.output.bias).reshape(-1),
}

out_path = Path("dqn-weights.safetensors")
write_safetensors(out_path, tensors)
print(f"Wrote {out_path}")

In [None]:
weights_bytes = Path("dqn-weights.safetensors").read_bytes()
weights_b64 = base64.b64encode(weights_bytes).decode("ascii")

display(CartPoleVisualization(weights_base64=weights_b64))