In [1]:
import random
from collections import defaultdict
from grid_utils import unmark_block
from state_utils import identify_decision_state, choose_action, is_block_row, get_valid_diagonal_moves, handle_blocks, drop_ball, initialize_trackers
from board_builder import build_board
from visualization import print_training_stats, visualize_grid

### Q-Learning

In [2]:
# declaration of tracker dictionaries
trackers = initialize_trackers()

q_table = defaultdict(dict)

# q-learning parameters
learning_rate = 0.1
discount_factor = 0.9
exploration_rate = 1.0  # start fully exploratory
exploration_decay = 0.995  # reduce randomness over time
min_exploration = 0.01  # smallest possible exploration rate
episodes = 1000  # number of training episodes

# train agent
target_bucket = 2  # the bucket the agent should aim for

for episode in range(episodes):
    grid, buckets = build_board("default", 15, 30, trackers)
    
    width = max(x for (x, y) in grid.keys()) + 1
    height = max(y for (x, y) in grid.keys()) + 1

    start_x = random.randint(0, width - 1)

    visualize_grid(grid, width, height, ball_position=(start_x, height - 1), buckets=buckets)
    
    state_action_pairs, reward = drop_ball(
        grid=grid,
        width=width,
        height=height,
        start_x=start_x,
        buckets=buckets,
        target_bucket=target_bucket,
        mode="q",
        exploration_rate=exploration_rate,
        q_table=q_table,
        trackers=trackers,
        extra=None
    )

    # update Q-table using recorded decisions
    for state, action in reversed(state_action_pairs):
        current_q = q_table[state][action]
        best_future_q = max(q_table.get(state, {}).values(), default=0)
        q_table[state][action] = current_q + learning_rate * (reward + discount_factor * best_future_q - current_q)

    # gradually reduce exploration to favor learned strategies
    exploration_rate = max(min_exploration, exploration_rate * exploration_decay)

    # ✅ Print progress every 100 episodes
    if (episode + 1) % 100 == 0:
        print(f"Episode {episode + 1} | Exploration Rate: {exploration_rate:.3f}")

print_training_stats(
    trackers,
    q_table
)


   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4
29   O   O   O   X   O   O   O   
28 \   O   O   O   O   O   O   O 
27   \ _ _ _ _ ⤓ O   O   O   O   
26 O   \   O   Φ   O   O   O   O 
25   O   \   O | O   O   O   O   
24 O   O   O   ↥ _ _ ⤓ O   O   O 
23   O   O   O   O   Φ   O   /   
22 O   O   O   O   O | O   /   O 
21   _ _ _ _ _ _ O   Φ   /   O   
20 O   O   O   O   O | O   O   O 
19   O   O   O   O   ↥ _ _ _ _   
18 O   O   O ^ ^ ^ ^   O   O   O 
17   O   _ _ _ _ _ _ _   O   O   
16 O   O   O   O   O   O   O   / 
15   O   O   O   _ _ _ _ _ _ /   
14 O   O   O   O   O   O   /   O 
13 _ _ _ _ _ O   O   O   /   O   
12 \   O   O   O   O   O   O   O 
11   \   O   O   O   _ _ _ _ _ _ 
10 O   O   O   O   O   O   O   O 
 9 ⬒ _ _ _ ⤓ _ _ _ _ O   O   O   
 8 O   O   Φ   O   O   O   O   O 
 7   O   O | _ _ _ _ _ _ O   O   
 6 O   O   Φ   O ^ ^   O   O   O 
 5 █ █ █ █ ↥ █ █ █ █ █ █ █ █ █ █ 
 4 O   O   O   O   O   O   O   O 
 3   O   O   O   O _ _ _ _   O   
 2 O   O   _ _ _ _ _ _ _ _ O   O 
 1   O   O   O 

Unnamed: 0_level_0,Unnamed: 1_level_0,9,10,11,12,13,14,4,0,1,2,3,5,6,7,8
position,buttons_pressed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
"(2, 27)",(),,,,,,,4.482451,,,3.788726,2.219808,2.278835,9.391424,,
"(6, 24)",(),9.333761,,,,,,,,,,,,8.873511,9.080986,9.999958
"(1, 21)",(),,,,,,,1.168311,,0.728198,2.923728,2.079781,1.808535,7.789688,,
"(9, 19)",(),9.999086,7.720207,6.890375,8.298995,7.36852,,,,,,,,,,
"(3, 17)",(),0.868264,,,,,,0.389337,,,,1.343874,6.981866,0.260621,1.142572,0.967671
"(7, 15)",(),0.659209,4.087614,0.843613,1.25342,,,,,,,,,,0.739335,0.73903
"(0, 13)",(),,,,,,,2.98516,0.476139,0.770687,0.317706,0.026812,,,,
"(9, 11)",(),2.049383,2.421859,1.397218,1.584823,5.473872,1.602835,,,,,,,,,
"(0, 9)",(),,,,,,,4.358899,3.660879,4.300184,3.775893,3.630341,3.535526,3.933298,3.9344,9.203196
"(0, 9)","((0, 9),)",,,,,,,0.00981,0.037235,0.099883,0.126444,0.026585,0.010685,0.217698,0.790207,0.209514
