<a href="https://colab.research.google.com/github/maciekpdev/Beta-DQN-Project/blob/google-colab/%5BUSD%5D_beta_DQN_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium[classic-control]
!pip install minigrid
!pip install ale-py
!pip install gym-super-mario-bros
!pip install pympler

Collecting minigrid
  Downloading minigrid-3.0.0-py3-none-any.whl.metadata (6.7 kB)
Downloading minigrid-3.0.0-py3-none-any.whl (136 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m136.7/136.7 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: minigrid
Successfully installed minigrid-3.0.0
Collecting gym-super-mario-bros
  Downloading gym_super_mario_bros-7.4.0-py3-none-any.whl.metadata (10 kB)
Collecting nes-py>=8.1.4 (from gym-super-mario-bros)
  Downloading nes_py-8.2.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.7/77.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyglet<=1.5.21,>=1.4.0 (from nes-py>=8.1.4->gym-super-mario-bros)
  Downloading pyglet-1.5.21-py3-none-any.whl.metadata (7.6 kB)
Downloading gym_super_mario_bros-7.4.0-py3-none-any.whl (199 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.1/199.

In [None]:
import sys
import os
import gymnasium as gym
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
import random
import pickle

import torch.nn as nn
import torch
import torch.nn.functional as F
import random
from collections import deque

DQN

In [None]:
class ReplayBuffer:
    def __init__(self, capacity, state_shape, dtype=np.uint8):
        self.capacity = capacity
        self.pos = 0
        self.full = False

        self.states = np.zeros((capacity, *state_shape), dtype=dtype)
        self.next_states = np.zeros((capacity, *state_shape), dtype=dtype)
        self.actions = np.zeros(capacity, dtype=np.int64)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.bool_)

    def push(self, s, a, r, s2, done):
        self.states[self.pos] = s
        self.next_states[self.pos] = s2
        self.actions[self.pos] = a
        self.rewards[self.pos] = r
        self.dones[self.pos] = done

        self.pos = (self.pos + 1) % self.capacity
        self.full = self.full or self.pos == 0

    def sample(self, batch_size):
        max_idx = self.capacity if self.full else self.pos
        idxs = np.random.randint(0, max_idx, size=batch_size)

        return (self.states[idxs],
                self.actions[idxs],
                self.rewards[idxs],
                self.next_states[idxs],
                self.dones[idxs])

    def __len__(self):
        return self.capacity if self.full else self.pos


class DQN(nn.Module):
    def __init__(self, obs_shape, num_actions):
        """
        Deep Q-Network (DQN) architecture based on:

        Young, K.,  Tian, T. (2019).
        "MinAtar: An Atari-Inspired Testbed for Thorough and Reproducible
        Reinforcement Learning Experiments." https://arxiv.org/abs/1903.03176

        obs_shape: (C, H, W)
        num_actions: int
        """
        super().__init__()

        self.is_image = len(obs_shape) == 3

        if self.is_image:
            c, h, w = obs_shape
            self.net = nn.Sequential(
                nn.Conv2d(in_channels=c, out_channels=16, kernel_size=3, stride=1),
                nn.ReLU(),
                nn.Flatten()
            )

            with torch.no_grad():
                dummy = torch.zeros(1, c, h, w)
                output_size = self.net(dummy).shape[1]

            self.fc = nn.Sequential(
                nn.Linear(output_size, 128),
                nn.ReLU(),
                nn.Linear(128, num_actions)
            )
        else:
            input_dim = obs_shape[0]
            self.net = nn.Sequential(
                nn.Linear(input_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 128),
                nn.ReLU()
            )
            self.fc = nn.Linear(128, num_actions)
    def forward(self, x):
        if self.is_image and x.max() > 1.0: x = x / 255.0
        x = self.net(x)
        return self.fc(x)

class BenchDQNAgent:
    def __init__(self, state_dim, action_dim, config, device):
        self.device = device
        self.action_dim = action_dim
        self.config = config # Store config

        # Read Hyperparameters from config
        self.gamma = config["gamma"]
        self.epsilon = config["epsilon_start"]
        self.epsilon_min = config["epsilon_min"]
        self.epsilon_decay = config["epsilon_decay"]
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]

        # Init Networks
        self.policy_net = DQN(state_dim, action_dim).to(device)
        self.target_net = DQN(state_dim, action_dim).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)

        # Init Memory
        self.memory = ReplayBuffer(config["buffer_size"], state_dim)

    def select_action(self, state, policy, training=True):
        if training and np.random.rand() < self.epsilon:
            return random.randrange(self.action_dim)

        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            return self.policy_net(state_t).argmax().item()

    def store_transition(self, s, a, r, ns, d):
        self.memory.push(s, a, r, ns, d)

    def train_step(self):
        if len(self.memory) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)

        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(np.array(actions)).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).to(self.device)

        curr_q = self.policy_net(states).gather(1, actions).squeeze(1)

        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)

        loss = F.smooth_l1_loss(curr_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

        return loss.item()

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


NoisyNet-DQN

In [None]:
#TODO

beta-DQN

In [None]:
class BetaNetwork(DQN):
    def forward(self, x):
        logits = super().forward(x)
        return torch.softmax(logits, dim=-1)

In [None]:
import random
import torch

class CovPolicy:
    def __init__(self, delta):
        self.delta = delta

    def __call__(self, Q_values, beta_probs, epsilon):
        actions = torch.argsort(Q_values)

        if all(beta_probs[a] > self.delta for a in actions):
            return max(actions, key=lambda a: Q_values[a]).item(), False

        low_coverage_actions = [a for a in actions if beta_probs[a] <= self.delta]
        return random.choice(low_coverage_actions).item(), True

    def __str__(self):
        return f"CovPolicy delta={self.delta}"

class CorPolicy:
    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, Q, beta, epsilon):
        actions = torch.argsort(Q)
        min_q = torch.min(Q)

        action = max(
            actions,
            key=lambda a: self.alpha * Q[a]
            + (1 - self.alpha) * (Q[a] if beta[a] > epsilon else min_q)
        )

        return action.item(), self.is_exploration(action, actions, Q, beta, epsilon)

    def is_exploration(self, action, actions, Q, beta, epsilon):
        return action != max(actions, key=lambda a: Q[a] if beta[a] > epsilon else 0)

    def __str__(self):
        return f"CorPolicy alpha={self.alpha}"


In [None]:
from collections import deque
import numpy as np

class MetaController:
    def __init__(self, window_size=1000):
        cor_policies = [CorPolicy(1/10) for i in range(1, 10)]
        cov_policies = [CovPolicy(0.05), CovPolicy(0.1)]
        self.policies = cor_policies + cov_policies # Zbiór polityk użytych w artykule

        self.num_policies = len(self.policies)
        self.window_size = window_size
        self.history = deque(maxlen=window_size)  # automatycznie usuwa ostatnie
        # Przechowuje (policy_idx, reward, exploration_ratio)

    def select_policy(self):
        used_indices = [h[0] for h in self.history]
        for i in range(self.num_policies):
            if i not in used_indices:
                return self.policies[i]

        best_value = -float("inf")
        best_policy = 0

        for i in range(self.num_policies):
            value = self.count_mean(i) + self.count_exploration_bonus(i)
            if value > best_value:
                best_value = value
                best_policy = i

        return self.policies[best_policy]


    def update(self, policy_idx, reward, exploration_ratio):
        self.history.append((policy_idx, reward, exploration_ratio))

    def count_mean(self, i):
        rewards = [reward for policy_idx, reward, _ in self.history if policy_idx == i]
        if not rewards:
            return 0.0
        return sum(rewards) / len(rewards)

    def count_exploration_bonus(self, i): # bk(pi, L)
        n = self.count_policy_occurance(i)
        if n == 0:
            return float("inf")
        return (1 / n) * self.count_sum_of_exploration_ratio(i)


    def count_sum_of_exploration_ratio(self, i): # E Bm(pii)
        return sum(exploration_ratio for policy_idx, _, exploration_ratio in self.history if policy_idx == i)

    def count_policy_occurance(self, i): # Nk(pi, L)
        return sum(1 for policy_idx, _, _ in self.history if policy_idx == i)

In [None]:
class BetaDQNAgent:
    def __init__(self, state_dim, action_dim, config, device):
        self.device = device
        self.action_dim = action_dim
        self.config = config # Store config

        # Read Hyperparameters from config
        self.gamma = config["gamma"]
        self.epsilon = config["epsilon_start"]
        self.epsilon_min = config["epsilon_min"]
        self.epsilon_decay = config["epsilon_decay"]
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]

        # Init Networks
        self.policy_net = DQN(state_dim, action_dim).to(device)
        self.target_net = DQN(state_dim, action_dim).to(device)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer_q = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)

        self.beta_net = DQN(state_dim, action_dim).to(device)
        self.optimizer_beta = torch.optim.Adam(self.beta_net.parameters(), lr=self.lr)

        # Init Memory
        self.memory = ReplayBuffer(config["buffer_size"], state_dim)

    def select_action(self, state, policy, training=True):
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)

        if training:
           with torch.no_grad():
            qvals = self.policy_net(state_t)
            beta = self.beta_net(state_t)

           return policy(qvals, beta, self.epsilon)

        with torch.no_grad():
            return self.policy_net(state_t).argmax().item()

    def store_transition(self, s, a, r, ns, d):
        self.memory.push(s, a, r, ns, d)

    def train_step(self):
        if len(self.memory) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)

        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(np.array(actions)).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).to(self.device)

        loss = self.train_td(states, actions, rewards, next_states, dones)
        self.train_beta_net(states, actions)

        return loss.item()

    def train_td(self, states, actions, rewards, next_states, dones):
        curr_q = self.policy_net(states).gather(1, actions).squeeze(1)

        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)

        loss = F.smooth_l1_loss(curr_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

        return loss.item()

    def train_beta_net(self, states, actions):
        logits = self.beta_net(states)
        loss = F.cross_entropy(logits, actions)
        self.optimizer_beta.zero_grad()
        loss.backward()
        self.optimizer_beta.step()

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)



In [None]:
from minigrid.wrappers import ImgObsWrapper
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
import ale_py
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation

def make_env(env_name, render_mode=None):
    """
    Creates the environment with necessary wrappers based on the name.
    """
    if env_name == "MountainCar-v0":
        return gym.make("MountainCar-v0", render_mode=render_mode)

    elif "LavaCrossing" in env_name:
        # Example: MiniGrid-LavaCrossingS9N1-v0
        return gym.make(env_name, render_mode=render_mode)
        return env

    elif env_name == "ALE/Breakout-v5":
        gym.register_envs(ale_py)
        env = gym.make(env_name, frameskip=1, render_mode=render_mode)
        env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=False)
        env = FrameStackObservation(env, stack_size=4)
        return env

    elif "SuperMarioBros" in env_name:
        env = gym_super_mario_bros.make('SuperMarioBros-v0')
        env = JoypadSpace(env, RIGHT_ONLY)
        return env

    else:
        raise ValueError(f"Unknown environment: {env_name}")

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
def run_benchmark(agent_class, env_name, algorithm_name, config, seeds=[42], save_dir="results", meta_controller=None):
    os.makedirs(save_dir, exist_ok=True)
    all_seeds_history = {}

    max_steps = config["total_steps"]

    print(f"\n--- Starting: {algorithm_name} on {env_name} ---")

    for seed in seeds:
        print(f" > Seed {seed}...")
        env = make_env(env_name)
        obs, _ = env.reset(seed=seed)

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # Init Agent
        state_dim = env.observation_space.shape
        action_dim = env.action_space.n
        agent = agent_class(state_dim, action_dim, config, device=device)

        history = {"episode_rewards": [], "global_steps": [], "loss": []}
        global_step_count = 0
        episode_count = 0

        # Track Best Performance
        best_avg_reward = -float('inf')

        while global_step_count < max_steps:
            state, _ = env.reset()
            episode_count += 1
            total_reward = 0
            total_exploration_steps = 0
            ep_losses = []
            done = False
            truncated = False
            policy = None
            steps_per_episode = 0

            if meta_controller:
              policy = meta_controller.select_policy()

            while not (done or truncated):
                if meta_controller:
                  action, is_exploration_move = agent.select_action(state, policy, training=True)
                  if is_exploration_move:
                    total_exploration_steps += 1
                else:
                  action = agent.select_action(state, policy, training=True)

                next_state, reward, done, truncated, _ = env.step(action)
                agent.store_transition(state, action, reward, next_state, done)

                if global_step_count % config.get("train_frequency", 4) == 0:
                    loss = agent.train_step()
                    if loss is not None:
                        ep_losses.append(loss)

                if global_step_count % config.get("target_update_freq", 1000) == 0:
                    agent.target_net.load_state_dict(agent.policy_net.state_dict())

                state = next_state
                total_reward += reward
                global_step_count += 1
                steps_per_episode +=1

                if global_step_count >= max_steps:
                    print(f"   [Limit Reached] Hit max steps at Episode {episode_count}")
                    break

            if meta_controller:
              meta_controller.update(policy, total_reward, total_exploration_steps / steps_per_episode)

            agent.decay_epsilon()

            # Logging
            avg_loss = np.mean(ep_losses) if ep_losses else 0
            history["episode_rewards"].append(total_reward)
            history["global_steps"].append(global_step_count)
            history["loss"].append(avg_loss)

            avg_r = np.mean(history["episode_rewards"][-50:])

            if episode_count > 10000 and avg_r > best_avg_reward:
                  best_avg_reward = avg_r
                  save_path = f"{save_dir}/{algorithm_name}_{env_name}_seed{seed}_best.pth"
                  torch.save(agent.policy_net.state_dict(), save_path)
                  # print(f"   [Saved Best] New Record: {best_avg_reward:.2f}")

            if episode_count % 50 == 0:
                print(f"   Step {global_step_count}/{max_steps} (Ep {episode_count}) | Avg Reward: {avg_r:.2f} | Loss: {avg_loss:.4f} | Eps: {agent.epsilon:.2f}")

        # 3. Save Final Model at end of training
        final_path = f"{save_dir}/{algorithm_name}_{env_name}_seed{seed}_final.pth"
        torch.save(agent.policy_net.state_dict(), final_path)
        print(f"   [Saved Final] Saved to {final_path}")

        all_seeds_history[seed] = history
        env.close()

    # Save Data
    data = {"algorithm": algorithm_name, "env": env_name, "seeds_data": all_seeds_history}
    with open(f"{save_dir}/{algorithm_name}_data.pkl", "wb") as f:
        pickle.dump(data, f)
    return data

In [None]:
def plot_results(data):
    algo_name = data["algorithm"]
    plot_data = []
    for seed, history in data["seeds_data"].items():
        for i, steps in enumerate(history["global_steps"]):
            plot_data.append({
                "Algorithm": algo_name,
                "Global Steps": steps,
                "Reward": history["episode_rewards"][i]
            })
    df = pd.DataFrame(plot_data)
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df, x="Global Steps", y="Reward", errorbar='sd')
    plt.title(f"{algo_name} Training Curve")
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
MC_CONFIG = {
    "lr": 1e-3,
    "gamma": 0.99,
    "batch_size": 32,
    "buffer_size": 100000,
    "epsilon_start": 1.0,
    "epsilon_min": 0.05,
    "epsilon_decay": 0.95,  # Decay every episode
    "target_update_freq": 10, # Update target net every 10 episodes (optional logic)
    "train_frequency": 4,
    "total_steps": 5_000_000
}

BREAKOUT_CONFIG = { #FROM PAPER
    "lr": 1e-3,
    "gamma": 0.99,
    "batch_size": 32,
    "buffer_size": 100000,
    "epsilon_start": 1.0,
    "epsilon_min": 0.01,
    "epsilon_decay": 0.999,
    "target_update_freq": 1000, # Update target net every 10 episodes (optional logic)
    "train_frequency": 4,
    "total_steps": 5_000_000
}

seeds=[31]#, 12, 1123, 111, 145]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

breakout_data = run_benchmark(
    agent_class=BetaDQNAgent,
    env_name="ALE/Breakout-v5",
    algorithm_name="BetaDQN",
    config=BREAKOUT_CONFIG,
    seeds=seeds,
    save_dir="results_breakout",
    meta_controller=MetaController()
)

plot_results(breakout_data)

Running on: cuda

--- Starting: BetaDQN on ALE/Breakout-v5 ---
 > Seed 31...


AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


  return datetime.utcnow().replace(tzinfo=utc)
