In [3]:
import random
from collections import defaultdict, deque
from grid_utils import unmark_block
from state_utils import identify_decision_state, choose_action, is_block_row, get_valid_diagonal_moves, handle_blocks, drop_ball, initialize_trackers
from board_builder import build_board
from visualization import print_training_stats, visualize_grid

### Q-Learning

In [4]:
# declaration of tracker dictionaries
trackers = initialize_trackers()

q_table = defaultdict(dict)
episode_rewards_history = []
most_recent_rewards = deque(maxlen=100)
total_stars_collected = 0

# q-learning parameters
learning_rate = 0.1
discount_factor = 0.99
exploration_rate = 1.0  # start fully exploratory
exploration_decay = 0.9999  # reduce randomness over time
min_exploration = 0.01  # smallest possible exploration rate
episodes = 10000  # number of training episodes
initial_free_exploration = 1000

# train agent
target_bucket = 2  # the bucket the agent should aim for

grid, buckets, width, height = build_board("hard", trackers)

start_x = random.randint(0, width - 1)

visualize_grid(grid, width, height, ball_position=(start_x, height - 1), buckets=buckets)

for episode in range(episodes):
    grid, buckets, width, height = build_board("hard", trackers)

    start_x = random.randint(0, width - 1)
    
    # visualize_grid(grid, width, height, ball_position=(start_x, height - 1), buckets=buckets)

    state_action_pairs, reward, stars_collected = drop_ball(
        grid=grid,
        width=width,
        height=height,
        start_x=start_x,
        buckets=buckets,
        target_bucket=target_bucket,
        mode="q",
        exploration_rate=exploration_rate,
        q_table=q_table,
        trackers=trackers,
        extra=None,
        # visualize=(episode == episodes - 1)
        visualize=False
    )

    total_stars_collected += len(stars_collected)

    # update Q-table using recorded decisions
    for state, action in reversed(state_action_pairs):
        current_q = q_table[state][action]
        best_future_q = max(q_table.get(state, {}).values(), default=0)
        q_table[state][action] = current_q + learning_rate * (reward + discount_factor * best_future_q - current_q)

    episode_rewards_history.append(reward) # save the final reward
    most_recent_rewards.append(reward)
    
    # decay exploration rate
    if episode >= initial_free_exploration:
        exploration_rate = max(min_exploration, exploration_rate * exploration_decay)

    # print progress
    if (episode + 1) % 100 == 0:
        avg_reward = sum(most_recent_rewards) / len(most_recent_rewards)
        avg_stars = total_stars_collected / (episode + 1)
        print(f"Episode {episode + 1} | Avg Reward (Last 100): {avg_reward:.2f} | Avg Stars: {avg_stars:.2f} | Exploration Rate: {exploration_rate:.2f} | Q-States: {len(q_table)}")

print_training_stats(
    trackers,
    q_table
)


   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4
99   O   X   O   O   O   O   O   O   O   O   O   O   
98 O   O   O   O   O   O   O   O   O   O   O   O ⤓ ☆ 
97   O   O   O   O   O   O   O   O   O   O   O   Φ   
96 O   O   O   O   O   O   O   O   O   O   O   O | O 
95 _ _ _ _ _ O   O   O _ _ _ _ _ _   O   O   O   Φ   
94 O   O   O   O   O   O   O   O   O   O   O   O | O 
93   O   O   O   O   O   O   O   O   O   O   O   Φ   
92 O   O   O   O   O   O   O   O   O   O   O   O | O 
91   O   _ _ _ _ _ _ _ ⤓ O   O   O   O   O   O   Φ   
90 O   O   O   O   O   Φ   O   O   O   O   O   O | O 
89   O   O   O   O   O | O   O   _ _ _ ⤓ _ _ _ _ ↥   
88 O   O   O   O   O   Φ   O   O   O   Φ   O   O   / 
87   O   O   O   O   O | O   O   O   O | O   O   /   
86 O   O   O   O   O   Φ   O   O   O   Φ   O   /   O 
85   O _ _ ⬒ _ _ O   O | O   O   O   O | O   /   O   
84 O   O   O   O   O   Φ   O   O   O   Φ   /   O   O 
83   O   O   O   O   O | O   O   _ _ _ ↥ _   O   O   
82 O   O   O   O   O   Φ   O 

Unnamed: 0_level_0,Unnamed: 1_level_0,15,16,17,18,19,20,21,22,23,10,...,8,9,2,0,1,11,12,13,14,24
position,buttons_pressed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
"(23, 98)","((24, 98),)",,,,,,,,,35.680868,,...,,,,,,,,,,37.658803
"(23, 98)",(),,,,,,,,,54.74737,,...,,,,,,,,,,53.073895
"(10, 95)",(),54.259899,,,,,,,,,52.914821,...,,,55.466062,51.105688,53.641651,53.457525,52.062731,54.086871,54.507704,
"(0, 95)",(),41.032538,,,,,,,,,43.786891,...,,,43.711984,43.812802,50.057293,43.675447,43.295795,42.451951,41.642534,
"(3, 91)","((4, 85),)",,,,,,,,,,8.994999,...,13.42759,7.980521,,,,,,,,
"(3, 91)",(),,,,,,,,,,103.062952,...,103.060508,103.058157,,,,,,,,
"(3, 91)","((4, 85), (24, 98))",,,,,,,,,,0.221741,...,0.32035,0.519352,,,,,,,,
"(3, 91)","((24, 98),)",,,,,,,,,,96.737067,...,96.806335,96.924854,,,,,,,,
"(15, 89)","((24, 98),)",19.403475,19.785618,19.797833,18.130881,22.722287,18.501161,20.260971,17.970883,26.217237,,...,,,,,,,,,,
"(15, 89)",(),78.603882,79.328175,78.935662,79.071488,78.757543,78.706714,81.543287,78.411412,78.722003,,...,,,,,,,,,,
