In [None]:
pip install pygame

In [None]:
import numpy as np
import random
import pygame
import time
from pygame.locals import *

# Environment configuration
GRID_SIZE = 4
START = (0, 0)
GOAL = (3, 3)
ACTIONS = ['up', 'down', 'left', 'right']

# Learning parameters
EPSILON = 0.1
ALPHA = 0.5
GAMMA = 0.9
EPISODES = 500

# Color configuration
COLORS = {
    'bg': (25, 25, 25),
    'grid': (100, 100, 100),
    'start': (30, 144, 255),
    'goal': (255, 51, 51),
    'agent': (50, 205, 50),
    'path': (255, 215, 0),
    'qtext': (200, 200, 200),
    'trail': (0, 255, 127, 50)
}

# Initialize Pygame
pygame.init()
WIDTH, HEIGHT = 800, 600
CELL_SIZE = 120
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Enhanced Q-Learning Maze")
font = pygame.font.SysFont('Arial', 18)
large_font = pygame.font.SysFont('Arial', 32)

# Initialize the Q-table
Q = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS)))

# New feature: Smooth movement animation
class AnimatedAgent:
    def __init__(self):
        self.current_pos = (0, 0)
        self.target_pos = (0, 0)
        self.animation_progress = 0  # 0-1 represents the animation progress
    
    def update_position(self, new_pos):
        self.target_pos = new_pos
        self.animation_progress = 0
    
    def get_display_pos(self):
        x = self.current_pos[0] * (1 - self.animation_progress) + self.target_pos[0] * self.animation_progress
        y = self.current_pos[1] * (1 - self.animation_progress) + self.target_pos[1] * self.animation_progress
        return (y * CELL_SIZE + CELL_SIZE//4, x * CELL_SIZE + CELL_SIZE//4)
    
    def update_animation(self, speed=0.1):
        if self.animation_progress < 1:
            self.animation_progress = min(1, self.animation_progress + speed)
        else:
            self.current_pos = self.target_pos

agent = AnimatedAgent()

# New feature: Trail recording
trail_positions = []

# Reward function and state transition remain unchanged...
def get_reward(state):
    return 10 if state == GOAL else -0.1

def get_next_state(state, action):
    x, y = state
    if action == 'up': return (max(x-1, 0), y)
    if action == 'down': return (min(x+1, GRID_SIZE-1), y)
    if action == 'left': return (x, max(y-1, 0))
    return (x, min(y+1, GRID_SIZE-1))

def update_q(state, action, reward, next_state):
    x, y = state
    a = ACTIONS.index(action)
    next_x, next_y = next_state
    Q[x, y, a] += ALPHA * (reward + GAMMA * np.max(Q[next_x, next_y]) - Q[x, y, a])

def choose_action(state):
    if random.uniform(0, 1) < EPSILON:
        return random.choice(ACTIONS)
    x, y = state
    return ACTIONS[np.argmax(Q[x, y])]

# Enhanced drawing functions
def draw_grid():
    # Draw a gradient background
    for x in range(GRID_SIZE):
        for y in range(GRID_SIZE):
            color = (50 + x*20, 50 + y*20, 100)
            rect = pygame.Rect(y*CELL_SIZE, x*CELL_SIZE, CELL_SIZE, CELL_SIZE)
            pygame.draw.rect(screen, color, rect)
    
    # Draw grid lines
    for x in range(0, WIDTH, CELL_SIZE):
        pygame.draw.line(screen, COLORS['grid'], (x, 0), (x, HEIGHT), 2)
    for y in range(0, HEIGHT, CELL_SIZE):
        pygame.draw.line(screen, COLORS['grid'], (0, y), (WIDTH, y), 2)

def draw_q_values():
    for x in range(GRID_SIZE):
        for y in range(GRID_SIZE):
            center_x = y * CELL_SIZE + CELL_SIZE//2
            center_y = x * CELL_SIZE + CELL_SIZE//2
            for i, action in enumerate(ACTIONS):
                value = Q[x, y, i]
                color = COLORS['qtext'] if value >= 0 else (255, 100, 100)
                
                if action == 'up':
                    pos = (center_x - 15, center_y - 40)
                elif action == 'down':
                    pos = (center_x - 15, center_y + 30)
                elif action == 'left':
                    pos = (center_x - 50, center_y - 10)
                else:
                    pos = (center_x + 20, center_y - 10)
                
                text = font.render(f"{value:.1f}", True, color)
                screen.blit(text, pos)

def draw_agent():
    pos = agent.get_display_pos()
    pygame.draw.circle(screen, COLORS['agent'], pos, CELL_SIZE//4)
    pygame.draw.circle(screen, (255, 255, 255), pos, CELL_SIZE//4, 2)

def draw_trail():
    for i, (x, y) in enumerate(trail_positions):
        alpha = 50 + 200 * i // len(trail_positions)
        pos = (y * CELL_SIZE + CELL_SIZE//2, x * CELL_SIZE + CELL_SIZE//2)
        pygame.draw.circle(screen, (*COLORS['trail'][:3], alpha), pos, CELL_SIZE//5)

def draw_heatmap():
    max_q = np.max(Q)
    for x in range(GRID_SIZE):
        for y in range(GRID_SIZE):
            cell_q = np.max(Q[x, y])
            intensity = min(255, int(200 * cell_q / (max_q + 1e-5)))
            surf = pygame.Surface((CELL_SIZE, CELL_SIZE))
            surf.set_alpha(intensity)
            surf.fill((255, 255, 0))
            screen.blit(surf, (y*CELL_SIZE, x*CELL_SIZE))

# Main training loop
clock = pygame.time.Clock()
running = True
paused = False

for episode in range(EPISODES):
    if not running:
        break
    
    state = START
    agent.current_pos = START
    agent.target_pos = START
    trail_positions = []
    
    while state != GOAL and running:
        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            if event.type == KEYDOWN:
                if event.key == K_SPACE:
                    paused = not paused
        
        if paused:
            continue
        
        # Q-learning step
        action = choose_action(state)
        next_state = get_next_state(state, action)
        reward = get_reward(next_state)
        update_q(state, action, reward, next_state)
        
        # Update animation and trail
        agent.update_position(next_state)
        trail_positions.append(state)
        if len(trail_positions) > 10:
            trail_positions.pop(0)
        
        # Rendering
        screen.fill(COLORS['bg'])
        draw_heatmap()
        draw_grid()
        
        # Draw the start and goal points
        pygame.draw.rect(screen, COLORS['start'], 
                        (START[1]*CELL_SIZE, START[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE))
        pygame.draw.rect(screen, COLORS['goal'], 
                        (GOAL[1]*CELL_SIZE, GOAL[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE))
        
        draw_trail()
        draw_q_values()
        agent.update_animation()
        draw_agent()
        
        # Draw status information
        info = f"Episode: {episode+1}/{EPISODES}  Steps: {len(trail_positions)}"
        text = large_font.render(info, True, (255, 255, 255))
        screen.blit(text, (20, HEIGHT-60))
        
        pygame.display.flip()
        clock.tick(30)
        state = next_state

# The testing phase maintains similar visualization
# ... (The testing code is similar to the original version but uses the new drawing functions)

pygame.quit()