#### Implementación de Prioritized Experience Replay

In [None]:
# Implementación de Prioritized Experience Replay (PER)
class PrioritizedMemory(Memory):
    def __init__(self, limit, alpha=0.6, beta=0.4, beta_increment=0.001, window_length=1):
        super(PrioritizedMemory, self).__init__()
        self.limit = limit
        self.alpha = alpha  # Determina cuánto se usa la prioridad (0 = sin prioridad, 1 = solo prioridad)
        self.beta = beta    # Importancia del muestreo (0 = sin corrección, 1 = corrección completa)
        self.beta_increment = beta_increment  # Incremento de beta durante el entrenamiento
        self.window_length = window_length
        
        # Inicializar buffers
        self.actions = np.zeros(limit, dtype=np.uint8)
        self.rewards = np.zeros(limit, dtype=np.float32)
        self.terminals = np.zeros(limit, dtype=np.bool)
        self.observations = [None] * limit
        
        # Variables para PER
        self.priorities = np.zeros(limit, dtype=np.float32)
        self.tree = SumTree(limit)
        self.max_priority = 1.0
        
        self.position = 0
        self.nb_entries = 0
    
    def append(self, observation, action, reward, terminal, training=True):
        super(PrioritizedMemory, self).append(observation, action, reward, terminal, training=training)
        
        # Almacenar en buffers
        self.observations[self.position] = observation
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.terminals[self.position] = terminal
        
        # Asignar máxima prioridad a nuevas experiencias
        self.tree.add(self.max_priority, self.position)
        
        # Actualizar posición e incrementar entradas
        self.position = (self.position + 1) % self.limit
        if self.nb_entries < self.limit:
            self.nb_entries += 1
    
    def _sample_batch_indices(self, batch_size):
        # Incrementar beta para corrección de importancia
        self.beta = min(1.0, self.beta + self.beta_increment)
        
        # Muestreo basado en prioridad
        indices = []
        priorities = []
        segment = self.tree.total() / batch_size
        
        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)
            idx, p, _ = self.tree.get(s)
            indices.append(idx)
            priorities.append(p)
        
        # Calcular pesos para corrección de importancia
        sampling_probabilities = np.array(priorities) / self.tree.total()
        is_weights = np.power(self.nb_entries * sampling_probabilities, -self.beta)
        is_weights /= is_weights.max()  # Normalizar
        
        return indices, is_weights
    
    def sample(self, batch_size, batch_idxs=None):
        if batch_idxs is None:
            batch_idxs, is_weights = self._sample_batch_indices(batch_size)
        else:
            is_weights = np.ones((len(batch_idxs),), dtype=np.float32)
        
        # Crear batch
        batch = {}
        batch['is_weights'] = is_weights
        batch['batch_idxs'] = batch_idxs
        
        # Extraer experiencias
        batch['observations0'] = []
        for idx in batch_idxs:
            batch['observations0'].append(self.observations[idx])
        
        batch['actions'] = self.actions[batch_idxs]
        batch['rewards'] = self.rewards[batch_idxs]
        batch['terminals'] = self.terminals[batch_idxs]
        
        # Obtener observaciones siguientes
        batch['observations1'] = []
        for idx in batch_idxs:
            terminal = self.terminals[idx]
            if terminal:
                next_idx = idx
            else:
                next_idx = (idx + 1) % self.limit
            batch['observations1'].append(self.observations[next_idx])
        
        return batch
    
    def update_priorities(self, batch_idxs, td_errors):
        # Actualizar prioridades basadas en errores TD
        for idx, error in zip(batch_idxs, td_errors):
            priority = (np.abs(error) + 1e-6) ** self.alpha  # Evitar prioridad cero
            self.tree.update(idx, priority)
            self.max_priority = max(self.max_priority, priority)
    
    def get_config(self):
        config = {
            'limit': self.limit,
            'alpha': self.alpha,
            'beta': self.beta,
            'beta_increment': self.beta_increment,
            'window_length': self.window_length
        }
        return config

# Estructura de datos SumTree para PER
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1, dtype=np.float32)
        self.data = np.zeros(capacity, dtype=np.int32)
        self.n_entries = 0
        self.write = 0
    
    def _propagate(self, idx, change):
        # Propagar cambio hacia arriba
        parent = (idx - 1) // 2
        self.tree[parent] += change
        
        if parent != 0:
            self._propagate(parent, change)
    
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
        
        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])
    
    def total(self):
        return self.tree[0]
    
    def add(self, p, data):
        idx = self.write + self.capacity - 1
        
        self.data[self.write] = data
        self.update(idx, p)
        
        self.write = (self.write + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1
    
    def update(self, idx, p):
        change = p - self.tree[idx]
        
        self.tree[idx] = p
        self._propagate(idx, change)
    
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        
        return (self.data[dataIdx], self.tree[idx], dataIdx)

#### Implementación de callbacks personalizados

In [None]:
# Callback para guardar pesos después de cada episodio
class EpisodeCheckpoint(Callback):
    def __init__(self, filepath, interval=1, verbose=1):
        super(EpisodeCheckpoint, self).__init__()
        self.filepath = filepath
        self.interval = interval
        self.verbose = verbose
        self.episode = 0
        self.best_reward = -np.inf
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Guardar pesos cada 'interval' episodios
        if self.episode % self.interval == 0:
            filepath = self.filepath.format(episode=self.episode, **logs)
            if self.verbose > 0:
                print(f'\nEpisodio {self.episode}: guardando pesos en {filepath}')
            self.model.save_weights(filepath, overwrite=True)
        
        # Guardar los mejores pesos basados en la recompensa
        if logs.get('episode_reward', -np.inf) > self.best_reward:
            self.best_reward = logs.get('episode_reward')
            best_filepath = self.filepath.format(episode='best')
            if self.verbose > 0:
                print(f'\nNueva mejor recompensa: {self.best_reward:.2f}, guardando en {best_filepath}')
            self.model.save_weights(best_filepath, overwrite=True)

# Callback para visualizar el progreso del entrenamiento
class TrainingVisualization(Callback):
    def __init__(self, log_file, plot_interval=5):
        super(TrainingVisualization, self).__init__()
        self.log_file = log_file
        self.plot_interval = plot_interval
        self.episode_rewards = []
        self.episode_losses = []
        self.episode_maes = []
        self.episode = 0
        self.moving_avg_rewards = []
        self.window_size = 10  # Tamaño de la ventana para promedio móvil
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Guardar métricas
        reward = logs.get('episode_reward', 0)
        self.episode_rewards.append(reward)
        self.episode_losses.append(logs.get('loss', 0))
        self.episode_maes.append(logs.get('mae', 0))
        
        # Calcular promedio móvil de recompensas
        if len(self.episode_rewards) >= self.window_size:
            avg_reward = np.mean(self.episode_rewards[-self.window_size:])
        else:
            avg_reward = np.mean(self.episode_rewards)
        self.moving_avg_rewards.append(avg_reward)
        
        # Guardar datos en archivo JSON
        data = {
            'episode_rewards': self.episode_rewards,
            'episode_losses': self.episode_losses,
            'episode_maes': self.episode_maes,
            'moving_avg_rewards': self.moving_avg_rewards
        }
        with open(self.log_file, 'w') as f:
            json.dump(data, f)
        
        # Visualizar progreso cada plot_interval episodios
        if self.episode % self.plot_interval == 0:
            self.visualize_training()
    
    def visualize_training(self):
        plt.figure(figsize=(15, 10))
        
        # Gráfico de recompensas
        plt.subplot(2, 2, 1)
        plt.plot(self.episode_rewards)
        plt.title('Recompensas por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('Recompensa')
        
        # Gráfico de promedio móvil de recompensas
        plt.subplot(2, 2, 2)
        plt.plot(self.moving_avg_rewards)
        plt.title(f'Promedio móvil de recompensas (ventana={self.window_size})')
        plt.xlabel('Episodio')
        plt.ylabel('Recompensa promedio')
        
        # Gráfico de pérdidas
        plt.subplot(2, 2, 3)
        plt.plot(self.episode_losses)
        plt.title('Pérdida por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('Pérdida')
        
        # Gráfico de MAE
        plt.subplot(2, 2, 4)
        plt.plot(self.episode_maes)
        plt.title('MAE por episodio')
        plt.xlabel('Episodio')
        plt.ylabel('MAE')
        
        plt.tight_layout()
        plt.show()

# Callback para ajustar la tasa de aprendizaje durante el entrenamiento
class LearningRateScheduler(Callback):
    def __init__(self, initial_lr=0.00025, min_lr=0.00001, decay_factor=0.5, decay_episodes=50):
        super(LearningRateScheduler, self).__init__()
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.decay_factor = decay_factor
        self.decay_episodes = decay_episodes
        self.episode = 0
        
    def on_episode_end(self, episode, logs=None):
        logs = logs or {}
        self.episode += 1
        
        # Ajustar tasa de aprendizaje cada decay_episodes episodios
        if self.episode % self.decay_episodes == 0:
            old_lr = K.get_value(self.model.optimizer.lr)
            new_lr = max(old_lr * self.decay_factor, self.min_lr)  # No bajar más del mínimo
            K.set_value(self.model.optimizer.lr, new_lr)
            print(f'\nEpisodio {self.episode}: tasa de aprendizaje ajustada de {old_lr:.6f} a {new_lr:.6f}')

#### Implementación del procesador de observaciones

In [None]:
class AtariProcessor(Processor):
    def process_observation(self, observation):
        assert observation.ndim == 3  # (height, width, channel)
        img = Image.fromarray(observation)
        img = img.resize(INPUT_SHAPE).convert('L')  # Convertir a escala de grises
        processed_observation = np.array(img)
        assert processed_observation.shape == INPUT_SHAPE
        return processed_observation.astype('uint8')  # Guardar como uint8 para ahorrar memoria

    def process_state_batch(self, batch):
        processed_batch = batch.astype('float32') / 255.  # Normalizar a [0, 1]
        return processed_batch

    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)  # Recortar recompensas para estabilizar el aprendizaje