In [1]:
%matplotlib notebook

In [13]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import ListedColormap
import os

# Gridworld setup
class GridWorld:
    def __init__(self, grid_size, start, goal, obstacles, max_steps):
        self.grid_size = grid_size
        self.start = start
        self.goal = goal
        self.obstacles = obstacles
        self.state = start
        self.max_steps = max_steps
        self.steps = 0

    def reset(self):
        self.state = self.start
        self.steps = 0
        return self.state

    def step(self, action):
        self.steps += 1
        x, y = self.state
        if action == 0:   # Up
            y = max(y - 1, 0)
        elif action == 1: # Right
            x = min(x + 1, self.grid_size[0] - 1)
        elif action == 2: # Down
            y = min(y + 1, self.grid_size[1] - 1)
        elif action == 3: # Left
            x = max(x - 1, 0)

        next_state = (x, y)

        # Reward and termination
        if next_state == self.goal:
            return next_state, 1, True
        elif next_state in self.obstacles or self.steps >= self.max_steps:
            return next_state, -1, True
        else:
            return next_state, 0, False

# SARSA Algorithm
class SARSA:
    def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.env = env
        self.q_table = np.zeros((*env.grid_size, 4))
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.choice(4)
        else:
            return np.argmax(self.q_table[state])

    def train(self, episodes):
        rewards = []
        frames = []

        for episode in range(episodes):
            state = self.env.reset()
            action = self.choose_action(state)
            total_reward = 0
            done = False

            episode_frames = []
            while not done:
                next_state, reward, done = self.env.step(action)
                next_action = self.choose_action(next_state)

                # SARSA update
                self.q_table[state][action] += self.alpha * (
                        reward + self.gamma * self.q_table[next_state][next_action] - self.q_table[state][action]
                )

                # Log visualization frame
                episode_frames.append((state, reward))

                state, action = next_state, next_action
                total_reward += reward

            rewards.append(total_reward)
            frames.append(episode_frames)

        return rewards, frames

# Visualization function
def visualize(frames, grid_size, start, goal, obstacles, output_file="gridworld.gif"):
    fig, ax = plt.subplots()
    cmap = ListedColormap(["white", "black", "green", "red"])
    grid = np.zeros(grid_size)
    grid[start] = 2
    grid[goal] = 3
    for obs in obstacles:
        grid[obs] = 1

    ims = []
    for episode_frames in frames:
        for state, reward in episode_frames:
            grid_img = grid.copy()
            grid_img[state] = 4
            im = plt.imshow(grid_img, cmap=cmap, animated=True)
            ax.set_title(f"Reward: {reward}")
            ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=300, blit=True, repeat_delay=1000)
    ani.save(output_file, writer="imagemagick")
    plt.close()

# Create and train the environment
if __name__ == "__main__":
    grid_size = (5, 5)
    start = (0, 0)
    goal = (4, 4)
    obstacles = [(1, 1), (2, 2), (3, 3)]
    max_steps = 50

    env = GridWorld(grid_size, start, goal, obstacles, max_steps)
    sarsa = SARSA(env)
    rewards, frames = sarsa.train(episodes=10)

    # Create visualization
    visualize(frames, grid_size, start, goal, obstacles)

MovieWriter imagemagick unavailable; using Pillow instead.
