In [1]:
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import numpy as np

from myelin.agents import ExpectedSARSA
from myelin.core import RLInteraction, MDPEnvironment
from myelin.mdps import GridWorld
from myelin.policies import RandomPolicy
from myelin.utils import Callback

ROWS = 20
COLS = 20
CELL_SIZE = 3

mdp = GridWorld(ROWS, COLS)
env = MDPEnvironment(mdp)
policy = RandomPolicy(env.action_space)

class GridWorldQFunction:
    def __init__(self):
        self._table = np.zeros((ROWS * CELL_SIZE, COLS * CELL_SIZE))

    def __setitem__(self, key, value):
        state, action = key
        key_ = tuple((np.array(state) * CELL_SIZE) + 1 + np.array(action))
        self._table[key_] = value

    def __getitem__(self, key):
        state, action = key
        key_ = tuple((np.array(state) * CELL_SIZE) + 1 + np.array(action))
        return self._table[key_]


qf = GridWorldQFunction()
agent = ExpectedSARSA(env.action_space, policy, qf)


def show_value_function(qf):
    np.set_printoptions(precision=4, linewidth=200)
    plt.rcParams["figure.figsize"] = (6, 6)
    plt.matshow(qf._table)
    plt.colorbar()
    plt.show()

    
class Monitor(Callback):
    def on_episode_begin(self, episode):
        if episode % 10 == 0:
            print('Episode {}'.format(episode))
            show_value_function(qf)

    def on_train_end(self, episode):
        print('Episode {}'.format(episode))
        show_value_function(qf)


RLInteraction(env, agent).train(200, callbacks=[Monitor()])

Episode 0


<Figure size 600x600 with 2 Axes>

Episode 10


<Figure size 600x600 with 2 Axes>

Episode 20


<Figure size 600x600 with 2 Axes>

Episode 30


<Figure size 600x600 with 2 Axes>

Episode 40


<Figure size 600x600 with 2 Axes>

Episode 50


<Figure size 600x600 with 2 Axes>

Episode 60


<Figure size 600x600 with 2 Axes>

Episode 70


<Figure size 600x600 with 2 Axes>

Episode 80


<Figure size 600x600 with 2 Axes>

Episode 90


<Figure size 600x600 with 2 Axes>

Episode 100


<Figure size 600x600 with 2 Axes>

Episode 110


<Figure size 600x600 with 2 Axes>

Episode 120


<Figure size 600x600 with 2 Axes>

Episode 130


<Figure size 600x600 with 2 Axes>

Episode 140


<Figure size 600x600 with 2 Axes>

Episode 150


<Figure size 600x600 with 2 Axes>

Episode 160


<Figure size 600x600 with 2 Axes>

Episode 170


<Figure size 600x600 with 2 Axes>

Episode 180


<Figure size 600x600 with 2 Axes>

Episode 190


<Figure size 600x600 with 2 Axes>

Episode 200


<Figure size 600x600 with 2 Axes>