# Double DQN - Snake game environment

## Step 1 - Imports

In [1]:
import gym_snakegame
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import deque

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


# Step 2: Replay Buffer & Neural Network

Replay Buffer: Stores experiences $(S, A, R, S', Done)$ and samples them randomly to break correlations between consecutive frames.

DQN: Maps the current state (100 inputs for every cell in input space) to Q-values for every possible action (4 outputs).

In [2]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (torch.tensor(states, dtype=torch.float32).to(DEVICE),
                torch.tensor(actions, dtype=torch.int64).to(DEVICE),
                torch.tensor(rewards, dtype=torch.float32).to(DEVICE),
                torch.tensor(next_states, dtype=torch.float32).to(DEVICE),
                torch.tensor(dones, dtype=torch.float32).to(DEVICE))
    
    def __len__(self):
        return len(self.buffer)

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, action_dim)
        )

    def forward(self, x):
        return self.net(x)

def select_action(model, state, epsilon, act_dim):
    # Exploration: random action
    if random.random() < epsilon:
        return random.randrange(act_dim)
    # Exploitation: best action according to the network
    with torch.no_grad():
        # Inference: NN estimates action-value function
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(DEVICE)
        return int(torch.argmax(model(state)).item())

# Step 3: The Dual-Mode Training Function
This function contains the core logic. It accepts a double_dqn boolean flag.

Standard DQN: The Target Net finds the max Q-value directly. It is optimistic.

Double DQN: The Policy Net chooses the best action (argmax), and the Target Net calculates the value of that specific action. This separation prevents the "optimism" from spiraling out of control.

In [3]:
def run_snake_game_experiment(exp_name, double_dqn=True, total_episodes=400):
    env = gym.make(
        "gym_snakegame/SnakeGame-v0", board_size=10, n_channel=1, n_target=1, render_mode=None # 'human' to visualize during training
    )
    # Space Dim. 1x10x10
    obs_dim = np.prod(env.observation_space.shape)
    action_dim = env.action_space.n
    
    # Hyperparameters
    lr = 10e-4
    gamma = 0.95
    batch_size = 256
    target_update_freq = 200
    buffer_capacity = 100000
    min_buffer_size = 2000
    epsilon_decay = 300000 # Decay based on steps, not episodes
    
    # Initialize Networks
    policy_net = DQN(obs_dim, action_dim).to(DEVICE)
    target_net = DQN(obs_dim, action_dim).to(DEVICE)
    target_net.load_state_dict(policy_net.state_dict())
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    replay_buffer = ReplayBuffer(buffer_capacity)
    
    epsilon_start = 1.0
    epsilon_final = 0.01
    steps_done = 0
    rewards_history = []
    q_value_history = [] # To track Maximization Bias

    print(f"--- Starting: {exp_name} | Double: {double_dqn} ---")
    episode_max_reward = -2
    episode_min_reward = 100

    for episode in range(total_episodes):
        state, _ = env.reset()
        episode_reward = 0
        episode_q_vals = []
        loss = 0
        
        while True:
            # Epsilon Decay (Exponential)
            epsilon = epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * steps_done / epsilon_decay)
            
            action = select_action(policy_net, state, epsilon, action_dim)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            steps_done += 1

            if len(replay_buffer) >= min_buffer_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

                with torch.no_grad():
                    if double_dqn:
                        # --- DOUBLE DQN LOGIC ---
                        # 1. Action Selection: Policy Net decides "which" action is best
                        # Takes higher value from output layer of actions
                        # (batch_size, output_values) => returns best action for each element in batch
                        best_actions = policy_net(next_states).argmax(1)
                        # 2. Action Evaluation: Target Net calculates the value of THAT action
                        next_q = target_net(next_states).gather(1, best_actions.unsqueeze(1)).squeeze(1)
                    else:
                        # --- STANDARD DQN LOGIC ---
                        # Target Net selects AND evaluates (Max operator). This causes bias.
                        # output è una tupla: ([values], [indices]) => prende i valori
                        next_q = target_net(next_states).max(1)[0]
                    
                    target = rewards + (1 - dones) * gamma * next_q

                # Stime Q(s,a) per ogni azione [batch, num_actios]
                q_values = policy_net(states)
                # seleziona il valore Q corrispondente all’azione eseguita per ogni stato del batch
                # array dim. [batch, 1]
                current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
                
                # Save average Q-value for analysis
                episode_q_vals.append(current.mean().item())

                loss = nn.SmoothL1Loss()(current, target)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_value_(policy_net.parameters(), 1.0)
                optimizer.step()

            if steps_done % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

            if done:
                break
        
        rewards_history.append(episode_reward)
        q_value_history.append(np.mean(episode_q_vals) if episode_q_vals else 0)
        if episode_reward > episode_max_reward:
            episode_max_reward = episode_reward
        elif episode_reward < episode_min_reward:
            episode_min_reward = episode_reward

        if episode % 50 == 0:
            print(f"Ep {episode}: Reward {episode_reward:.2f}, Min: {episode_min_reward:.2f}, Max: {episode_max_reward:.2f} | Avg Q: {q_value_history[-1]:.2f} | Eps: {epsilon:.2f} | Loss: {loss:.4f}")
            episode_max_reward = -2
            episode_min_reward = 100

    env.close()
    return rewards_history, q_value_history

# Step 4: Running the Comparison

We run the training loop twice to collect data for comparison.

In [4]:
results = {}
N_EPISODES = 20000

# 1. Train Standard DQN
# print("Collecting data for Standard DQN...")
# res_dqn = run_snake_game_experiment('DQN', double_dqn=False, total_episodes=N_EPISODES)
# results['Standard DQN'] = res_dqn

# 2. Train Double DQN
print("\nCollecting data for Double DQN...")
res_ddqn = run_snake_game_experiment('Double DQN', double_dqn=True, total_episodes=N_EPISODES)
results['Double DQN'] = res_ddqn


Collecting data for Double DQN...


  from pkg_resources import resource_stream, resource_exists


--- Starting: Double DQN | Double: True ---
Ep 0: Reward -1.00, Min: 100.00, Max: -1.00 | Avg Q: 0.00 | Eps: 1.00 | Loss: 0.0000
Ep 50: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.00 | Eps: 1.00 | Loss: 0.0000
Ep 100: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.00 | Eps: 0.99 | Loss: 0.0000


  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


Ep 150: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: -0.03 | Eps: 0.99 | Loss: 0.0240
Ep 200: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.07 | Eps: 0.99 | Loss: 0.0102
Ep 250: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.21 | Eps: 0.99 | Loss: 0.0032
Ep 300: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.30 | Eps: 0.98 | Loss: 0.0092
Ep 350: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.36 | Eps: 0.98 | Loss: 0.0021
Ep 400: Reward 0.00, Min: -1.00, Max: 0.00 | Avg Q: 0.43 | Eps: 0.98 | Loss: 0.0036
Ep 450: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.50 | Eps: 0.98 | Loss: 0.0028
Ep 500: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.53 | Eps: 0.97 | Loss: 0.0028
Ep 550: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.60 | Eps: 0.97 | Loss: 0.0030
Ep 600: Reward -1.00, Min: -1.00, Max: 0.00 | Avg Q: 0.67 | Eps: 0.97 | Loss: 0.0065
Ep 650: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.75 | Eps: 0.97 | Loss: 0.0027
Ep 700: Reward -1.00, Min: -1.00, Max: 1.00 | Avg Q: 0.78 | Eps: 