In [None]:
import gymnasium as gym

import numpy as np
import polars as pl

from collections import defaultdict

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

In [None]:
from windy_gridworld_env import WindyGridworldEnv

In [None]:
env = WindyGridworldEnv()

In [None]:
class GridWorldAgent:
    def __init__(
        self,
        strategy="random",
        learning_rate="decaying-epsilon",
        learning_method="sarsa",
    ):
        self.strategy = strategy
        self.value_function = defaultdict(lambda: {0: 0, 1: 0, 2: 0, 3: 0})
        self.t = 0
        self.alpha = 0.01

        self.learning_method = learning_method

        self.learning_rate = learning_rate
        self.epsilon = 0.1

    def _get_epsilon(self):
        if self.learning_rate == "decaying-epsilon":
            return 1 / (self.t + 1)
        else:
            return self.epsilon

    def get_action(self, state):
        if self.strategy == "random":
            return np.random.choice(4)
        if self.strategy == "epsilon-greedy":
            epsilon = self._get_epsilon()
            if np.random.rand() < epsilon:
                action = np.random.choice(4)
            else:
                action = np.argmax([self.value_function[state][i] for i in range(4)])
            return int(action)

    def update(self, state, action, reward, next_state, next_action):

        if self.learning_method == "sarsa":
            self.value_function[state][action] += self.alpha * (
                reward
                + self.value_function[next_state][next_action]
                - self.value_function[state][action]
            )
        elif self.learning_method == "q-learning":
            self.value_function[state][action] += self.alpha * (
                reward
                + max(self.value_function[next_state].values())
                - self.value_function[state][action]
            )

        self.t += 1

In [None]:
def generate_episodes(agent, n_episodes=2):
    sequence = []
    for r in range(n_episodes):

        terminated = False

        # initialize:
        state = env.reset()
        action = agent.get_action(state)

        while not terminated:
        # update loop
            next_state, reward, terminated, _ = env.step(action)
            next_action = agent.get_action(next_state)
            agent.update(state, action, reward, next_state, next_action)
            action, state = next_action, next_state

            sequence.append((state[0], state[1], action, reward, r))

    episodes = pl.DataFrame(
        sequence, schema=["x", "y", "action", "reward", "episode"], orient="row"
    )

    return episodes


In [None]:
def plot_value_function(agent):
    shapes = {
        0: "↑",
        1: "→",
        2: "↓",
        3: "←",
        -1: "",
    }

    data_plotting = pl.DataFrame(
        [
            (x, y, int(np.argmax([qs[i] for i in range(4)])), max(qs.values()))
            for (x, y), qs in agent.value_function.items()
        ],
        orient="row",
        schema=["x", "y", "action", "value"],
    ).sort("x", "y")

    z = data_plotting.pivot(index="x", on="y", values="value").fill_null(0)
    actions = (
        data_plotting.pivot(index="x", on="y", values="action").drop("x").fill_null(-1)
    )

    y = np.array(z["x"])[::-1]
    x = np.array([int(i) for i in z.columns[1:]])

    z = z.drop("x").to_numpy()


    text_actions = np.empty(actions.shape, dtype=np.dtype("U3"))
    for yi in range(actions.shape[0]):
        for xi in range(actions.shape[1]):
            if (yi, xi) == (3, 0):
                text_actions[yi, xi] = "S " + shapes[actions[yi, xi]]
            elif (yi, xi) == (3, 7):
                text_actions[yi, xi] = "G"
            else:
                text_actions[yi, xi] = shapes[actions[yi, xi]]


    fig = go.Figure(
        data=go.Heatmap(
            z=z,
            x=x,
            y=y,
            colorscale="Viridis",
            text=text_actions,
            texttemplate="%{text}",
            textfont={"size": 20},
        )
    )
    fig.show()


In [None]:
agent = GridWorldAgent(strategy="epsilon-greedy", learning_method="sarsa")
generate_episodes(agent, n_episodes=100_000)
plot_value_function(agent)


In [None]:
agent = GridWorldAgent(strategy="epsilon-greedy", learning_method="q-learning")
generate_episodes(agent, n_episodes=100_000)
plot_value_function(agent)