# Improving the Plinko DQN algorithm using Double Q-Learning

In [None]:
import random
from collections import defaultdict, deque
import copy # deep copying Q-table
from collections import namedtuple
from state_utils_deep import drop_ball, initialize_trackers
from board_builder import build_board
from visualization import print_training_stats, visualize_grid
import torch
import torch.nn as nn
import torch.optim as optim


### Part 1: Double Q-Learning

#### Motivation: 
>The original Plinko code uses standard Q-learning. Q-learning is known for maximization bias, leading to overestimation of action values. Our standard Q-learning algorithm uses one Q-table to select both the best next action and to evaluate the value of that action. If some action's value is overestimated our max operation will likely select it therefore distributing the overestimation. Double Q-learning ensures that our selection and evaluation are separate. We will use the online Q-table to select the best next action while using the target Q-table to evaluate the value of that chosen action. This will reduce the chance of consistently selecting actions based on overestimated values.

#### Expectation: 
>We expect more accurate Q-value estimates, which will hopefully result in a more stable learning process and convergence to a better final policy to ensure a higher success rate for the target bucket. It might also prevent our agent from getting stuck favouring sub-optimal paths due to early overestimations.

In [None]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        return self.fc3(torch.relu(self.fc2(torch.relu(self.fc1(x)))))

def encode_state(state, width):
    # example encoding: [is_block (0/1)] + y + buttons (one-hot for up to N buttons)
    if isinstance(state[0], tuple):  # ledge
        x, y = state[0]
        button_vector = torch.zeros(width * 5)  # assuming max 5 buttons, adjust as needed
        for bx, by in state[1]:
            index = (by * width + bx) % (width * 5)
            button_vector[index] = 1
        return torch.tensor([0, x, y], dtype=torch.float32).unsqueeze(0), button_vector.unsqueeze(0)
    elif isinstance(state[0], str):  # block
        y = state[0][1]
        button_vector = torch.zeros(width * 5)
        for bx, by in state[1]:
            index = (by * width + bx) % (width * 5)
            button_vector[index] = 1
        return torch.tensor([1, y], dtype=torch.float32).unsqueeze(0), button_vector.unsqueeze(0)
    else:
        raise ValueError("Unrecognized state format")

def preprocess_state(state, width):
    base, buttons = encode_state(state, width)
    return torch.cat([base, buttons], dim=1)


### Double Q-Learning and Experience Replay

In [None]:
# declaration of tracker dictionaries
trackers = initialize_trackers()

# DDQN learning initialization
learning_rate = 0.1
discount_factor = 0.99 # higher discount factor for potentially long paths
exploration_rate = 1.0  # start fully exploratory
exploration_decay = 0.999  # slow decay
min_exploration = 0.01  # smallest possible exploration rate
episodes = 1000  # number of training episodes
initial_free_exploration = 100

update_frequency = 4 # learn every 4 steps
target_update_frequency = 100 # update target table every 100 steps
soft_update_alpha = 0.1 # soft update parameter

# Experience Replay
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=10000) # store last 10k transitions
batch_size = 64

# training setup
target_bucket = 4  # the bucket the agent should aim for

### Training Loop (DDQN)

In [None]:
loss_fn = nn.MSELoss()

def learn(grid, width, learning_rate, discount_factor):
    if len(replay_buffer) < batch_size:
        return

    batch = random.sample(replay_buffer, batch_size)

    state_batch = torch.cat([preprocess_state(s, width) for s, _, _, _, _ in batch])
    action_batch = torch.tensor([a for _, a, _, _, _ in batch], dtype=torch.int64).unsqueeze(1)
    reward_batch = torch.tensor([r for _, _, r, _, _ in batch], dtype=torch.float32).unsqueeze(1)
    done_batch = torch.tensor([d for _, _, _, _, d in batch], dtype=torch.float32).unsqueeze(1)

    next_states = [s for _, _, _, s, _ in batch]
    non_final_mask = torch.tensor([s is not None for s in next_states], dtype=torch.bool)
    non_final_next_states = torch.cat([preprocess_state(s, width) for s in next_states if s is not None])

    q_values = online_net(state_batch).gather(1, action_batch)

    target_q_values = torch.zeros(batch_size, 1)
    if non_final_next_states.size(0) > 0:
        next_actions = online_net(non_final_next_states).argmax(dim=1).unsqueeze(1)
        target_q = target_net(non_final_next_states).gather(1, next_actions)
        target_q_values[non_final_mask] = target_q.detach()

    expected_q = reward_batch + discount_factor * target_q_values * (1 - done_batch)

    loss = loss_fn(q_values, expected_q)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()  # <-- return the loss

def update_target_network(online_model, target_model, alpha):
    for target_param, online_param in zip(target_model.parameters(), online_model.parameters()):
        target_param.data.copy_(alpha * online_param.data + (1 - alpha) * target_param.data)

In [None]:
total_decision_steps = [0]
steps_after_episode = 0
episode_rewards_history = []
most_recent_rewards = deque(maxlen=100)
total_stars_collected = 0



grid, buckets, width, height = build_board("hard", trackers)

# Q-Learning Specific
# Two Q-tables for Double DQN
input_dim = 3 + (width * 5)  # 3 for [type, x, y] or [type, y], rest for button vector
output_dim = width  # one output per column choice

online_net = QNetwork(input_dim, output_dim)
target_net = QNetwork(input_dim, output_dim)
target_net.load_state_dict(online_net.state_dict())
target_net.eval()

optimizer = optim.Adam(online_net.parameters(), lr=0.001)

start_x = random.randint(0, width - 1)

visualize_grid(grid, width, height, ball_position=(start_x, height - 1), buckets=buckets)

# agent determines when to call learn() based on total steps
def should_learn():
     return total_decision_steps[0] % update_frequency == 0

# training loop
for episode in range(episodes):
    grid, buckets, width, height = build_board("hard", trackers)
         
    start_x = random.randint(0, width - 1)
    
    episode_final_reward, final_bucket, stars_collected = drop_ball(
        grid=grid,
        width=width,
        height=height,
        start_x=start_x,
        buckets=buckets,
        target_bucket=target_bucket,
        exploration_rate=exploration_rate,
        q_model=online_net,
        trackers=trackers,
        extra={
            "replay_buffer": replay_buffer,
            "should_learn": should_learn,
            "learn": learn,
            "Experience": Experience,
            "target_net": target_net,
            "soft_update_alpha": soft_update_alpha,
            "update_target_network": update_target_network,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "discount_factor": discount_factor,
            "total_decision_steps": total_decision_steps,  # pass as list for mutability
            "target_update_frequency": target_update_frequency
        },
        # visualize=(episode == episodes - 3)
        visualize=False
    )
    steps_after_episode += total_decision_steps[0]
    total_stars_collected += len(stars_collected)
    
    # perform learning steps from replay buffer
    losses = []
    if len(replay_buffer) > batch_size:
        for _ in range(10): # example 10 learning steps per episode end
            loss = learn(grid, width, learning_rate, discount_factor)
            if loss is not None:
                losses.append(loss)
             
    episode_rewards_history.append(episode_final_reward) # save the final reward
    most_recent_rewards.append(episode_final_reward)
    
    if episode >= initial_free_exploration:
        exploration_rate = max(min_exploration, exploration_rate * exploration_decay)

    # Print per-episode result
    if (episode + 1) % 10 == 0:
        print(f"[Episode {episode+1}] Reward: {episode_final_reward} | Bucket: {final_bucket} | Stars: {len(stars_collected)}")

    # print progress
    if (episode + 1) % 100 == 0:
        avg_reward = sum(most_recent_rewards) / len(most_recent_rewards)
        avg_stars = total_stars_collected / (episode + 1)
        if losses:
            avg_loss = sum(losses) / len(losses)
            print(f"Ep {episode + 1} | Avg R (last 100): {avg_reward:.2f} | Stars: {avg_stars:.2f} | Loss: {avg_loss:.4f} | ε: {exploration_rate:.2f}")
        else:
            print(f"Ep {episode + 1} | Avg R (last 100): {avg_reward:.2f} | Stars: {avg_stars:.2f} | ε: {exploration_rate:.2f}")
