# Callbacks para mejorar el entrenamiento en DQN

Este notebook implementa callbacks personalizados para mejorar el entrenamiento del agente DQN en Space Invaders.

In [None]:
from __future__ import division

from PIL import Image
import numpy as np
import gym
import matplotlib.pyplot as plt
import json
import os

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D, Permute, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam, RMSprop
import tensorflow.keras.backend as K

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import FileLogger, ModelIntervalCheckpoint, Callback

In [None]:
# Configuración base
INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4  # Captura 4 frames consecutivos para percibir movimiento

env_name = 'SpaceInvaders-v0'
env = gym.make(env_name)

np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

print("Numero de acciones disponibles: " + str(nb_actions))
print("Formato de las observaciones: " + str(env.observation_space))

In [None]:
class AtariProcessor(Processor):
    def process_observation(self, observation):
        assert observation.ndim == 3  # (height, width, channel)
        img = Image.fromarray(observation)
        img = img.resize(INPUT_SHAPE).convert('L')  # Convertir a escala de grises
        processed_observation = np.array(img)
        assert processed_observation.shape == INPUT_SHAPE
        return processed_observation.astype('uint8')  # Guardar como uint8 para ahorrar memoria

    def process_state_batch(self, batch):
        processed_batch = batch.astype('float32') / 255.  # Normalizar a [0, 1]
        return processed_batch

    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)  # Recortar recompensas para estabilizar el aprendizaje

In [None]:
# Callback para guardar pesos después de cada episodio
class EpisodeCheckpoint(Callback):
    def __init__(self, filepath, interval=1, verbose=1):
        super(EpisodeCheckpoint, self).__init__()
        self.filepath = filepath
        self.interval = interval
        self.verbose = verbose
        self.episode = 0
        self.best_reward = -np.inf
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Guardar pesos cada 'interval' episodios
        if self.episode % self.interval == 0:
            filepath = self.filepath.format(episode=self.episode, **logs)
            if self.verbose > 0:
                print(f'\nEpisodio {self.episode}: guardando pesos en {filepath}')
            self.model.save_weights(filepath, overwrite=True)
        
        # Guardar los mejores pesos basados en la recompensa
        if logs.get('episode_reward', -np.inf) > self.best_reward:
            self.best_reward = logs.get('episode_reward')
            best_filepath = self.filepath.format(episode='best')
            if self.verbose > 0:
                print(f'\nNueva mejor recompensa: {self.best_reward:.2f}, guardando en {best_filepath}')
            self.model.save_weights(best_filepath, overwrite=True)

In [None]:
# Callback para visualizar el progreso del entrenamiento
class TrainingVisualization(Callback):
    def __init__(self, log_file, plot_interval=5):
        super(TrainingVisualization, self).__init__()
        self.log_file = log_file
        self.plot_interval = plot_interval
        self.episode_rewards = []
        self.episode_losses = []
        self.episode_maes = []
        self.episode = 0
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Guardar métricas
        self.episode_rewards.append(logs.get('episode_reward', 0))
        self.episode_losses.append(logs.get('loss', 0))
        self.episode_maes.append(logs.get('mae', 0))
        
        # Guardar datos en archivo JSON
        data = {
            'episode_rewards': self.episode_rewards,
            'episode_losses': self.episode_losses,
            'episode_maes': self.episode_maes
        }
        with open(self.log_file, 'w') as f:
            json.dump(data, f)
        
        # Visualizar progreso cada plot_interval episodios
        if self.episode % self.plot_interval == 0:
            self.visualize_training()
    
    def visualize_training(self):
        plt.figure(figsize=(15, 5))
        
        # Gráfico de recompensas
        plt.subplot(1, 3, 1)
        plt.plot(self.episode_rewards)
        plt.title('Recompensas por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('Recompensa')
        
        # Gráfico de pérdidas
        plt.subplot(1, 3, 2)
        plt.plot(self.episode_losses)
        plt.title('Pérdida por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('Pérdida')
        
        # Gráfico de MAE
        plt.subplot(1, 3, 3)
        plt.plot(self.episode_maes)
        plt.title('MAE por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('MAE')
        
        plt.tight_layout()
        plt.show()

In [None]:
# Callback para ajustar la tasa de aprendizaje durante el entrenamiento
class LearningRateScheduler(Callback):
    def __init__(self, initial_lr=0.00025, decay_factor=0.1, decay_episodes=50):
        super(LearningRateScheduler, self).__init__()
        self.initial_lr = initial_lr
        self.decay_factor = decay_factor
        self.decay_episodes = decay_episodes
        self.episode = 0
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Ajustar tasa de aprendizaje cada decay_episodes episodios
        if self.episode % self.decay_episodes == 0:
            old_lr = K.get_value(self.model.optimizer.lr)
            new_lr = old_lr * self.decay_factor
            K.set_value(self.model.optimizer.lr, new_lr)
            print(f'\nEpisodio {self.episode}: tasa de aprendizaje ajustada de {old_lr:.6f} a {new_lr:.6f}')

In [None]:
# Modelo CNN para DQN
input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE
model = Sequential()

if K.image_data_format() == 'channels_last':
    model.add(Permute((2, 3, 1), input_shape=input_shape))
elif K.image_data_format() == 'channels_first':
    model.add(Permute((1, 2, 3), input_shape=input_shape))
else:
    raise RuntimeError('Unknown image_dim_ordering.')

model.add(Convolution2D(32, (8, 8), strides=(4, 4)))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Convolution2D(64, (4, 4), strides=(2, 2)))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Convolution2D(64, (3, 3), strides=(1, 1)))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Flatten())
model.add(Dense(512))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))

model.add(Dense(nb_actions))
model.add(Activation('linear'))

print(model.summary())

In [None]:
# Configuración del agente DQN
memory = SequentialMemory(limit=50000, window_length=WINDOW_LENGTH)
processor = AtariProcessor()
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps',
                              value_max=1.0, value_min=0.1, value_test=0.05,
                              nb_steps=50000)

dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy,
               memory=memory, processor=processor,
               nb_steps_warmup=1000, gamma=0.99,
               target_model_update=1000,
               train_interval=4,
               delta_clip=1.0)

dqn.compile(Adam(learning_rate=0.00025), metrics=['mae'])

In [None]:
# Configuración de callbacks
weights_filename = 'dqn_callbacks_{}_weights.h5f'.format(env_name)
checkpoint_weights_filename = 'dqn_callbacks_' + env_name + '_weights_episode_{episode}.h5f'
log_filename = 'dqn_callbacks_{}_log.json'.format(env_name)
visualization_log = 'dqn_callbacks_{}_visualization.json'.format(env_name)

# Crear directorio para checkpoints si no existe
checkpoint_dir = 'checkpoints'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Callbacks personalizados
callbacks = [
    # Guardar pesos cada 5 episodios
    EpisodeCheckpoint(os.path.join(checkpoint_dir, checkpoint_weights_filename), interval=5),
    
    # Visualizar progreso cada 5 episodios
    TrainingVisualization(visualization_log, plot_interval=5),
    
    # Ajustar tasa de aprendizaje cada 50 episodios
    LearningRateScheduler(initial_lr=0.00025, decay_factor=0.5, decay_episodes=50),
    
    # Logger estándar
    FileLogger(log_filename, interval=100)
]

In [None]:
# Entrenamiento del agente
dqn.fit(env, callbacks=callbacks, nb_steps=50000, log_interval=1000, visualize=False)

# Guardar pesos finales
dqn.save_weights(weights_filename, overwrite=True)

In [None]:
# Test con los mejores pesos
best_weights_filename = os.path.join(checkpoint_dir, 'dqn_callbacks_' + env_name + '_weights_episode_best.h5f')
if os.path.exists(best_weights_filename):
    print(f"Cargando los mejores pesos desde: {best_weights_filename}")
    dqn.load_weights(best_weights_filename)
else:
    print(f"No se encontraron los mejores pesos, usando los pesos finales: {weights_filename}")
    dqn.load_weights(weights_filename)

# Test de n episodios para calcular la recompensa final
dqn.test(env, nb_episodes=10, visualize=True)