In [None]:
!pip install gymnasium
from myst_nb import glue

# Let's practice

In [None]:
import gymnasium
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

class GridWorld(gymnasium.Env):
  def __init__(self):
    # Define the action and observation spaces
    self.action_space = spaces.Discrete(4) # Up, Down, Left, Right
    self.observation_space = spaces.Discrete(12) # 12 cells

    self.P = { 0: { 0: [[0.9, 0, 0], [0.1, 1, 0]], 1: [[0.8, 4, 0], [0.1, 1, 0], [0.1, 0, 0]], 2: [[0.9, 0, 0], [0.1, 4, 0]], 3: [[0.8, 1, 0], [0.1, 4, 0], [0.1, 0, 0]]},
               1: { 0: [[0.8, 1, 0], [0.1, 0, 0], [0.1, 2, 0]],  1: [[0.8, 1, 0], [0.1, 0, 0], [0.1, 2, 0]], 2: [[0.8, 0, 0], [0.2, 1, 0]], 3: [[0.8, 2, 0], [0.2, 1, 0]]},
               2: { 0: [[0.8, 2, 0], [0.1, 1, 0], [0.1, 3, 0]], 1: [[0.8, 6, 0], [0.1, 1, 0], [0.1, 3, 0]], 2: [[0.8, 1, 0], [0.1, 2, 0], [0.1, 6, 0]], 3: [[0.8, 3, 0], [0.1, 2, 0], [0.1, 6, 0]]},
               3: { 0: [[0.9, 3, 0], [0.1, 2, 0]], 1: [[0.8, 7, -10], [0.1, 2, 0], [0.1, 3, 0]], 2: [[0.8, 2, 0], [0.1, 3, 0], [0.1, 7, -10]], 3: [[0.9, 3, 0], [0.1, 7, -10]]},
               4: { 0: [[0.8, 0, 0], [0.2, 4, 0]], 1: [[0.8, 8, 0], [0.2, 4, 0]], 2: [[0.8, 4, 0], [0.1, 0, 0], [0.1, 8, 0]], 3: [[0.8, 4, 0], [0.1, 0, 0], [0.1, 8, 0]]},
               5: { 0: [[1, 5, 0]], 1: [[1, 5, 0]], 2: [[1, 5, 0]], 3: [[1, 5, 0]]},
               6: { 0: [[0.8, 2, 0], [0.1, 6, 0], [0.1, 7, -10]], 1: [[0.8, 10, 0], [0.1, 6, 0], [0.1, 7, -10]], 2: [[0.8, 6, 0], [0.1, 10, 0], [0.1, 2, 0]], 3: [[0.8, 7, -10], [0.1, 2, 0], [0.1, 10, 0]]},
               7: { 0: [[1, 7, -10]], 1: [[1, 7, -10]], 2: [[1, 7, -10]], 3: [[1, 7, -10]]},
               8: { 0: [[0.8, 4, 0], [0.1, 9, 0], [0.1, 8, 0]], 1: [[0.9, 8, 0], [0.1, 9, 0]], 2: [[0.9, 8, 0], [0.1, 4, 0]], 3: [[0.8, 9, 0], [0.1, 4, 0], [0.1, 8, 0]]},
               9: { 0: [[0.8, 9, 0], [0.1, 8, 0], [0.1, 10, 0]], 1: [[0.8, 9, 0], [0.1, 8, 0], [0.1, 10, 0]], 2: [[0.8, 8, 0], [0.2, 9, 0]], 3: [[0.8, 10, 0], [0.2, 9, 0]]},
               10: { 0: [[0.8, 6, 0], [0.1, 9, 0], [0.1, 11, 10]], 1: [[0.8, 10, 0], [0.1, 9, 0], [0.1, 11, 10]], 2: [[0.8, 9, 0], [0.1, 6, 0], [0.1, 10, 0]], 3: [[0.8, 11, 10], [0.1, 6, 0], [0.1, 10, 0]]},
               11: { 0: [[1, 11, 10]], 1: [[1, 11, 10]], 2: [[1, 11, 10]], 3: [[1, 11, 10]]},
              } 

    # Initialize the state
    self.state = 0

  def step(self, action: int):

    self._transition(action)

    done = False

    reward = 0
    if self.state == 11:
      reward = 10
      done = True
    elif self.state == 7:
      reward = -10
      done = True


    # Return the observation, reward, done flag, and info
    return self.state, reward, done, {}

  def _transition(self, action: int):
    """
    Transition function.
    :param action: Action to take
    """
    r = np.floor(self.state / 3)
    c = self.state % 4
    if action == 0:
      r = max(0, r - 1)
    elif action == 1:
      r = min(2, r + 1)
    elif action == 2:
      c = max(0, c - 1)
    elif action == 3:
      c = min(3, c + 1)
    self.state = r * 4 + c

  def reset(self):
    """
    Reset the environment.
    """
    self.state = 0
    return self.state

  def render(self, render="human"):
    fig, ax = plt.subplots()
    ax.set_xlim(0, 4)
    ax.set_ylim(0, 3)
    ax.set_aspect('equal')


    for i in range(4):
      for j in range(3):
        if j * 4 + i == 11:
          rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='green')
          ax.add_patch(rect)
        elif j * 4 + i == 7:
          rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='red')
          ax.add_patch(rect)
        elif j * 4 + i == 5:
          rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='grey')
          ax.add_patch(rect)
        else:
          rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='white')
          ax.add_patch(rect)

    ax.tick_params(axis='both',       # changes apply to both axis
                    which='both',      # both major and minor ticks are affected
                    bottom=False,      # ticks along the bottom edge are off
                    top=False,         # ticks along the top edge are off
                    left=False,
                    right=False,
                    labelbottom=False,
                    labelleft=False) # labels along the bottom edge are off

    plt.show()


In [None]:
def value_iteration(env, gamma=0.9, theta=0.0001):
  V = np.zeros(env.observation_space.n)
  while True:
    delta = 0
    for s in range(env.observation_space.n):
      v = V[s]
      V[s] = max([sum([p * (r + gamma * V[s_]) for p, s_, r in env.P[s][a]]) for a in env.P[s]])
      delta = max(delta, np.abs(v - V[s]))
    if delta < theta:
      break
  pi = np.zeros(env.observation_space.n)
  for s in range(env.observation_space.n):
    pi[s] = np.argmax([sum([p * (r + gamma * V[s_]) for p, s_, r in env.P[s][a]]) for a in env.P[s]])

In [None]:
def value_iteration_interact(env, gamma=0.9, theta=0.0001, step=0):
  V = np.zeros(env.observation_space.n)
  for i in range(step):
    delta = 0
    for s in range(env.observation_space.n):
      v = V[s]
      V[s] = max([sum([p * (r + gamma * V[s_]) for p, s_, r in env.P[s][a]]) for a in env.P[s]])
      delta = max(delta, np.abs(v - V[s]))
    if delta < theta:
      break
  pi = np.zeros(env.observation_space.n)
  for s in range(env.observation_space.n):
    pi[s] = np.argmax([sum([p * (r + gamma * V[s_]) for p, s_, r in env.P[s][a]]) for a in env.P[s]])
  return V, pi


fig, ax = plt.subplots()
ax.set_xlim(0, 4)
ax.set_ylim(0, 3)
ax.set_aspect('equal')


for i in range(4):
  for j in range(3):
    if j * 4 + i == 11:
      rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='green')
      ax.add_patch(rect)
    elif j * 4 + i == 7:
      rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='red')
      ax.add_patch(rect)
    elif j * 4 + i == 5:
      rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='grey')
      ax.add_patch(rect)
    else:
      rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='white')
      ax.add_patch(rect)

    ax.text(i + 0.5, j + 0.5, 0, ha='center', va='center')

ax.tick_params(axis='both',        # changes apply to both axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                left=False,
                right=False,
                labelbottom=False,
                labelleft=False)   # labels along the bottom edge are off

def update(gamma=0.9, step=0):
  V, pi = value_iteration_interact(env, gamma, step=step)
  ax.clear()
  for i in range(4):
    for j in range(3):
      if j * 4 + i == 11:
        rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='green')
        ax.add_patch(rect)
      elif j * 4 + i == 7:
        rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='red')
        ax.add_patch(rect)
      elif j * 4 + i == 5:
        rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='grey')
        ax.add_patch(rect)
      else:
        rect = Rectangle((i, j), 1, 1, edgecolor='black', facecolor='white')
        ax.add_patch(rect)
      ax.text(i + 0.5, j + 0.5, int(V[j * 4 + i]), ha='center', va='center')
  fig.canvas.draw_idle()
  

In [None]:
from ipywidgets import *

interact(update, gamma = (0.5,0.9,0.1), step = (0, 50, 1))