# Improving the Plinko DQN algorithm using Double Q-Learning

In [3]:
import random
import pandas as pd
from IPython.display import display
from collections import defaultdict, deque
import numpy as np
import copy # For deep copying Q-table

### Improvement 1: Double Q-Learning (Two Q-Tables)

#### Motivation: 
The original Plinko code uses standard Q-learning. Q-learning is known to suffer from maximization bias, leading to overestimation of action values. This happens because it uses the same Q-table to both select the best next action (max_a' Q(s', a')) and evaluate the value of that action (Q(s', argmax_a' Q(s', a'))). If some action's value is overestimated due to noise or limited sampling, the max operation is likely to select it, propagating the overestimation. Double Q-learning decouples selection and evaluation. We'll use the online Q-table (q_table_online) to select the best next action (argmax_a' Q_online(s', a')) but use the target Q-table (q_table_target) to evaluate the value of that chosen action (Q_target(s', argmax_a' Q_online(s', a'))). This reduces the chance of consistently selecting actions based on overestimated values.

#### Expectation: 
This should lead to more accurate Q-value estimates, potentially resulting in a more stable learning process and convergence to a better final policy (higher success rate for the target bucket). It might prevent the agent from getting stuck favouring suboptimal paths due to early overestimations.

### Global Trackers and Dictionaries

In [None]:
pipes = {}  # maps (x, y) of pipe end -> (x, y) of connected destination
blocks = {}  # maps row_y -> {x: original tile} for restoring blocked rows

# --- Q-Learning Specific ---
# Two Q-tables for Double DQN
q_table_online = defaultdict(lambda: defaultdict(float))
q_table_target = defaultdict(lambda: defaultdict(float))

# --- Experience Replay ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=10000) # Store last 10k transitions
batch_size = 64

# --- Statistics Trackers ---
bucket_tracker = {i: 0 for i in range(5)}  # maps bucket index -> number of landings
ledge_tracker = defaultdict(int)  # maps state tuple -> number of visits
block_row_tracker = defaultdict(int) # maps block state tuple -> number of visits
spike_tracker = defaultdict(int)  # maps spike row y -> number of hits
pipe_tracker = defaultdict(int)  # maps (x, y) of pipe entry/exit -> number of uses
button_tracker = defaultdict(int)  # maps (x, y) of button tile -> number of presses
episode_rewards = [] # Track rewards per episode

### Board Functions

In [None]:
def generate_grid(width, height):
    grid = {}
    for y in range(height):
        for x in range(width):
            if (y % 2 == 0 and x % 2 == 1) or (y % 2 == 1 and x % 2 == 0):
                grid[(x, height - 1 - y)] = 'O'  # place pegs in a checkered pattern
            else:
                grid[(x, height - 1 - y)] = ' '  # empty spaces between pegs
    return grid

def mark_ledge(grid, start_x, length, ledge_y, button_x=None):
    # place a horizontal ledge starting at start_x on row ledge_y
    for x in range(start_x, start_x + length):
        if (x, ledge_y) not in grid or grid[(x, ledge_y)] == ' ': # Avoid overwriting pegs/other features unless empty
            if x == button_x:
                grid[(x, ledge_y)] = '⬒'  # mark a special button tile
                button_tracker[(x, ledge_y)]  # initialize button in tracker
            else:
                grid[(x, ledge_y)] = '_'  # normal ledge tile
    # Ledge tracker now tracks state visits, not just ledge definitions

def mark_spike(grid, start_x, length, spike_y):
    for x in range(start_x, start_x + length):
            if (x, spike_y) not in grid or grid[(x, spike_y)] == ' ':
            grid[(x, spike_y)] = '^'
    spike_tracker[spike_y]  # auto-initializes to 0 if not already set

def mark_pipe(grid, x, y1, y2):
    # mark a vertical pipe that connects y1 and y2 at column x
    top = max(y1, y2)
    bottom = min(y1, y2)

    for y in range(bottom, top + 1):
        if y == top:
            grid[(x, y)] = '⤓'  # down pipe entrance
        elif y == bottom:
            grid[(x, y)] = '↥'  # up pipe entrance
        else:
            tile = grid.get((x, y), ' ')
            grid[(x, y)] = 'Φ' if tile == 'O' else '|'  # middle of the pipe

    # connect both ends in the pipes map
    pipes[(x, top)] = (x, bottom)
    pipes[(x, bottom)] = (x, top)

    # start tracking usage of this pipe
    pipe_tracker[(x, top)]
    pipe_tracker[(x, bottom)]

def mark_slide(grid, start_x, start_y, length, direction):
    slide_char = '\\\\' if direction == "forward" else '/'
    x, y = start_x, start_y

    for _ in range(length):
        if (x, y) in grid and grid[(x, y)] == 'O':
            grid[(x, y)] = slide_char  # replace pegs with slides

        # move diagonally in the selected direction
        if direction == "forward":
            x += 1
            y -= 1
        else:
            x -= 1
            y -= 1

def mark_block(grid, width, row_y):
    if row_y in blocks:
        return  # skip if already marked

    blocks[row_y] = {}  # store original row tiles
    for x in range(width):
        current_tile = grid.get((x, row_y), ' ')
        if current_tile not in {'↥', 'Φ', '⤓', '|'}:  # skip if tile is part of a pipe
            blocks[row_y][x] = current_tile  # remember what was here
            grid[(x, row_y)] = '█'  # mark block tile

def unmark_block(grid, row_y):
    if row_y not in blocks:
        return  # nothing to unmark

    for x, original_char in blocks[row_y].items():
        grid[(x, row_y)] = original_char  # restore original tile
    del blocks[row_y]  # remove from block tracker

def mark_buckets(width, num_buckets):
    buckets = {}  # maps x to bucket index
    base_size = width // num_buckets  # base size for each bucket
    extra = width % num_buckets  # leftover columns
    middle_bucket = num_buckets // 2  # middle bucket index
    start_x = 0  # starting column for current bucket

    for i in range(num_buckets):
        # add 1 to size if extra columns remain and it's not the middle bucket
        size = base_size + (1 if extra > 0 and i != middle_bucket else 0)
        for x in range(start_x, start_x + size):
            buckets[x] = i  # map each column to bucket index
        start_x += size  # move to next start column
        if extra > 0 and i != middle_bucket:
            extra -= 1  # use up one extra column

    return buckets

def visualize_grid(grid, width, height, ball_position=None, buckets=None):
    x_labels = "   " + " ".join(str(i % 10) for i in range(width))  # x-axis labels
    print(x_labels)  # print top x-axis
    for y in range(height - 1, -1, -1):
        row = f"{y:2} "  # add y-axis label
        for x in range(width):
            if ball_position and (x, y) == ball_position:
                row += 'X'  # draw ball
            else:
                row += grid.get((x, y), ' ')  # draw tile
            row += " "
        print(row)  # print full row

    bucket_row = "   "
    if buckets:
        for x in range(width):
            bucket_row += str(buckets.get(x, ' ')) + " "
    else:
        bucket_row += "  " * width
    print(bucket_row)  # print bucket labels
    print(x_labels)  # print bottom x-axis
    print("===" + "=" * (2 * width))  # draw horizontal divider