In [1]:
!pip install gymnasium[atari] ale-py autorom imageio imageio-ffmpeg torch torchvision
!AutoROM --accept-license

Collecting autorom
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Downloading AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Installing collected packages: autorom
Successfully installed autorom-0.6.1
AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/adventure.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/air_raid.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/alien.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/amidar.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/assault.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/asterix.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/asteroids.bin
Installed /usr/local/lib/python3.12/dist-packages/AutoROM/roms/atlantis.bin
Installed /usr/local/lib/py

In [5]:
from google.colab import drive
drive.mount('/content/drive')

import os
BASE_DIR = "/content/drive/MyDrive/galaxian_rl"
os.makedirs(BASE_DIR, exist_ok=True)
print("Usando carpeta:", BASE_DIR)

Mounted at /content/drive
Usando carpeta: /content/drive/MyDrive/galaxian_rl


In [6]:
# env_utils.py
import gymnasium as gym
import ale_py
import numpy as np
import cv2
from collections import deque


class CustomAtariPreprocessing(gym.Wrapper):
    """
    Simplificado: convierte a escala de grises, reescala a (84, 84)
    y aplica frame_skip (por defecto 4).
    """

    def __init__(self, env, frame_skip=4, screen_size=84, grayscale_obs=True):
        super().__init__(env)
        self.frame_skip = frame_skip
        self.screen_size = screen_size
        self.grayscale_obs = grayscale_obs

        obs_shape = (screen_size, screen_size)
        if not grayscale_obs:
            obs_shape += (3,)

        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=obs_shape, dtype=np.uint8
        )

    def process_frame(self, frame):
        """Convierte y reescala el frame a 84x84"""
        if self.grayscale_obs:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
            frame = cv2.resize(frame, (self.screen_size, self.screen_size), interpolation=cv2.INTER_AREA)
            return frame
        else:
            frame = cv2.resize(frame, (self.screen_size, self.screen_size), interpolation=cv2.INTER_AREA)
            return frame

    def step(self, action):
        """Ejecuta frame_skip pasos y devuelve el último frame procesado"""
        total_reward = 0.0
        terminated = truncated = False
        for _ in range(self.frame_skip):
            obs, reward, term, trunc, info = self.env.step(action)
            total_reward += reward
            terminated |= term
            truncated |= trunc
            if terminated or truncated:
                break
        processed = self.process_frame(obs)
        return processed, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        processed = self.process_frame(obs)
        return processed, info


class CustomFrameStack(gym.Wrapper):
    """
    Apila los últimos N frames para captar la dinámica temporal.
    """

    def __init__(self, env, num_stack=4):
        super().__init__(env)
        self.num_stack = num_stack
        self.frames = deque([], maxlen=num_stack)
        low = np.repeat(env.observation_space.low[np.newaxis, ...], num_stack, axis=0)
        high = np.repeat(env.observation_space.high[np.newaxis, ...], num_stack, axis=0)
        self.observation_space = gym.spaces.Box(
            low=low.min(), high=high.max(), dtype=env.observation_space.dtype, shape=(num_stack, *env.observation_space.shape)
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.num_stack):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        return np.stack(self.frames, axis=0)


def make_galaxian_env(seed: int | None = None, render_mode: str | None = None):
    """
    Crea el entorno ALE/Galaxian-v5 con preprocesamiento manual.
    - Reescalado a 84x84
    - Escala de grises
    - Frame skip = 4
    - Frame stack = 4
    """
    env = gym.make("ALE/Galaxian-v5", render_mode=render_mode)
    if seed is not None:
        env.reset(seed=seed)

    env = CustomAtariPreprocessing(env, frame_skip=4, screen_size=84, grayscale_obs=True)
    env = CustomFrameStack(env, num_stack=4)

    return env


In [7]:
# dqn_galaxian.py
# ------------------------------------------------------------
# DQN desde cero para ALE/Galaxian-v5 (Gymnasium + ALE-Py)
# Compatible con wrappers personalizados (env_utils.make_galaxian_env)
# - Observación por defecto: (4, 84, 84) (canal-primero)
# - Robusto si viene en (84, 84, 4) (canal-último): auto-detecta y permuta
# - Checkpoints periódicos en 'checkpoint_dir' (útil para Google Drive en Colab)
# ------------------------------------------------------------

import os
import csv
import random
from collections import deque
from typing import Tuple, Deque, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use("Agg")  # backend sin pantalla para guardar PNG
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================
# Red DQN (estilo Nature)
# ============================================================
class DQN(nn.Module):
    """
    Red convolucional para estimar Q(s,a).
    Espera tensores en float normalizados [0,1] con formato (B, C, H, W).
    Si recibe (B, H, W, C) permuta automáticamente a (B, C, H, W).
    """
    def __init__(self, input_shape: Tuple[int, int, int], n_actions: int):
        super().__init__()
        c, h, w = input_shape
        self.expected_c = c  # Para validar/ajustar layout en forward

        self.features = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),  # -> (B,32,20,20) aprox con 84x84
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # -> (B,64,9,9)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # -> (B,64,7,7)
            nn.ReLU(inplace=True),
            nn.Flatten()
        )

        # Inferir tamaño del flatten de forma programática
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            n_flat = self.features(dummy).shape[1]

        self.head = nn.Sequential(
            nn.Linear(n_flat, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, n_actions)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B,C,H,W) uint8/float o (B,H,W,C) uint8/float
        Retorna Q(s,·): (B, n_actions)
        """
        if x.ndim != 4:
            raise ValueError(f"Se esperaba un tensor 4D, recibido x.ndim={x.ndim}")

        # Si viene en canal-último (B,H,W,C), permutamos a (B,C,H,W)
        if x.shape[1] != self.expected_c and x.shape[-1] == self.expected_c:
            x = x.permute(0, 3, 1, 2)

        # Asegurar float y normalizar
        x = x.float() / 255.0
        x = self.features(x)
        return self.head(x)


# ============================================================
# Replay Buffer
# ============================================================
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer: Deque = deque(maxlen=capacity)

    def push(self, s, a, r, ns, d):
        # Guardamos en bruto (uint8 para ahorrar memoria)
        self.buffer.append((s, a, r, ns, d))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, ns, d = map(np.array, zip(*batch))
        # Convertimos tipos; estados se quedan en uint8 para pasar a torch y normalizar allí
        return s, a, r, ns, d

    def __len__(self):
        return len(self.buffer)


# ============================================================
# Política para play.py (interfaz: __call__(obs, info) -> action)
# ============================================================
class DQNPolicy:
    def __init__(self, q_net: 'DQN'):
        self.q_net = q_net.to(device)
        self.q_net.eval()

    @torch.no_grad()
    def __call__(self, obs: np.ndarray, info: dict) -> int:
        """
        obs: (C,H,W) uint8 o (H,W,C) uint8
        """
        if obs.ndim != 3:
            raise ValueError(f"Se esperaba obs 3D, recibido obs.ndim={obs.ndim}")

        # Meter batch
        obs_batch = np.expand_dims(obs, axis=0)  # (1, C,H,W) o (1, H,W,C)
        obs_t = torch.from_numpy(obs_batch).to(device)
        q_vals = self.q_net(obs_t)
        return int(torch.argmax(q_vals, dim=1).item())


# ============================================================
# Utilidades de logging
# ============================================================
def _moving_average(x: List[float], k: int = 100):
    """Calcula la media móvil de una lista de valores."""
    if len(x) == 0:
        return []
    out = []
    s = 0.0
    q = []
    for i, v in enumerate(x):
        q.append(v)
        s += v
        if len(q) > k:
            s -= q.pop(0)
        out.append(s / len(q))
    return out

def _save_plot_and_csv(checkpoint_dir: str, rewards: List[float], episode: int):
    """Guarda gráfica de recompensas (PNG) y registro CSV."""
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Guardar CSV
    csv_path = os.path.join(checkpoint_dir, "rewards_log.csv")
    new_file = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.writer(f)
        if new_file:
            w.writerow(["episode", "reward"])
        w.writerow([episode, rewards[-1]])

    # Guardar gráfica PNG
    plt.figure(figsize=(8, 4.5))
    plt.plot(rewards, label="Reward")
    ma = _moving_average(rewards, k=100)
    if len(ma) > 0:
        plt.plot(ma, label="Moving Avg (100)")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Training Rewards - DQN")
    plt.legend()
    png_path = os.path.join(checkpoint_dir, f"rewards_ep{episode}.png")
    plt.tight_layout()
    plt.savefig(png_path, dpi=120)
    plt.close()
    print(f"[LOG] Guardadas gráfica y CSV en: {png_path} / {csv_path}")


# ============================================================
# Entrenamiento DQN
# ============================================================
def train_dqn(
    checkpoint_dir: str,
    total_episodes: int = 500,
    replay_size: int = 100_000,
    batch_size: int = 32,
    gamma: float = 0.99,
    lr: float = 1e-4,
    start_epsilon: float = 1.0,
    end_epsilon: float = 0.1,
    epsilon_decay_episodes: int = 300,
    target_update_interval: int = 1_000,  # en pasos
    train_start: int = 10_000,            # tamaño mínimo de buffer para empezar a entrenar
    save_interval: int = 50,              # guardar checkpoints cada N episodios
    plot_interval: int = 50,              # guardar gráfica y CSV cada N episodios
    max_steps_per_episode: int | None = None,
    seed: int = 42,
):
    """
    Entrena un DQN minimalista para Galaxian.
    Guarda checkpoints periódicos en `checkpoint_dir` (ideal en /content/drive/... en Colab).
    Guarda gráficas de recompensa y CSV cada `plot_interval` episodios.
    Retorna la red Q entrenada (q_net).
    """
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Entorno con wrappers personalizados (sin depender de gymnasium.wrappers)
    env = make_galaxian_env(seed=seed, render_mode=None)

    # Obtenemos shape de la observación y detectamos layout
    obs, _ = env.reset()
    if obs.ndim != 3:
        env.close()
        raise ValueError(f"Observación inesperada: ndim={obs.ndim}, se esperaba 3")

    # Auto-detectar (C,H,W) vs (H,W,C) y construir input_shape=(C,H,W)
    if obs.shape[0] in (1, 3, 4):  # canal-primero típico de CustomFrameStack
        input_shape = (obs.shape[0], obs.shape[1], obs.shape[2])
    else:                           # canal-último
        input_shape = (obs.shape[2], obs.shape[0], obs.shape[1])

    n_actions = env.action_space.n

    q_net = DQN(input_shape, n_actions).to(device)
    target_net = DQN(input_shape, n_actions).to(device)
    target_net.load_state_dict(q_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(q_net.parameters(), lr=lr)
    buffer = ReplayBuffer(replay_size)

    global_step = 0

    def epsilon_by_episode(ep: int) -> float:
        if ep >= epsilon_decay_episodes:
            return end_epsilon
        frac = ep / float(epsilon_decay_episodes)
        return start_epsilon + frac * (end_epsilon - start_epsilon)

    # (Opcional) logging sencillo
    rewards_log: List[float] = []

    for episode in range(1, total_episodes + 1):
        obs, _ = env.reset()
        done = False
        total_reward = 0.0
        eps = epsilon_by_episode(episode)
        steps_in_ep = 0

        while not done:
            global_step += 1
            steps_in_ep += 1

            # Epsilon-greedy
            if random.random() < eps:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    ob = np.expand_dims(obs, axis=0)  # (1, C,H,W) o (1, H,W,C)
                    ob_t = torch.from_numpy(ob).to(device)
                    q_vals = q_net(ob_t)
                    action = int(torch.argmax(q_vals, dim=1).item())

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += float(reward)

            # Guardamos en buffer (mantener dtype uint8 en estados y next_estados)
            buffer.push(obs, action, reward, next_obs, done)
            obs = next_obs

            # Entrenar cuando el buffer tenga suficientes transiciones
            if len(buffer) >= train_start:
                s, a, r, ns, d = buffer.sample(batch_size)

                # Convertir a tensores
                s_t = torch.from_numpy(s).to(device)
                ns_t = torch.from_numpy(ns).to(device)
                a_t = torch.from_numpy(a).long().to(device)
                r_t = torch.from_numpy(r).float().to(device)
                d_t = torch.from_numpy(d.astype(np.float32)).to(device)

                # Q(s,a)
                q_vals = q_net(s_t)                       # (B, n_actions)
                q_a = q_vals.gather(1, a_t.unsqueeze(1)).squeeze(1)  # (B,)

                # y = r + gamma * max_a' Q_target(ns, a') * (1 - done)
                with torch.no_grad():
                    next_q = target_net(ns_t).max(1)[0]
                    target = r_t + gamma * next_q * (1.0 - d_t)

                loss = nn.functional.mse_loss(q_a, target)

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
                optimizer.step()

            # Actualizar target network cada cierto número de pasos
            if global_step % target_update_interval == 0:
                target_net.load_state_dict(q_net.state_dict())

            # Límite duro de pasos por episodio (opcional)
            if max_steps_per_episode is not None and steps_in_ep >= max_steps_per_episode:
                break

        rewards_log.append(total_reward)
        print(f"[DQN] Episodio {episode}/{total_episodes} | Recompensa: {total_reward:.1f} | eps={eps:.3f} | buffer={len(buffer)}")

        # Guardar checkpoint en Google Drive cada N episodios
        if episode % save_interval == 0:
            ckpt_path = os.path.join(checkpoint_dir, f"dqn_galaxian_ep{episode}.pth")
            torch.save({
                "q_net": q_net.state_dict(),
                "target_net": target_net.state_dict(),
                "optimizer": optimizer.state_dict(),
                "episode": episode,
                "global_step": global_step,
                "rewards": rewards_log,
                "input_shape": input_shape,
                "n_actions": n_actions,
            }, ckpt_path)
            print(f"[DQN] Checkpoint guardado en: {ckpt_path}")

        # Guardar gráfica + CSV cada N episodios
        if episode % plot_interval == 0:
            _save_plot_and_csv(checkpoint_dir, rewards_log, episode)

    env.close()

    # Guardar pesos finales (solo la Q principal para inferencia)
    final_path = os.path.join(checkpoint_dir, "dqn_galaxian_final.pth")
    torch.save(q_net.state_dict(), final_path)
    print(f"[DQN] Modelo final guardado en: {final_path}")

    return q_net


In [None]:
dqn_dir = f"{BASE_DIR}/dqn"
q_net = train_dqn(checkpoint_dir=dqn_dir, total_episodes=10000, save_interval = 500, plot_interval=500)

[DQN] Episodio 1/10000 | Recompensa: 330.0 | eps=0.997 | buffer=71
[DQN] Episodio 2/10000 | Recompensa: 390.0 | eps=0.994 | buffer=187
[DQN] Episodio 3/10000 | Recompensa: 830.0 | eps=0.991 | buffer=426
[DQN] Episodio 4/10000 | Recompensa: 540.0 | eps=0.988 | buffer=550
[DQN] Episodio 5/10000 | Recompensa: 490.0 | eps=0.985 | buffer=698
[DQN] Episodio 6/10000 | Recompensa: 300.0 | eps=0.982 | buffer=776
[DQN] Episodio 7/10000 | Recompensa: 680.0 | eps=0.979 | buffer=929
[DQN] Episodio 8/10000 | Recompensa: 580.0 | eps=0.976 | buffer=1054
[DQN] Episodio 9/10000 | Recompensa: 1010.0 | eps=0.973 | buffer=1243
[DQN] Episodio 10/10000 | Recompensa: 270.0 | eps=0.970 | buffer=1330
[DQN] Episodio 11/10000 | Recompensa: 710.0 | eps=0.967 | buffer=1503
[DQN] Episodio 12/10000 | Recompensa: 360.0 | eps=0.964 | buffer=1631
[DQN] Episodio 13/10000 | Recompensa: 330.0 | eps=0.961 | buffer=1780
[DQN] Episodio 14/10000 | Recompensa: 390.0 | eps=0.958 | buffer=1877
[DQN] Episodio 15/10000 | Recompensa

In [None]:
# a2c_galaxian.py
# ------------------------------------------------------------
# A2C desde cero para ALE/Galaxian-v5 (Gymnasium + ALE-Py)
# Compatible con wrappers personalizados (env_utils.make_galaxian_env)
# - Observación por defecto: (4, 84, 84) (canal-primero)
# - Robusto si llega en (84, 84, 4) (canal-último): auto-detecta y permuta
# - Checkpoints periódicos en 'checkpoint_dir' (útil para Google Drive en Colab)
# ------------------------------------------------------------

import os
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================
# Red Actor-Crítico (backbone conv compartido)
# ============================================================
class A2CNet(nn.Module):
    """
    Red Actor-Crítico con backbone convolucional estilo Nature.
    Espera tensores en float normalizados [0,1] con formato (B, C, H, W).
    Si recibe (B, H, W, C), permuta automáticamente a (B, C, H, W).
    """
    def __init__(self, input_shape: Tuple[int, int, int], n_actions: int):
        super().__init__()
        c, h, w = input_shape
        self.expected_c = c  # para validar/ajustar layout en forward

        self.features = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),  # -> (B,32,20,20) aprox con 84x84
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # -> (B,64,9,9)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # -> (B,64,7,7)
            nn.ReLU(inplace=True),
            nn.Flatten()
        )

        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            n_flat = self.features(dummy).shape[1]

        self.actor = nn.Sequential(
            nn.Linear(n_flat, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, n_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(n_flat, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, x: torch.Tensor):
        """
        x: (B,C,H,W) uint8/float o (B,H,W,C) uint8/float
        Retorna:
          - logits de la política: (B, n_actions)
          - valores V(s): (B,)
        """
        if x.ndim != 4:
            raise ValueError(f"Se esperaba un tensor 4D, recibido x.ndim={x.ndim}")

        # Si viene en canal-último (B,H,W,C), permutamos a (B,C,H,W)
        if x.shape[1] != self.expected_c and x.shape[-1] == self.expected_c:
            x = x.permute(0, 3, 1, 2)

        x = x.float() / 255.0
        feat = self.features(x)
        logits = self.actor(feat)
        values = self.critic(feat).squeeze(-1)
        return logits, values


# ============================================================
# Política para play.py (interfaz: __call__(obs, info) -> action)
# ============================================================
class A2CPolicy:
    def __init__(self, net: 'A2CNet'):
        self.net = net.to(device)
        self.net.eval()

    @torch.no_grad()
    def __call__(self, obs: np.ndarray, info: dict) -> int:
        """
        obs: (C,H,W) uint8 o (H,W,C) uint8
        Estrategia greedy (argmax) para evaluación/competencia.
        """
        if obs.ndim != 3:
            raise ValueError(f"Se esperaba obs 3D, recibido obs.ndim={obs.ndim}")

        obs_batch = np.expand_dims(obs, axis=0)  # (1, C,H,W) o (1, H,W,C)
        obs_t = torch.from_numpy(obs_batch).to(device)
        logits, _ = self.net(obs_t)
        probs = torch.softmax(logits, dim=-1)
        action = torch.argmax(probs, dim=-1).item()
        return int(action)


# ============================================================
# Entrenamiento A2C
# ============================================================
def train_a2c(
    checkpoint_dir: str,
    total_episodes: int = 500,
    gamma: float = 0.99,
    lr: float = 2.5e-4,
    entropy_coef: float = 0.01,
    value_coef: float = 0.5,
    rollout_length: int = 5,
    save_interval: int = 50,
    max_steps_per_episode: int | None = None,
    seed: int = 123,
    gae_lambda: float = 0.95,
):
    """
    Entrena un A2C minimalista para Galaxian con rollouts cortos (n-steps) y GAE(λ).
    Guarda checkpoints periódicos en `checkpoint_dir` (ideal Drive en Colab).
    Retorna la red A2C entrenada.
    """
    os.makedirs(checkpoint_dir, exist_ok=True)

    env = make_galaxian_env(seed=seed, render_mode=None)

    # Detectar layout y shape de entrada
    obs, _ = env.reset()
    if obs.ndim != 3:
        env.close()
        raise ValueError(f"Observación inesperada: ndim={obs.ndim}, se esperaba 3")

    if obs.shape[0] in (1, 3, 4):  # canal-primero típico de CustomFrameStack
        input_shape = (obs.shape[0], obs.shape[1], obs.shape[2])
    else:                           # canal-último
        input_shape = (obs.shape[2], obs.shape[0], obs.shape[1])

    n_actions = env.action_space.n

    net = A2CNet(input_shape, n_actions).to(device)
    optimizer = optim.RMSprop(net.parameters(), lr=lr, eps=1e-5)

    episode_idx = 0
    rewards_log: List[float] = []

    while episode_idx < total_episodes:
        obs, _ = env.reset()
        done = False
        ep_reward = 0.0
        steps_in_ep = 0

        while not done:
            # Trajectoria corta (rollout)
            log_probs = []
            values = []
            rewards = []
            dones = []
            entropies = []

            for _ in range(rollout_length):
                if done:
                    break

                ob = np.expand_dims(obs, axis=0)             # (1, C,H,W) o (1, H,W,C)
                ob_t = torch.from_numpy(ob).to(device)

                logits, value = net(ob_t)
                probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)

                action = dist.sample()
                entropy = dist.entropy().mean()

                next_obs, reward, terminated, truncated, _ = env.step(action.item())
                done = terminated or truncated

                log_probs.append(dist.log_prob(action).squeeze(0))
                values.append(value.squeeze(0))
                rewards.append(torch.tensor(reward, dtype=torch.float32, device=device))
                dones.append(torch.tensor(float(done), device=device))
                entropies.append(entropy)

                obs = next_obs
                ep_reward += float(reward)
                steps_in_ep += 1

                if max_steps_per_episode is not None and steps_in_ep >= max_steps_per_episode:
                    done = True
                    break

            # Bootstrap del valor para el último estado
            if done:
                next_value = torch.zeros(1, device=device)
            else:
                nb = np.expand_dims(obs, axis=0)
                nb_t = torch.from_numpy(nb).to(device)
                _, nv = net(nb_t)
                next_value = nv.detach()

            # GAE(λ) + retornos
            returns = []
            gae = 0
            # Usamos next_value como V(s_{t+T})
            v_next = next_value
            for r, d, v in zip(reversed(rewards), reversed(dones), reversed(values)):
                v_next = v_next * (1.0 - d)  # si done==1, no bootstrap
                delta = r + gamma * v_next - v
                gae = delta + gamma * gae_lambda * (1.0 - d) * gae
                v_next = v
                returns.insert(0, gae + v)

            returns = torch.stack(returns)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            entropies_t = torch.stack(entropies) if len(entropies) > 0 else torch.tensor(0.0, device=device)

            advantage = returns - values_t

            policy_loss = -(log_probs_t * advantage.detach()).mean()
            value_loss = advantage.pow(2).mean()
            entropy_loss = entropies_t.mean() if entropies_t.ndim > 0 else entropies_t

            loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), 0.5)
            optimizer.step()

            if done:
                break  # salir al cierre del episodio

        episode_idx += 1
        rewards_log.append(ep_reward)
        print(f"[A2C] Episodio {episode_idx}/{total_episodes} | Recompensa: {ep_reward:.1f}")

        # Guardar checkpoint cada N episodios
        if episode_idx % save_interval == 0:
            ckpt_path = os.path.join(checkpoint_dir, f"a2c_galaxian_ep{episode_idx}.pth")
            torch.save({
                "net": net.state_dict(),
                "optimizer": optimizer.state_dict(),
                "episode": episode_idx,
                "rewards": rewards_log,
                "input_shape": input_shape,
                "n_actions": n_actions,
            }, ckpt_path)
            print(f"[A2C] Checkpoint guardado en: {ckpt_path}")

    env.close()

    # Guardar pesos finales para inferencia
    final_path = os.path.join(checkpoint_dir, "a2c_galaxian_final.pth")
    torch.save(net.state_dict(), final_path)
    print(f"[A2C] Modelo final guardado en: {final_path}")

    return net


In [None]:
a2c_dir = f"{BASE_DIR}/a2c"
a2c_net = train_a2c(checkpoint_dir=a2c_dir, total_episodes=50)

[A2C] Episodio 1/50 | Recompensa: 490.0
[A2C] Episodio 2/50 | Recompensa: 800.0
[A2C] Episodio 3/50 | Recompensa: 400.0
[A2C] Episodio 4/50 | Recompensa: 390.0
[A2C] Episodio 5/50 | Recompensa: 600.0
[A2C] Episodio 6/50 | Recompensa: 540.0
[A2C] Episodio 7/50 | Recompensa: 450.0
[A2C] Episodio 8/50 | Recompensa: 500.0
[A2C] Episodio 9/50 | Recompensa: 980.0
[A2C] Episodio 10/50 | Recompensa: 590.0
[A2C] Episodio 11/50 | Recompensa: 810.0
[A2C] Episodio 12/50 | Recompensa: 530.0
[A2C] Episodio 13/50 | Recompensa: 1020.0
[A2C] Episodio 14/50 | Recompensa: 890.0
[A2C] Episodio 15/50 | Recompensa: 660.0
[A2C] Episodio 16/50 | Recompensa: 420.0
[A2C] Episodio 17/50 | Recompensa: 240.0
[A2C] Episodio 18/50 | Recompensa: 710.0
[A2C] Episodio 19/50 | Recompensa: 1170.0
[A2C] Episodio 20/50 | Recompensa: 870.0
[A2C] Episodio 21/50 | Recompensa: 650.0
[A2C] Episodio 22/50 | Recompensa: 600.0
[A2C] Episodio 23/50 | Recompensa: 560.0
[A2C] Episodio 24/50 | Recompensa: 570.0
[A2C] Episodio 25/50 | 

In [8]:
# dqn_dueling_per.py
# ------------------------------------------------------------
# Dueling Double DQN con Prioritized Experience Replay (PER)
# - Compatible con env_utils.make_galaxian_env (custom wrappers)
# - Robusto a (C,H,W) o (H,W,C)
# - Checkpoints periódicos en Google Drive
# - Gráficas de recompensa y media móvil (PNG) + log CSV
# ------------------------------------------------------------

import os
import csv
import math
import random
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use("Agg")  # backend sin pantalla para guardar PNG
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ===========================
#  Segment Trees para PER
# ===========================
class SegmentTree:
    def __init__(self, capacity, fn):
        assert capacity > 0 and (capacity & (capacity - 1)) == 0, \
            "capacity debe ser potencia de 2"
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity, dtype=np.float32)
        self.fn = fn

    def update(self, idx, value):
        i = idx + self.capacity
        self.tree[i] = value
        i //= 2
        while i >= 1:
            self.tree[i] = self.fn(self.tree[2 * i], self.tree[2 * i + 1])
            i //= 2

    def reduce(self, start, end):
        res_left = None
        res_right = None
        start += self.capacity
        end += self.capacity
        while start <= end:
            if (start % 2) == 1:
                res_left = self.tree[start] if res_left is None else self.fn(res_left, self.tree[start])
                start += 1
            if (end % 2) == 0:
                res_right = self.tree[end] if res_right is None else self.fn(self.tree[end], res_right)
                end -= 1
            start //= 2
            end //= 2
        if res_left is None:
            return res_right
        if res_right is None:
            return res_left
        return self.fn(res_left, res_right)

    def __getitem__(self, idx):
        return self.tree[idx + self.capacity]


class SumSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super().__init__(capacity, fn=lambda a, b: a + b)

    def sum(self):
        return self.tree[1]

    def find_prefixsum_idx(self, prefixsum):
        """Devuelve el índice i tal que sum(0..i) >= prefixsum."""
        idx = 1
        while idx < self.capacity:
            left = 2 * idx
            if self.tree[left] >= prefixsum:
                idx = left
            else:
                prefixsum -= self.tree[left]
                idx = left + 1
        return idx - self.capacity


class MinSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super().__init__(capacity, fn=min)

    def min(self):
        return self.tree[1]


# =======================================
#  Prioritized Replay Buffer (proportional)
# =======================================
class PrioritizedReplayBuffer:
    def __init__(self, capacity: int, alpha: float = 0.6, eps: float = 1e-6):
        # capacity → potencia de 2 para segment tree
        pow2 = 1
        while pow2 < capacity:
            pow2 *= 2
        self.capacity = pow2
        self.alpha = alpha
        self.eps = eps

        self.pos = 0
        self.size = 0

        self.states = [None] * self.capacity
        self.actions = np.zeros(self.capacity, dtype=np.int64)
        self.rewards = np.zeros(self.capacity, dtype=np.float32)
        self.next_states = [None] * self.capacity
        self.dones = np.zeros(self.capacity, dtype=np.bool_)

        self.sum_tree = SumSegmentTree(self.capacity)
        self.min_tree = MinSegmentTree(self.capacity)
        self.max_priority = 1.0

        # Inicializa árboles con prioridad mínima
        for i in range(self.capacity):
            self.sum_tree.update(i, 0.0)
            self.min_tree.update(i, float("inf"))

    def __len__(self):
        return self.size

    def add(self, s, a, r, ns, d):
        idx = self.pos
        self.states[idx] = s
        self.actions[idx] = a
        self.rewards[idx] = r
        self.next_states[idx] = ns
        self.dones[idx] = d

        p = (self.max_priority + self.eps) ** self.alpha
        self.sum_tree.update(idx, p)
        self.min_tree.update(idx, p)

        self.pos = (self.pos + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, beta: float = 0.4):
        """Devuelve (indices, w, batch) con pesos de importancia."""
        out_idx = []
        out_s = []
        out_a = np.empty(batch_size, dtype=np.int64)
        out_r = np.empty(batch_size, dtype=np.float32)
        out_ns = []
        out_d = np.empty(batch_size, dtype=np.float32)

        total = self.sum_tree.sum()
        segment = total / batch_size
        min_prob = self.min_tree.min() / total
        max_w = (min_prob * self.size) ** (-beta)

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            mass = random.random() * (b - a) + a
            idx = self.sum_tree.find_prefixsum_idx(mass)
            out_idx.append(idx)
            out_s.append(self.states[idx])
            out_a[i] = self.actions[idx]
            out_r[i] = self.rewards[idx]
            out_ns.append(self.next_states[idx])
            out_d[i] = float(self.dones[idx])

        # pesos de importancia
        probs = np.array([self.sum_tree[idx] / total for idx in out_idx], dtype=np.float32)
        w = (probs * self.size) ** (-beta)
        w = w / max_w
        w = w.astype(np.float32)

        return np.array(out_idx), w, (np.array(out_s), out_a, out_r, np.array(out_ns), out_d)

    def update_priorities(self, idxs, priorities):
        for i, p in zip(idxs, priorities):
            p = float(p + self.eps)
            self.sum_tree.update(i, (p) ** self.alpha)
            self.min_tree.update(i, (p) ** self.alpha)
            self.max_priority = max(self.max_priority, p)


# ===========================
#  Red Dueling
# ===========================
class DuelingDQN(nn.Module):
    """
    Backbone conv + (stream Valor) + (stream Ventaja). Q = V + (A - mean(A))
    Espera tensores (B,C,H,W) float/uint8; si llega (B,H,W,C) permuta.
    """
    def __init__(self, input_shape: Tuple[int, int, int], n_actions: int):
        super().__init__()
        c, h, w = input_shape
        self.expected_c = c

        self.features = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Flatten()
        )
        with torch.no_grad():
            n_flat = self.features(torch.zeros(1, c, h, w)).shape[1]

        self.value = nn.Sequential(
            nn.Linear(n_flat, 512), nn.ReLU(inplace=True),
            nn.Linear(512, 1)
        )
        self.advantage = nn.Sequential(
            nn.Linear(n_flat, 512), nn.ReLU(inplace=True),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        if x.ndim != 4:
            raise ValueError(f"x.ndim={x.ndim}, esperado 4")
        if x.shape[1] != self.expected_c and x.shape[-1] == self.expected_c:
            x = x.permute(0, 3, 1, 2)
        x = x.float() / 255.0
        feat = self.features(x)
        v = self.value(feat)                  # (B,1)
        a = self.advantage(feat)              # (B,A)
        q = v + (a - a.mean(dim=1, keepdim=True))
        return q


# ===========================
#  Política para play.py
# ===========================
class DQNPolicy:
    def __init__(self, q_net: 'DuelingDQN'):
        self.q_net = q_net.to(device)
        self.q_net.eval()

    @torch.no_grad()
    def __call__(self, obs: np.ndarray, info: dict) -> int:
        if obs.ndim != 3:
            raise ValueError("obs debe ser 3D")
        ob = np.expand_dims(obs, axis=0)
        ob_t = torch.from_numpy(ob).to(device)
        q = self.q_net(ob_t)
        return int(torch.argmax(q, dim=1).item())


# ===========================
#  Utilidades de logging
# ===========================
def _moving_average(x: List[float], k: int = 100):
    if len(x) == 0:
        return []
    out = []
    s = 0.0
    q = []
    for i, v in enumerate(x):
        q.append(v)
        s += v
        if len(q) > k:
            s -= q.pop(0)
        out.append(s / len(q))
    return out

def _save_plot_and_csv(checkpoint_dir: str, rewards: List[float], episode: int):
    os.makedirs(checkpoint_dir, exist_ok=True)
    # CSV
    csv_path = os.path.join(checkpoint_dir, "rewards_log.csv")
    new_file = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.writer(f)
        if new_file:
            w.writerow(["episode", "reward"])
        w.writerow([episode, rewards[-1]])

    # PNG
    plt.figure(figsize=(8,4.5))
    plt.plot(rewards, label="Reward")
    ma = _moving_average(rewards, k=100)
    if len(ma) > 0:
        plt.plot(ma, label="Moving Avg (100)")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Training Rewards")
    plt.legend()
    png_path = os.path.join(checkpoint_dir, f"rewards_ep{episode}.png")
    plt.tight_layout()
    plt.savefig(png_path, dpi=120)
    plt.close()
    print(f"[LOG] Guardadas gráfica y CSV en: {png_path} / {csv_path}")


# ===========================
#  Entrenamiento
# ===========================
def train_dueling_double_dqn_per(
    checkpoint_dir: str,
    total_episodes: int = 50000,
    buffer_capacity: int = 100_000,
    batch_size: int = 32,
    gamma: float = 0.99,
    lr: float = 1e-4,
    start_epsilon: float = 1.0,
    end_epsilon: float = 0.1,
    epsilon_decay_episodes: int = 30000,
    target_update_interval: int = 1000,     # en pasos
    train_start: int = 10_000,
    per_alpha: float = 0.6,
    per_beta_start: float = 0.4,
    per_beta_end: float = 1.0,
    per_beta_anneal_episodes: int = 50000,
    per_eps: float = 1e-6,
    save_interval: int = 500,               # guardar modelo
    plot_interval: int = 200,               # guardar PNG/CSV
    max_steps_per_episode: int | None = None,
    seed: int = 42,
):
    os.makedirs(checkpoint_dir, exist_ok=True)

    env = make_galaxian_env(seed=seed, render_mode=None)

    # Detecta layout y fija input_shape=(C,H,W)
    obs, _ = env.reset()
    if obs.ndim != 3:
        env.close()
        raise ValueError("obs.ndim inesperado")
    if obs.shape[0] in (1,3,4):
        input_shape = (obs.shape[0], obs.shape[1], obs.shape[2])
    else:
        input_shape = (obs.shape[2], obs.shape[0], obs.shape[1])
    n_actions = env.action_space.n

    q_net = DuelingDQN(input_shape, n_actions).to(device)
    target_net = DuelingDQN(input_shape, n_actions).to(device)
    target_net.load_state_dict(q_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(q_net.parameters(), lr=lr)
    buffer = PrioritizedReplayBuffer(buffer_capacity, alpha=per_alpha, eps=per_eps)

    rewards_log: List[float] = []
    global_step = 0

    def epsilon_by_episode(ep):
        if ep >= epsilon_decay_episodes:
            return end_epsilon
        frac = ep / float(epsilon_decay_episodes)
        return start_epsilon + frac * (end_epsilon - start_epsilon)

    def beta_by_episode(ep):
        # lineal de per_beta_start -> per_beta_end
        frac = min(1.0, ep / float(per_beta_anneal_episodes))
        return per_beta_start + frac * (per_beta_end - per_beta_start)

    for episode in range(1, total_episodes + 1):
        obs, _ = env.reset()
        done = False
        total_reward = 0.0
        steps_in_ep = 0

        eps = epsilon_by_episode(episode)
        beta = beta_by_episode(episode)

        while not done:
            global_step += 1
            steps_in_ep += 1

            # ε-greedy sobre q_net
            if random.random() < eps:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    ob = np.expand_dims(obs, axis=0)
                    ob_t = torch.from_numpy(ob).to(device)
                    q_vals = q_net(ob_t)
                    action = int(torch.argmax(q_vals, dim=1).item())

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += float(reward)

            buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs

            # Entrenamiento (cuando hay suficientes muestras)
            if len(buffer) >= train_start:
                idxs, isw, batch = buffer.sample(batch_size, beta=beta)
                s, a, r, ns, d = batch

                s_t  = torch.from_numpy(s).to(device)
                ns_t = torch.from_numpy(ns).to(device)
                a_t  = torch.from_numpy(a).long().to(device)
                r_t  = torch.from_numpy(r).float().to(device)
                d_t  = torch.from_numpy(d).float().to(device)
                w_t  = torch.from_numpy(isw).float().to(device)

                # Q_online(s,a)
                q = q_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

                # Double DQN target:
                # a* = argmax_a Q_online(ns,a)
                with torch.no_grad():
                    q_online_ns = q_net(ns_t)
                    a_star = torch.argmax(q_online_ns, dim=1, keepdim=True)

                    q_target_ns = target_net(ns_t)
                    next_q = q_target_ns.gather(1, a_star).squeeze(1)

                    target = r_t + gamma * next_q * (1.0 - d_t)

                td_error = target - q
                loss = (w_t * td_error.pow(2)).mean()

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
                optimizer.step()

                # actualizar prioridades con |TD error|
                new_prios = td_error.detach().abs().cpu().numpy() + per_eps
                buffer.update_priorities(idxs, new_prios)

            # sync target
            if global_step % target_update_interval == 0:
                target_net.load_state_dict(q_net.state_dict())

            if max_steps_per_episode is not None and steps_in_ep >= max_steps_per_episode:
                break

        rewards_log.append(total_reward)
        print(f"[Dueling-DDQN+PER] Ep {episode}/{total_episodes} | R: {total_reward:.1f} | eps={eps:.3f} | beta={beta:.3f} | buf={len(buffer)}")

        # Guardar checkpoint
        if episode % save_interval == 0:
            ckpt_path = os.path.join(checkpoint_dir, f"dueling_ddqn_per_ep{episode}.pth")
            torch.save({
                "q_net": q_net.state_dict(),
                "target_net": target_net.state_dict(),
                "optimizer": optimizer.state_dict(),
                "episode": episode,
                "global_step": global_step,
                "rewards": rewards_log,
                "input_shape": input_shape,
                "n_actions": n_actions,
            }, ckpt_path)
            print(f"[CKPT] Guardado: {ckpt_path}")

        # Guardar gráfica + CSV
        if episode % plot_interval == 0:
            _save_plot_and_csv(checkpoint_dir, rewards_log, episode)

    env.close()

    # Guardar modelo final
    final_path = os.path.join(checkpoint_dir, "dueling_ddqn_per_final.pth")
    torch.save(q_net.state_dict(), final_path)
    print(f"[DONE] Modelo final guardado en: {final_path}")

    # Gráfica final
    _save_plot_and_csv(checkpoint_dir, rewards_log, total_episodes)

    return q_net

In [None]:
model = train_dueling_double_dqn_per(
    checkpoint_dir=f"{BASE_DIR}/dueling_ddqn_per",
    total_episodes=100000,            # ajusta según GPU/tiempo
    save_interval=500,
    plot_interval=500
)

[1;30;43mSe truncaron las últimas líneas 5000 del resultado de transmisión.[0m
[Dueling-DDQN+PER] Ep 1892/100000 | R: 600.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1893/100000 | R: 1440.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1894/100000 | R: 740.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1895/100000 | R: 430.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1896/100000 | R: 1100.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1897/100000 | R: 680.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1898/100000 | R: 330.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1899/100000 | R: 990.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1900/100000 | R: 510.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1901/100000 | R: 670.0 | eps=0.943 | beta=0.423 | buf=131072
[Dueling-DDQN+PER] Ep 1902/100000 | R: 900.0 | eps=0.943 | beta=0.423 | buf=131072
[Due

In [9]:
def train_dueling_double_dqn_per(
    checkpoint_dir: str,
    total_episodes: int = 5000,
    buffer_capacity: int = 100_000,
    batch_size: int = 32,
    gamma: float = 0.99,
    lr: float = 5e-5,
    start_epsilon: float = 1.0,
    end_epsilon: float = 0.05,
    epsilon_decay_episodes: int = 500,
    target_update_interval: int = 1000,     # en pasos
    train_start: int = 1000,
    per_alpha: float = 0.6,
    per_beta_start: float = 0.4,
    per_beta_end: float = 1.0,
    per_beta_anneal_episodes: int = 500,
    per_eps: float = 1e-6,
    save_interval: int = 500,               # guardar modelo
    plot_interval: int = 200,               # guardar PNG/CSV
    max_steps_per_episode: int | None = None,
    seed: int = 42,
    resume_from: str | None = None,         # <-- NUEVO: ruta a checkpoint .pth
):
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Creamos el entorno
    env = make_galaxian_env(seed=seed, render_mode=None)

    # Detectar layout de observación
    obs, _ = env.reset()
    if obs.ndim != 3:
        env.close()
        raise ValueError("obs.ndim inesperado")
    if obs.shape[0] in (1, 3, 4):
        detected_input_shape = (obs.shape[0], obs.shape[1], obs.shape[2])
    else:
        detected_input_shape = (obs.shape[2], obs.shape[0], obs.shape[1])
    detected_n_actions = env.action_space.n

    # ------------------------------
    # Inicializar redes, optimizador, logs
    # ------------------------------
    start_episode = 1
    global_step = 0
    rewards_log: List[float] = []

    if resume_from is not None:
        print(f"[RESUME] Cargando checkpoint desde {resume_from}")
        ckpt = torch.load(resume_from, map_location=device, weights_only=False)

        # Usar input_shape / n_actions del checkpoint (más confiable)
        input_shape = tuple(ckpt.get("input_shape", detected_input_shape))
        n_actions = int(ckpt.get("n_actions", detected_n_actions))

        q_net = DuelingDQN(input_shape, n_actions).to(device)
        target_net = DuelingDQN(input_shape, n_actions).to(device)

        q_net.load_state_dict(ckpt["q_net"])
        target_net.load_state_dict(ckpt["target_net"])

        optimizer = optim.Adam(q_net.parameters(), lr=lr)
        optimizer.load_state_dict(ckpt["optimizer"])

        start_episode = int(ckpt["episode"]) + 1
        global_step = int(ckpt.get("global_step", 0))
        rewards_log = list(ckpt.get("rewards", []))

        print(f"[RESUME] Reanudando desde episodio {start_episode}, global_step={global_step}")
    else:
        # Entrenamiento desde cero
        input_shape = detected_input_shape
        n_actions = detected_n_actions

        q_net = DuelingDQN(input_shape, n_actions).to(device)
        target_net = DuelingDQN(input_shape, n_actions).to(device)
        target_net.load_state_dict(q_net.state_dict())
        target_net.eval()

        optimizer = optim.Adam(q_net.parameters(), lr=lr)
        start_episode = 1
        global_step = 0
        rewards_log = []

    # Replay buffer PER NUEVO (no se guardó en el checkpoint)
    buffer = PrioritizedReplayBuffer(buffer_capacity, alpha=per_alpha, eps=per_eps)

    def epsilon_by_episode(ep):
        if ep >= epsilon_decay_episodes:
            return end_epsilon
        frac = ep / float(epsilon_decay_episodes)
        return start_epsilon + frac * (end_epsilon - start_epsilon)

    def beta_by_episode(ep):
        frac = min(1.0, ep / float(per_beta_anneal_episodes))
        return per_beta_start + frac * (per_beta_end - per_beta_start)

    # ------------------------------
    # Bucle principal de entrenamiento
    # ------------------------------
    for episode in range(start_episode, total_episodes + 1):
        obs, _ = env.reset()
        done = False
        total_reward = 0.0
        steps_in_ep = 0

        eps = epsilon_by_episode(episode)
        beta = beta_by_episode(episode)

        while not done:
            global_step += 1
            steps_in_ep += 1

            # Política ε-greedy sobre q_net
            if random.random() < eps:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    ob = np.expand_dims(obs, axis=0)
                    ob_t = torch.from_numpy(ob).to(device)
                    q_vals = q_net(ob_t)
                    action = int(torch.argmax(q_vals, dim=1).item())

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += float(reward)

            buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs

            # Entrenamiento solo si hay suficientes muestras
            if len(buffer) >= train_start:
                idxs, isw, batch = buffer.sample(batch_size, beta=beta)
                s, a, r, ns, d = batch

                s_t  = torch.from_numpy(s).to(device)
                ns_t = torch.from_numpy(ns).to(device)
                a_t  = torch.from_numpy(a).long().to(device)
                r_t  = torch.from_numpy(r).float().to(device)
                d_t  = torch.from_numpy(d).float().to(device)
                w_t  = torch.from_numpy(isw).float().to(device)

                # Q(s,a) online
                q = q_net(s_t).gather(1, a_t.unsqueeze(1)).squeeze(1)

                # Double DQN: a* con q_net, eval con target_net
                with torch.no_grad():
                    q_online_ns = q_net(ns_t)
                    a_star = torch.argmax(q_online_ns, dim=1, keepdim=True)

                    q_target_ns = target_net(ns_t)
                    next_q = q_target_ns.gather(1, a_star).squeeze(1)

                    target = r_t + gamma * next_q * (1.0 - d_t)

                td_error = target - q
                loss = (w_t * td_error.pow(2)).mean()

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
                optimizer.step()

                new_prios = td_error.detach().abs().cpu().numpy() + per_eps
                buffer.update_priorities(idxs, new_prios)

            if global_step % target_update_interval == 0:
                target_net.load_state_dict(q_net.state_dict())

            if max_steps_per_episode is not None and steps_in_ep >= max_steps_per_episode:
                break

        rewards_log.append(total_reward)
        print(f"[Dueling-DDQN+PER] Ep {episode}/{total_episodes} | R: {total_reward:.1f} | eps={eps:.3f} | beta={beta:.3f} | buf={len(buffer)}")

        # Guardar checkpoint
        if episode % save_interval == 0:
            ckpt_path = os.path.join(checkpoint_dir, f"dueling_ddqn_per_ep{episode}.pth")
            torch.save({
                "q_net": q_net.state_dict(),
                "target_net": target_net.state_dict(),
                "optimizer": optimizer.state_dict(),
                "episode": episode,
                "global_step": global_step,
                "rewards": rewards_log,
                "input_shape": input_shape,
                "n_actions": n_actions,
            }, ckpt_path)
            print(f"[CKPT] Guardado: {ckpt_path}")

        # Guardar gráfica + CSV
        if episode % plot_interval == 0:
            _save_plot_and_csv(checkpoint_dir, rewards_log, episode)

    env.close()

    final_path = os.path.join(checkpoint_dir, "dueling_ddqn_per_final.pth")
    torch.save(q_net.state_dict(), final_path)
    print(f"[DONE] Modelo final guardado en: {final_path}")
    _save_plot_and_csv(checkpoint_dir, rewards_log, total_episodes)

    return q_net


In [None]:
BASE_DIR = "/content/drive/MyDrive/galaxian_rl"
ckpt = f"{BASE_DIR}/dueling_ddqn_per/dueling_ddqn_per_ep9500.pth"

q_net = train_dueling_double_dqn_per(
    checkpoint_dir=f"{BASE_DIR}/dueling_ddqn_per",
    total_episodes=20000,          # el total final que quieres alcanzar
    resume_from=ckpt,              # <-- aquí se reanuda
    save_interval=500,
    plot_interval=500
)


[RESUME] Cargando checkpoint desde /content/drive/MyDrive/galaxian_rl/dueling_ddqn_per/dueling_ddqn_per_ep9500.pth
[RESUME] Reanudando desde episodio 9501, global_step=1774076
[Dueling-DDQN+PER] Ep 9501/20000 | R: 3630.0 | eps=0.050 | beta=1.000 | buf=507
[Dueling-DDQN+PER] Ep 9502/20000 | R: 2380.0 | eps=0.050 | beta=1.000 | buf=850
[Dueling-DDQN+PER] Ep 9503/20000 | R: 1460.0 | eps=0.050 | beta=1.000 | buf=1141
[Dueling-DDQN+PER] Ep 9504/20000 | R: 2570.0 | eps=0.050 | beta=1.000 | buf=1485
[Dueling-DDQN+PER] Ep 9505/20000 | R: 590.0 | eps=0.050 | beta=1.000 | buf=1637
[Dueling-DDQN+PER] Ep 9506/20000 | R: 2130.0 | eps=0.050 | beta=1.000 | buf=1991
[Dueling-DDQN+PER] Ep 9507/20000 | R: 740.0 | eps=0.050 | beta=1.000 | buf=2141
[Dueling-DDQN+PER] Ep 9508/20000 | R: 1760.0 | eps=0.050 | beta=1.000 | buf=2400
[Dueling-DDQN+PER] Ep 9509/20000 | R: 2670.0 | eps=0.050 | beta=1.000 | buf=2819
[Dueling-DDQN+PER] Ep 9510/20000 | R: 1530.0 | eps=0.050 | beta=1.000 | buf=3165
[Dueling-DDQN+PER]