# DQN (TorchRL) — CartPole demo

In [None]:

import torch
from agentslab.utils.device import get_device
from agentslab.utils.seeding import set_seed

# choose device at notebook level
PREFERRED_DEVICE = "cuda"  # "cuda" | "mps" | "cpu"
device = get_device(PREFERRED_DEVICE)
print("Using device:", device)

SEED = 0
set_seed(SEED)


In [None]:

from agentslab.runner.train import DQNConfig, train_dqn
cfg = DQNConfig(env_id="CartPole-v1", seed=SEED, total_frames=20_000, frames_per_batch=512, optim_steps_per_batch=2, batch_size=256)
cfg.log_root = "logs"
cfg.ckpt_root = "checkpoints"
cfg.run_name = "dqn_cartpole"
cfg.eps_init = 1.0
cfg.eps_end = 0.05
cfg.epsilon_steps = 50_000
cfg.lr = 1e-3
cfg.tau = 0.995
cfg


## Обучение

In [None]:

result = train_dqn(cfg, device=device, record_video=False)
print("Logs at:", result["log_dir"])
print("Checkpoint:", result["ckpt_path"])


## Оценка

In [None]:
from agentslab.runner.eval import evaluate_policy
from agentslab.models.value import build_qvalue_actor
from agentslab.models.policy import build_dqn_policy
from agentslab.envs.gym import make_gym_env
import torch

# Rebuild policy structure
tmp_env = make_gym_env(cfg.env_id, seed=SEED)
n_actions = tmp_env.action_spec.space.n
value_net, q_head = build_qvalue_actor(n_actions, hidden=(256,256))
greedy, _ = build_dqn_policy(value_net, q_head, action_spec=tmp_env.action_spec)
tmp_env.close()

# Load weights from checkpoint
state = torch.load(result["ckpt_path"], map_location=device)
value_net.load_state_dict(state["value_net"])

metrics = evaluate_policy(greedy, cfg.env_id, device=device, episodes=5, seed=SEED)
metrics

## Построение графиков по логам

In [None]:

from agentslab.runner.curves import plot_from_csv
plot_from_csv(result["log_dir"])


## Одна симуляция с визуализацией

In [None]:

from agentslab.envs.gym import make_gym_env
from tensordict.nn import TensorDictSequential as Seq
import torch, time

env = make_gym_env(cfg.env_id, seed=SEED, render_mode="human")
td = env.reset()
greedy.eval().to(device)
done = False
while not done:
    with torch.no_grad():
        td = td.to(device)
        td = greedy(td)
        action = td.get("action")
    td = env.step(action)
    done = (td.get("done") | td.get("terminated") | td.get("truncated")).any().item()
env.close()


## Восстановление модели из checkpoint

In [None]:

import glob, os, torch
from agentslab.runner.checkpointer import latest_run_dir

run_dir = latest_run_dir(cfg.ckpt_root, cfg.run_name)
print("Latest run dir:", run_dir)
ckpts = sorted(glob.glob(os.path.join(run_dir, "*.pt")))
ckpts[-1]
