# Improving the Plinko DQN algorithm using Double Q-Learning

In [2]:
import random
import pandas as pd
from IPython.display import display
from collections import defaultdict, deque
import numpy as np
import copy # For deep copying Q-table
from collections import namedtuple

### Part 1: Double Q-Learning

#### Motivation: 
>The original Plinko code uses standard Q-learning. Q-learning is known for maximization bias, leading to overestimation of action values. Our standard Q-learning algorithm uses one Q-table to select both the best next action and to evaluate the value of that action. If some action's value is overestimated our max operation will likely select it therefore distributing the overestimation. Double Q-learning ensures that our selection and evaluation are separate. We will use the online Q-table to select the best next action while using the target Q-table to evaluate the value of that chosen action. This will reduce the chance of consistently selecting actions based on overestimated values.

#### Expectation: 
>We expect more accurate Q-value estimates, which will hopefully result in a more stable learning process and convergence to a better final policy to ensure a higher success rate for the target bucket. It might also prevent our agent from getting stuck favouring sub-optimal paths due to early overestimations.

### Global Trackers and Dictionaries

In [3]:
pipes = {}  # maps (x, y) of pipe end -> (x, y) of connected destination
blocks = {}  # maps row_y -> {x: original tile} for restoring blocked rows

# --- Q-Learning Specific ---
# Two Q-tables for Double DQN
q_table_online = defaultdict(lambda: defaultdict(float))
q_table_target = defaultdict(lambda: defaultdict(float))

# --- Experience Replay ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=10000) # Store last 10k transitions
batch_size = 64

# --- Statistics Trackers ---
bucket_tracker = {i: 0 for i in range(5)}  # maps bucket index -> number of landings
ledge_tracker = defaultdict(int)  # maps state tuple -> number of visits
block_row_tracker = defaultdict(int) # maps block state tuple -> number of visits
spike_tracker = defaultdict(int)  # maps spike row y -> number of hits
pipe_tracker = defaultdict(int)  # maps (x, y) of pipe entry/exit -> number of uses
button_tracker = defaultdict(int)  # maps (x, y) of button tile -> number of presses
episode_rewards = [] # Track rewards per episode

### Board Functions

In [4]:
def generate_grid(width, height):
    grid = {}
    for y in range(height):
        for x in range(width):
            if (y % 2 == 0 and x % 2 == 1) or (y % 2 == 1 and x % 2 == 0):
                grid[(x, height - 1 - y)] = 'O'  # place pegs in a checkered pattern
            else:
                grid[(x, height - 1 - y)] = ' '  # empty spaces between pegs
    return grid

def mark_ledge(grid, start_x, length, ledge_y, button_x=None):
    # place a horizontal ledge starting at start_x on row ledge_y
    for x in range(start_x, start_x + length):
        if (x, ledge_y) not in grid or grid[(x, ledge_y)] == ' ': # Avoid overwriting pegs/other features unless empty
            if x == button_x:
                grid[(x, ledge_y)] = '⬒'  # mark a special button tile
                button_tracker[(x, ledge_y)]  # initialize button in tracker
            else:
                grid[(x, ledge_y)] = '_'  # normal ledge tile
    # Ledge tracker now tracks state visits, not just ledge definitions

def mark_spike(grid, start_x, length, spike_y):
    for x in range(start_x, start_x + length):
            if (x, spike_y) not in grid or grid[(x, spike_y)] == ' ':
                grid[(x, spike_y)] = '^'
    spike_tracker[spike_y]  # auto-initializes to 0 if not already set

def mark_pipe(grid, x, y1, y2):
    # mark a vertical pipe that connects y1 and y2 at column x
    top = max(y1, y2)
    bottom = min(y1, y2)

    for y in range(bottom, top + 1):
        if y == top:
            grid[(x, y)] = '⤓'  # down pipe entrance
        elif y == bottom:
            grid[(x, y)] = '↥'  # up pipe entrance
        else:
            tile = grid.get((x, y), ' ')
            grid[(x, y)] = 'Φ' if tile == 'O' else '|'  # middle of the pipe

    # connect both ends in the pipes map
    pipes[(x, top)] = (x, bottom)
    pipes[(x, bottom)] = (x, top)

    # start tracking usage of this pipe
    pipe_tracker[(x, top)]
    pipe_tracker[(x, bottom)]

def mark_slide(grid, start_x, start_y, length, direction):
    slide_char = '\\\\' if direction == "forward" else '/'
    x, y = start_x, start_y

    for _ in range(length):
        if (x, y) in grid and grid[(x, y)] == 'O':
            grid[(x, y)] = slide_char  # replace pegs with slides

        # move diagonally in the selected direction
        if direction == "forward":
            x += 1
            y -= 1
        else:
            x -= 1
            y -= 1

def mark_block(grid, width, row_y):
    if row_y in blocks:
        return  # skip if already marked

    blocks[row_y] = {}  # store original row tiles
    for x in range(width):
        current_tile = grid.get((x, row_y), ' ')
        if current_tile not in {'↥', 'Φ', '⤓', '|'}:  # skip if tile is part of a pipe
            blocks[row_y][x] = current_tile  # remember what was here
            grid[(x, row_y)] = '█'  # mark block tile

def unmark_block(grid, row_y):
    if row_y not in blocks:
        return  # nothing to unmark

    for x, original_char in blocks[row_y].items():
        grid[(x, row_y)] = original_char  # restore original tile
    del blocks[row_y]  # remove from block tracker

def mark_buckets(width, num_buckets):
    buckets = {}  # maps x to bucket index
    base_size = width // num_buckets  # base size for each bucket
    extra = width % num_buckets  # leftover columns
    middle_bucket = num_buckets // 2  # middle bucket index
    start_x = 0  # starting column for current bucket

    for i in range(num_buckets):
        # add 1 to size if extra columns remain and it's not the middle bucket
        size = base_size + (1 if extra > 0 and i != middle_bucket else 0)
        for x in range(start_x, start_x + size):
            buckets[x] = i  # map each column to bucket index
        start_x += size  # move to next start column
        if extra > 0 and i != middle_bucket:
            extra -= 1  # use up one extra column

    return buckets

def visualize_grid(grid, width, height, ball_position=None, buckets=None):
    x_labels = "   " + " ".join(str(i % 10) for i in range(width))  # x-axis labels
    print(x_labels)  # print top x-axis
    for y in range(height - 1, -1, -1):
        row = f"{y:2} "  # add y-axis label
        for x in range(width):
            if ball_position and (x, y) == ball_position:
                row += 'X'  # draw ball
            else:
                row += grid.get((x, y), ' ')  # draw tile
            row += " "
        print(row)  # print full row

    bucket_row = "   "
    if buckets:
        for x in range(width):
            bucket_row += str(buckets.get(x, ' ')) + " "
    else:
        bucket_row += "  " * width
    print(bucket_row)  # print bucket labels
    print(x_labels)  # print bottom x-axis
    print("===" + "=" * (2 * width))  # draw horizontal divider

### Game Logic: DQN

In [6]:
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) # keep track of action-states for now

# helper: epsilon-greedy action selection
def choose_action(grid, state, available_actions, q_table, epsilon, width):
    # choose action using epsilon-greedy policy based on online q-table
    # first check if state exists in the q_table
    if state not in q_table:
        q_table[state] = defaultdict(float)
        valid_actions = list(range(width)) if isinstance(state[0], str) and state[0] == 'block' else [col for col in range(width) if grid.get((col, state[0][1])) in {'_', '⤓', '↥', '⬒'}]
        for act in valid_actions:
             q_table[state][act] = 0.0 # initialize q-values to 0

    # determine the actual available actions for the q-table logic
    # for ledges, it's specific columns; for blocks, it's all columns
    current_available_actions = list(q_table[state].keys())
    if not current_available_actions:
        # check if somehow a state exists but has no actions initialized
         valid_actions = list(range(width)) if isinstance(state[0], str) and state[0] == 'block' else [col for col in range(width) if grid.get((col, state[0][1])) in {'_', '⤓', '↥', '⬒'}]
         for act in valid_actions:
             q_table[state][act] = 0.0
         current_available_actions = valid_actions

    if not current_available_actions:
         # if there are still no actions, choose randomly across width, might have to rethink
         print(f"Warning: No available actions found for state {state}. Choosing random column.")
         return random.choice(list(range(width)))

    if random.random() < epsilon:
        return random.choice(current_available_actions)
    else:
        # we choose the action with the highest q-value
        q_values = q_table[state]
        max_q = -float('inf')
        best_actions = []
        for act in current_available_actions:
            q_val = q_values.get(act, 0.0)
            if q_val > max_q:
                max_q = q_val
                best_actions = [act]
            elif q_val == max_q:
                best_actions.append(act)

        return random.choice(best_actions) # ties: choose randomly among best actions

# helper: find ledge state
def find_ledge_state_key(x, y, width, grid, pressed_buttons):
    # finding the state key for the ledge the ball is currently on
    for ledge_y_check in range(y, y + 2): # checking current row and row above for ledge start
        for ledge_start_x in range(x, -1, -1):
            tile = grid.get((ledge_start_x, ledge_y_check))
            if tile not in {'_', '⬒', '⤓', '↥'}:
                # accounting for going too far left/hitting non-ledge tile, start at ledge_start_x + 1
                ledge_start_x += 1
                break
        # check if a ledge actually starts at ledge_start_x on ledge_y_check
        start_tile = grid.get((ledge_start_x, ledge_y_check))
        if start_tile in {'_', '⬒', '⤓', '↥'}:
             # check if current x is within the bounds of this ledge
             current_x = ledge_start_x
             while grid.get((current_x, ledge_y_check)) in {'_', '⬒', '⤓', '↥'}:
                 if current_x == x:
                     return ((ledge_start_x, ledge_y_check), frozenset(pressed_buttons))
                 current_x += 1
                 if current_x >= width:
                      break
    return None