# Assault (Atari) con DQN + PER

Este notebook entrena un agente DQN con Prioritized Experience Replay (PER) en ALE/Assault-v5.
Incluye preprocesamiento, checkpoints, logging con TensorBoard y evaluacion vs baseline aleatorio.

In [None]:
# Instalar dependencias (compatible con Colab)
!pip -q install gymnasium[atari,accept-rom-license] ale-py autorom torch tensorboard
!AutoROM --accept-license

import os
import json
import time
import subprocess
import numpy as np
import torch
import gymnasium as gym
import ale_py  # Registra el namespace ALE
from gymnasium.wrappers import RecordVideo

ROOT_DIR = os.getcwd()
ART_DIR = os.path.join(ROOT_DIR, "artifacts", "assault")
CKPT_DIR = os.path.join(ART_DIR, "checkpoints")
LOG_DIR = os.path.join(ART_DIR, "logs")
VIDEO_DIR = os.path.join(ART_DIR, "videos")
for d in [CKPT_DIR, LOG_DIR, VIDEO_DIR]:
    os.makedirs(d, exist_ok=True)

SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Dispositivo:", DEVICE)
print("Torch:", torch.__version__)
print("Gymnasium:", gym.__version__)
try:
    print(subprocess.check_output(["nvidia-smi"]).decode())
except Exception:
    print("nvidia-smi no disponible")

[0mAutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
# Implementaciones locales (sin importar src/)
import math
from collections import deque
import torch.nn as nn
import torch.nn.functional as F
from gymnasium.wrappers import AtariPreprocessing, TransformReward


class SimpleFrameStack(gym.Wrapper):
    def __init__(self, env, num_stack=4):
        super().__init__(env)
        self.num_stack = num_stack
        self.frames = deque(maxlen=num_stack)
        low = np.repeat(env.observation_space.low[None, ...], num_stack, axis=0)
        high = np.repeat(env.observation_space.high[None, ...], num_stack, axis=0)
        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=env.observation_space.dtype)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.num_stack):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        return np.stack(list(self.frames), axis=0)


def make_atari_env(
    game_id: str,
    seed: int,
    frame_skip: int = 4,
    clip_rewards: bool = False,
    render_mode: str | None = None,
):
    env = gym.make(
        game_id,
        frameskip=1,
        repeat_action_probability=0.0,
        render_mode=render_mode,
    )
    env = AtariPreprocessing(
        env,
        frame_skip=frame_skip,
        grayscale_obs=True,
        screen_size=84,
        scale_obs=False,
        terminal_on_life_loss=False,
    )
    if clip_rewards:
        env = TransformReward(env, lambda r: max(-1.0, min(1.0, r)))
    env = SimpleFrameStack(env, num_stack=4)
    env.reset(seed=seed)
    return env


class SumTree:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1, dtype=np.float32)
        self.data_pointer = 0

    @property
    def total(self) -> float:
        return float(self.tree[0])

    def add(self, priority: float):
        idx = self.data_pointer + self.capacity - 1
        self.update(idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        return idx

    def update(self, idx: int, priority: float) -> None:
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        parent = (idx - 1) // 2
        while True:
            self.tree[parent] += change
            if parent == 0:
                break
            parent = (parent - 1) // 2

    def get(self, s: float):
        idx = 0
        while True:
            left = 2 * idx + 1
            right = left + 1
            if left >= len(self.tree):
                leaf = idx
                break
            if s <= self.tree[left]:
                idx = left
            else:
                s -= self.tree[left]
                idx = right
        data_idx = leaf - self.capacity + 1
        return leaf, self.tree[leaf], data_idx


class PrioritizedReplayBuffer:
    def __init__(self, capacity: int, alpha: float = 0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.tree = SumTree(capacity)
        self.data = [None] * capacity
        self.max_priority = 1.0
        self.size = 0

    def add(self, obs, action, reward, next_obs, done):
        idx = self.tree.add(self.max_priority)
        self.data[idx - self.capacity + 1] = (obs, action, reward, next_obs, done)
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, beta: float = 0.4):
        if self.tree.total == 0:
            raise ValueError("No hay prioridades en el buffer")
        indices = []
        priorities = []
        samples = []
        segment = self.tree.total / batch_size
        for i in range(batch_size):
            data = None
            for _ in range(20):
                s = np.random.uniform(segment * i, segment * (i + 1))
                idx, p, data_idx = self.tree.get(s)
                data = self.data[data_idx]
                if data is not None:
                    indices.append(idx)
                    priorities.append(p)
                    samples.append(data)
                    break
            if data is None:
                valid = [j for j, d in enumerate(self.data) if d is not None]
                data_idx = int(np.random.choice(valid))
                idx = data_idx + self.capacity - 1
                p = self.tree.tree[idx]
                indices.append(idx)
                priorities.append(p)
                samples.append(self.data[data_idx])
        probs = np.array(priorities, dtype=np.float32) / self.tree.total
        weights = (self.size * probs) ** (-beta)
        weights /= weights.max()
        obs, actions, rewards, next_obs, dones = map(np.array, zip(*samples))
        return obs, actions, rewards, next_obs, dones, indices, weights

    def update_priorities(self, indices, priorities):
        for idx, p in zip(indices, priorities):
            priority = float(p) ** self.alpha
            self.tree.update(idx, priority)
            self.max_priority = max(self.max_priority, priority)

    def __len__(self):
        return self.size

    def state_dict(self):
        return {
            "capacity": self.capacity,
            "alpha": self.alpha,
            "tree": self.tree.tree.copy(),
            "data": self.data,
            "max_priority": self.max_priority,
            "size": self.size,
            "data_pointer": self.tree.data_pointer,
        }

    def load_state_dict(self, state):
        self.capacity = state["capacity"]
        self.alpha = state["alpha"]
        self.tree = SumTree(self.capacity)
        self.tree.tree = state["tree"]
        self.tree.data_pointer = state["data_pointer"]
        self.data = state["data"]
        self.max_priority = state["max_priority"]
        self.size = state["size"]


class EpsilonSchedule:
    def __init__(self, start: float, end: float, decay_steps: int):
        self.start = start
        self.end = end
        self.decay_steps = decay_steps

    def value(self, step: int) -> float:
        frac = min(step / self.decay_steps, 1.0)
        return self.start + frac * (self.end - self.start)


class PERBetaSchedule:
    def __init__(self, start: float, end: float, steps: int):
        self.start = start
        self.end = end
        self.steps = steps

    def value(self, step: int) -> float:
        frac = min(step / self.steps, 1.0)
        return self.start + frac * (self.end - self.start)


class QNetwork(nn.Module):
    def __init__(self, in_channels: int, num_actions: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 64, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class DQNPerAgent:
    def __init__(
        self,
        obs_shape,
        num_actions: int,
        device: str,
        gamma: float = 0.99,
        lr: float = 1e-4,
        target_update_interval: int = 10_000,
        buffer_size: int = 200_000,
        alpha: float = 0.6,
        eps_schedule: EpsilonSchedule | None = None,
        beta_schedule: PERBetaSchedule | None = None,
    ):
        self.device = device
        self.num_actions = num_actions
        self.gamma = gamma
        self.target_update_interval = target_update_interval
        self.step_count = 0

        in_channels = obs_shape[0]
        self.online = QNetwork(in_channels, num_actions).to(device)
        self.target = QNetwork(in_channels, num_actions).to(device)
        self.target.load_state_dict(self.online.state_dict())
        self.target.eval()

        self.optimizer = torch.optim.Adam(self.online.parameters(), lr=lr)
        self.replay = PrioritizedReplayBuffer(buffer_size, alpha=alpha)

        self.eps_schedule = eps_schedule or EpsilonSchedule(1.0, 0.05, 1_000_000)
        self.beta_schedule = beta_schedule or PERBetaSchedule(0.4, 1.0, 1_000_000)

    def select_action(self, obs: np.ndarray) -> int:
        eps = self.eps_schedule.value(self.step_count)
        if np.random.rand() < eps:
            return np.random.randint(self.num_actions)
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            q_values = self.online(obs_t)
        return int(torch.argmax(q_values, dim=1).item())

    def update(self, batch_size: int):
        if len(self.replay) < batch_size:
            return None

        beta = self.beta_schedule.value(self.step_count)
        obs, actions, rewards, next_obs, dones, indices, weights = self.replay.sample(batch_size, beta)

        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        next_obs_t = torch.tensor(next_obs, dtype=torch.float32, device=self.device)
        actions_t = torch.tensor(actions, dtype=torch.int64, device=self.device).unsqueeze(1)
        rewards_t = torch.tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
        dones_t = torch.tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
        weights_t = torch.tensor(weights, dtype=torch.float32, device=self.device).unsqueeze(1)

        q_values = self.online(obs_t).gather(1, actions_t)
        with torch.no_grad():
            next_q = self.target(next_obs_t).max(dim=1, keepdim=True)[0]
            target = rewards_t + self.gamma * (1.0 - dones_t) * next_q

        td_error = target - q_values
        loss = F.smooth_l1_loss(q_values, target, reduction="none")
        loss = (loss * weights_t).mean()

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.online.parameters(), 10.0)
        self.optimizer.step()

        priorities = td_error.detach().abs().cpu().numpy().squeeze() + 1e-6
        self.replay.update_priorities(indices, priorities)

        if self.step_count % self.target_update_interval == 0:
            self.target.load_state_dict(self.online.state_dict())

        return float(loss.item()), float(td_error.detach().abs().mean().item())

    def save_state(self):
        return {
            "online": self.online.state_dict(),
            "target": self.target.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "step_count": self.step_count,
            "replay": self.replay.state_dict(),
            "gamma": self.gamma,
            "target_update_interval": self.target_update_interval,
        }

    def load_state(self, state):
        self.online.load_state_dict(state["online"])
        self.target.load_state_dict(state["target"])
        self.optimizer.load_state_dict(state["optimizer"])
        self.step_count = int(state["step_count"])
        self.replay.load_state_dict(state["replay"])
        self.gamma = float(state.get("gamma", self.gamma))
        self.target_update_interval = int(state.get("target_update_interval", self.target_update_interval))


def save_checkpoint(state: dict, checkpoint_dir: str, step: int) -> str:
    os.makedirs(checkpoint_dir, exist_ok=True)
    ckpt_path = os.path.join(checkpoint_dir, f"checkpoint_{step}.pth")
    torch.save(state, ckpt_path)
    return ckpt_path


def load_checkpoint(ckpt_path: str, device: str):
    return torch.load(ckpt_path, map_location=device)


def save_config(config: dict, checkpoint_dir: str) -> str:
    os.makedirs(checkpoint_dir, exist_ok=True)
    config_path = os.path.join(checkpoint_dir, "config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)
    return config_path

In [None]:
# Hiperparametros y configuracion
config = {
    "env_id": "ALE/Assault-v5",
    "frame_skip": 4,
    "clip_rewards": True,
    "total_steps": 2_000_000,
    "learning_starts": 50_000,
    "batch_size": 32,
    "buffer_size": 200_000,
    "gamma": 0.99,
    "lr": 1e-4,
    "target_update_interval": 10_000,
    "checkpoint_interval": 200_000,
    "eps_start": 1.0,
    "eps_end": 0.05,
    "eps_decay_steps": 1_000_000,
    "beta_start": 0.4,
    "beta_end": 1.0,
    "beta_steps": 1_000_000,
    "alpha": 0.6,
    "seed": SEED
}
save_config(config, CKPT_DIR)

In [None]:
# Crear entorno
env = make_atari_env(
    config["env_id"],
    seed=SEED,
    frame_skip=config["frame_skip"],
    clip_rewards=config["clip_rewards"],
)
num_actions = env.action_space.n
obs_shape = env.observation_space.shape
print("Obs shape:", obs_shape, "Acciones:", num_actions)

In [None]:
# Agente
eps_schedule = EpsilonSchedule(
    config["eps_start"],
    config["eps_end"],
    config["eps_decay_steps"],
)
beta_schedule = PERBetaSchedule(
    config["beta_start"],
    config["beta_end"],
    config["beta_steps"],
)
agent = DQNPerAgent(
    obs_shape=obs_shape,
    num_actions=num_actions,
    device=DEVICE,
    gamma=config["gamma"],
    lr=config["lr"],
    target_update_interval=config["target_update_interval"],
    buffer_size=config["buffer_size"],
    alpha=config["alpha"],
    eps_schedule=eps_schedule,
    beta_schedule=beta_schedule,
)

In [None]:
# Writer de TensorBoard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=LOG_DIR)


def obs_to_array(obs):
    arr = np.asarray(obs)
    # Si ya viene en formato (C, H, W), no transponer
    if arr.ndim == 3 and arr.shape[0] in (1, 4):
        return arr
    # Si viene como (H, W, C), convertir a (C, H, W)
    if arr.ndim == 3:
        return np.transpose(arr, (2, 0, 1))
    return arr


def preprocess(obs):
    arr = obs_to_array(obs)
    return arr.astype(np.float32) / 255.0

In [None]:
# Loop de entrenamiento con checkpoints
resume_path = None  # define ruta de checkpoint para reanudar
if resume_path:
    state = load_checkpoint(resume_path, device=DEVICE)
    agent.load_state(state)
    print("Reanudado desde", resume_path)

obs, info = env.reset()
episode_reward = 0.0
episode_len = 0
start_time = time.time()

for step in range(1, config["total_steps"] + 1):
    agent.step_count = step
    obs_proc = preprocess(obs)
    action = agent.select_action(obs_proc)
    next_obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated

    agent.replay.add(obs_to_array(obs), action, reward, obs_to_array(next_obs), float(done))
    episode_reward += reward
    episode_len += 1

    if step > config["learning_starts"]:
        update_out = agent.update(config["batch_size"])
        if update_out is not None:
            loss, td_error = update_out
            writer.add_scalar("train/loss", loss, step)
            writer.add_scalar("train/td_error", td_error, step)

    eps_value = agent.eps_schedule.value(step)
    writer.add_scalar("train/epsilon", eps_value, step)

    if done:
        writer.add_scalar("rollout/episode_return", episode_reward, step)
        writer.add_scalar("rollout/episode_length", episode_len, step)
        obs, info = env.reset()
        episode_reward = 0.0
        episode_len = 0
    else:
        obs = next_obs

    if step % config["checkpoint_interval"] == 0:
        state = agent.save_state()
        ckpt_path = save_checkpoint(state, CKPT_DIR, step)
        print("Guardado", ckpt_path)

train_time = time.time() - start_time
print(f"Tiempo de entrenamiento (s): {train_time:.1f}")

In [None]:
# Helpers de evaluacion
def greedy_action(agent, obs):
    obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    with torch.no_grad():
        q = agent.online(obs_t)
    return int(torch.argmax(q, dim=1).item())

def run_eval(env_id, n_episodes=10, seed_offset=0, use_random=False, record_video=False):
    eval_env = make_atari_env(
        env_id,
        seed=SEED + seed_offset,
        frame_skip=config["frame_skip"],
        clip_rewards=False,
        render_mode="rgb_array" if record_video else None,
    )
    if record_video:
        eval_env = RecordVideo(eval_env, video_folder=VIDEO_DIR, name_prefix="assault_eval")
    rewards = []
    for ep in range(n_episodes):
        obs, info = eval_env.reset(seed=SEED + seed_offset + ep)
        done = False
        ep_reward = 0.0
        while not done:
            obs_proc = preprocess(obs)
            if use_random:
                action = eval_env.action_space.sample()
            else:
                action = greedy_action(agent, obs_proc)
            obs, reward, terminated, truncated, info = eval_env.step(action)
            done = terminated or truncated
            ep_reward += reward
        rewards.append(ep_reward)
    eval_env.close()
    mean_r = float(np.mean(rewards))
    std_r = float(np.std(rewards))
    return rewards, mean_r, std_r

agent.online.eval()
eval_rewards, eval_mean, eval_std = run_eval(config["env_id"], n_episodes=10)
print(f"DQN+PER recompensa media: {eval_mean:.2f} +/- {eval_std:.2f}")

rand_rewards, rand_mean, rand_std = run_eval(config["env_id"], n_episodes=10, use_random=True)
print(f"Politica aleatoria recompensa media: {rand_mean:.2f} +/- {rand_std:.2f}")

In [None]:
# Exportar un video corto de evaluacion
_ = run_eval(config["env_id"], n_episodes=1, seed_offset=9999, record_video=True)
print("Video guardado en:", VIDEO_DIR)

## TensorBoard

Ejecuta en Colab:

```
%load_ext tensorboard
%tensorboard --logdir artifacts/assault/logs
```

## Reporte tecnico (completar despues del entrenamiento)
- Algoritmo: DQN + PER
- Hiperparametros: ver config.json
- Librerias y versiones: impresas en la celda de setup
- Hardware: salida de nvidia-smi en la celda de setup
- Tiempo de entrenamiento: impreso al finalizar
- Resultados: media/desviacion en 10 episodios + baseline aleatorio
- Conclusiones: agrega observaciones y siguientes pasos