# Improving the Plinko DQN algorithm using Double Q-Learning

In [5]:
import random
from collections import defaultdict, deque
import copy # deep copying Q-table
from collections import namedtuple
from grid_utils import unmark_block
from state_utils import choose_action, find_ledge_state_key, get_valid_diagonal_moves, handle_blocks, drop_ball, initialize_trackers
from board_builder import build_board
from visualization import print_training_stats, visualize_grid

### Part 1: Double Q-Learning

#### Motivation: 
>The original Plinko code uses standard Q-learning. Q-learning is known for maximization bias, leading to overestimation of action values. Our standard Q-learning algorithm uses one Q-table to select both the best next action and to evaluate the value of that action. If some action's value is overestimated our max operation will likely select it therefore distributing the overestimation. Double Q-learning ensures that our selection and evaluation are separate. We will use the online Q-table to select the best next action while using the target Q-table to evaluate the value of that chosen action. This will reduce the chance of consistently selecting actions based on overestimated values.

#### Expectation: 
>We expect more accurate Q-value estimates, which will hopefully result in a more stable learning process and convergence to a better final policy to ensure a higher success rate for the target bucket. It might also prevent our agent from getting stuck favouring sub-optimal paths due to early overestimations.

### Double Q-Learning and Experience Replay

In [6]:
# declaration of tracker dictionaries
trackers = initialize_trackers()

# DDQN learning initialization
learning_rate = 0.1
discount_factor = 0.99 # higher discount factor for potentially long paths
exploration_rate = 1.0  # start fully exploratory
exploration_decay = 0.9999  # slow decay
min_exploration = 0.01  # smallest possible exploration rate
episodes = 10000  # number of training episodes
initial_free_exploration = 1000

update_frequency = 4 # learn every 4 steps
target_update_frequency = 100 # update target table every 100 steps
soft_update_alpha = 0.1 # soft update parameter

# Experience Replay
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=10000) # store last 10k transitions
batch_size = 64

# training setup
target_bucket = 2  # the bucket the agent should aim for

# Q-Learning Specific
# Two Q-tables for Double DQN
q_table_online = defaultdict(lambda: defaultdict(float))
q_table_target = defaultdict(lambda: defaultdict(float))

### Training Loop (DDQN)

In [7]:
def learn(grid, width, learning_rate, discount_factor):
    if len(replay_buffer) < batch_size:
        return # not enough samples yet

    mini_batch = random.sample(replay_buffer, batch_size)

    for experience in mini_batch:
        state, action, reward, next_state, done = experience
        # check state and action exist in online table
        if state not in q_table_online or action not in q_table_online[state]:
             q_table_online[state][action] = 0.0
        if state not in q_table_target or action not in q_table_target[state]:
             q_table_target[state][action] = 0.0
            
        current_q_online = q_table_online[state][action]

        # target calculation
        if done:
            target_q = reward # use the final reward directly
        else:
            next_state_type = next_state[0]
            if isinstance(next_state_type, str) and next_state_type == 'block':
                next_available_actions = list(range(width))
            elif isinstance(next_state_type, tuple):
                ledge_y_next = next_state_type[1]
                next_available_actions = [col for col in range(width) if grid.get((col, ledge_y_next)) in {'_', '⬒', '⤓', '↥'}]
            else:
                print(f"Error: Unknown next state format in learn: {next_state}")
                next_available_actions = []
                
            if next_state not in q_table_online:
                 q_table_online[next_state] = defaultdict(float)
                 for act in next_available_actions:
                    q_table_online[next_state][act] = 0.0
            if next_state not in q_table_target:
                 q_table_target[next_state] = defaultdict(float)
                 for act in next_available_actions:
                    q_table_target[next_state][act] = 0.0
            
            # best action in next state using online table
            online_q_next = q_table_online[next_state]
            if not online_q_next or not next_available_actions: # if no actions available or state just initialized
                 best_next_action = None
                 max_online_q_next = 0.0
            else:
                 max_online_q_next = -float('inf')
                 best_actions_next = []
                 for act in next_available_actions: # check only valid next actions
                    q_val = online_q_next.get(act, 0.0) 
                    if q_val > max_online_q_next:
                        max_online_q_next = q_val
                        best_actions_next = [act]
                    elif q_val == max_online_q_next:
                        best_actions_next.append(act)
                 if not best_actions_next:
                      best_next_action = random.choice(next_available_actions)
                 else:
                     best_next_action = random.choice(best_actions_next)

            # get Q-value for that best action using target table
            q_target_next = q_table_target[next_state].get(best_next_action, 0.0) # 0 if action missing
                 
            # reward for intermediate steps is 0
            target_q = discount_factor * q_target_next # R_i=0 + gamma * Q_target(s', argmax Q_online(s',a'))
        # update online Q-table
        q_table_online[state][action] = current_q_online + learning_rate * (target_q - current_q_online)

# target table update        
def update_target_network(online_q_table, target_q_table, alpha):
    for state in online_q_table:
        if state not in target_q_table:
             target_q_table[state] = defaultdict(float)
             
        for action in online_q_table[state]:
            # initialize target action value if it doesn't exist
            if action not in target_q_table[state]:
                 target_q_table[state][action] = 0.0 
            target_q_table[state][action] = (1.0 - alpha) * target_q_table[state][action] + alpha * online_q_table[state][action]

# agent determines when to call learn() based on total steps
def should_learn():
     return total_decision_steps[0] % update_frequency == 0

In [8]:
total_decision_steps = [0]
steps_after_episode = 0
episode_rewards_history = []
most_recent_rewards = deque(maxlen=100)
total_stars_collected = 0

# synchronize Q-tables initially
q_table_target = copy.deepcopy(q_table_online) 

grid, buckets, width, height = build_board("hard", trackers)

start_x = random.randint(0, width - 1)

visualize_grid(grid, width, height, ball_position=(start_x, height - 1), buckets=buckets)

# training loop
for episode in range(episodes):
    grid, buckets, width, height = build_board("hard", trackers)
         
    start_x = random.randint(0, width - 1)
    
    episode_final_reward, final_bucket, stars_collected = drop_ball(
        grid=grid,
        width=width,
        height=height,
        start_x=start_x,
        buckets=buckets,
        target_bucket=target_bucket,
        mode="dqn",
        exploration_rate=exploration_rate,
        q_table=q_table_online,
        trackers=trackers,
        extra={
            "replay_buffer": replay_buffer,
            "should_learn": should_learn,
            "learn": learn,
            "Experience": Experience,
            "q_table_target": q_table_target,
            "soft_update_alpha": soft_update_alpha,
            "update_target_network": update_target_network,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "discount_factor": discount_factor,
            "total_decision_steps": total_decision_steps,  # pass as list for mutability
            "target_update_frequency": target_update_frequency
        },
        # visualize=(episode == episodes - 3)
        visualize=False
    )
    steps_after_episode += total_decision_steps[0]
    total_stars_collected += len(stars_collected)

    # check if target table update is due based on step count progression
    if (total_decision_steps[0] // target_update_frequency) < (steps_after_episode // target_update_frequency):
       update_target_network(q_table_online, q_table_target, soft_update_alpha)
       # print(f"Target network updated after episode {episode + 1}")
    
    # perform learning steps from replay buffer
    if len(replay_buffer) > batch_size:
        for _ in range(10): # example 10 learning steps per episode end
             learn(grid, width, learning_rate, discount_factor) 
             
    episode_rewards_history.append(episode_final_reward) # save the final reward
    most_recent_rewards.append(episode_final_reward)
    
    if episode >= initial_free_exploration:
        exploration_rate = max(min_exploration, exploration_rate * exploration_decay)

    # print progress
    if (episode + 1) % 100 == 0:
        avg_reward = sum(most_recent_rewards) / len(most_recent_rewards)
        avg_stars = total_stars_collected / (episode + 1)
        print(f"Episode {episode + 1} | Avg Reward (Last 100): {avg_reward:.2f} | Avg Stars: {avg_stars:.2f} | Exploration Rate: {exploration_rate:.2f}")

# final statistics summary
print_training_stats(
    trackers,
    q_table_online  # or q_table depending on version
)


   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4
99   O   O   O X O   O   O   O   O   O   O   O   O   
98 O   O   O   O   O   O   O   O   O   O   O   O ⤓ ☆ 
97   O   O   O   O   O   O   O   O   O   O   O   Φ   
96 O   O   O   O   O   O   O   O   O   O   O   O | O 
95 _ _ _ _ _ O   O   O _ _ _ _ _ _   O   O   O   Φ   
94 O   O   O   O   O   O   O   O   O   O   O   O | O 
93   O   O   O   O   O   O   O   O   O   O   O   Φ   
92 O   O   O   O   O   O   O   O   O   O   O   O | O 
91   O   _ _ _ _ _ _ _ ⤓ O   O   O   O   O   O   Φ   
90 O   O   O   O   O   Φ   O   O   O   O   O   O | O 
89   O   O   O   O   O | O   O   _ _ _ ⤓ _ _ _ _ ↥   
88 O   O   O   O   O   Φ   O   O   O   Φ   O   O   / 
87   O   O   O   O   O | O   O   O   O | O   O   /   
86 O   O   O   O   O   Φ   O   O   O   Φ   O   /   O 
85   O _ _ ⬒ _ _ O   O | O   O   O   O | O   /   O   
84 O   O   O   O   O   Φ   O   O   O   Φ   /   O   O 
83   O   O   O   O   O | O   O   _ _ _ ↥ _   O   O   
82 O   O   O   O   O   Φ   O 

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,10,11,12,13,14,...,9,20,16,17,18,19,21,22,23,24
position,buttons_pressed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
"(23, 98)","((24, 98),)",,,,,,,,,,,...,,,,,,,,,3.765807,3.766818
"(23, 98)",(),,,,,,,,,,,...,,,,,,,,,3.274226,3.737111
"(0, 95)",(),1.657538,1.660767,1.661866,1.564378,1.641924,1.573074,2.497136,1.764026,2.414461,1.635616,...,,,,,,,,,,
"(10, 95)",(),1.66383,1.654885,1.654361,1.661614,1.623667,1.658592,1.664211,1.674442,2.746805,3.097847,...,,,,,,,,,,
"(3, 91)","((4, 85), (24, 98))",,,,2.704537,2.477349,2.656048,,,,,...,3.183968,,,,,,,,,
"(3, 91)",(),,,,1.650096,1.649266,1.644139,,,,,...,1.791378,,,,,,,,,
"(3, 91)","((24, 98),)",,,,3.777221,3.775241,3.772362,,,,,...,3.781694,,,,,,,,,
"(3, 91)","((4, 85),)",,,,1.523115,1.53594,1.625964,,,,,...,1.575723,,,,,,,,,
"(15, 89)","((4, 85),)",,,,,,,,,,,...,,0.0,0.0,0.0,0.0,0.0,0.0,0.918508,0.0,
"(15, 89)","((24, 98),)",,,,,,,,,,,...,,3.798651,3.78302,3.79173,3.783469,3.794235,3.794457,3.788428,3.746004,
