In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.patches import Polygon
from matplotlib.colors import LinearSegmentedColormap

def plot_colored_grid(rows, cols, colored_tiles, output_size, q_values=None):

    dpi = output_size[0] / cols

    fig, ax = plt.subplots(figsize=(output_size[0] / dpi, output_size[1] / dpi), dpi=dpi)

    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)

    ax.plot([0, cols], [0, 0], color='black')
    ax.plot([0, 0], [0, rows], color='black')
    ax.plot([cols, cols], [0, rows], color='black')
    ax.plot([cols, 0], [rows, rows], color='black')

    for i in range(1, rows):
        ax.plot([0, cols], [i, i], color='black', linestyle='--')

    for j in range(1, cols):
        ax.plot([j, j], [0, rows], color='black', linestyle='--')

    if not (q_values is None):
        max_q_value = np.max(q_values)
        for row in range(rows):
            for col in range(cols):
                if ((row, col) == colored_tiles['red']) or ((row, col) == colored_tiles['green']):
                    continue
                state_q_values = q_values[row, col]
                vertices = [(col+1, row), (col+1, row+1), (col, row+1), (col, row), (col+1, row)]
                for i in range(2):
                    for j in range(2):
                        index = 2*i+j

                        q_value = state_q_values[index]
                        color = plt.cm.RdYlGn((q_value / max_q_value))
                        triangle = Polygon([(col+0.5, row+0.5), vertices[(1-index)%4], vertices[(2-index)%4]],
                                           facecolor=color, edgecolor='none')
                        ax.add_patch(triangle)

                        x_middle = (vertices[(1-index)%4][0] + vertices[(2-index)%4][0] + (col+0.5)) / 3
                        y_middle = (vertices[(1-index)%4][1] + vertices[(2-index)%4][1] + (row+0.5)) / 3

                        ax.text(x_middle, y_middle, f'{q_value:.2f}', color='black',
                            ha='center', va='center', fontsize=8)


    for color, coord in colored_tiles.items():
        row, col = coord
        if color == 'black':
            circle = plt.Circle((col+0.5, row+0.5), 0.1, color=color)
            ax.add_patch(circle)
        else:
            ax.fill_betweenx([row, row + 1], col, col+1, color=color)


    plt.xticks([])
    plt.yticks([])
    plt.grid(False)

    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    image_array = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)


    # plt.show()
    plt.close(fig)
    return image_array

# colored_tiles = {'red': (1, 3), 'black': (0, 0), 'green': (2, 3)}
# q_values = np.random.random((3, 4, 4))
# image_array = plot_colored_grid(3, 4, colored_tiles, (400, 300), q_values)

In [None]:
import numpy as np
import random

class CustomMiniGridEnv:
    def __init__(self, row_num, col_num, reward=100, max_step=40):
        self.col_num = col_num
        self.row_num = row_num
        self.reward = reward
        self.max_step = max_step
        self.steps = 0

        self.observation_shape = (row_num, col_num)
        self.action_shape = 4
        self._setup()

    def _setup(self):
        agent_row_pos = np.random.randint(0, self.row_num-1)
        agent_col_pos = np.random.randint(0, self.col_num-1)
        self.agent_pos = np.array([agent_row_pos, agent_col_pos])

        while True:
            goal_row_pos = np.random.randint(0, self.row_num-1)
            goal_col_pos = np.random.randint(0, self.col_num-1)
            self.goal_pos = np.array([goal_row_pos, goal_col_pos])

            if not np.array_equal(self.agent_pos, self.goal_pos):
                break

        while True:
            trap_row_pos = np.random.randint(0, self.row_num-1)
            trap_col_pos = np.random.randint(0, self.col_num-1)
            self.trap_pos = np.array([trap_row_pos, trap_col_pos])

            if (not np.array_equal(self.trap_pos, self.agent_pos)) and (not np.array_equal(self.trap_pos, self.goal_pos)):
                break

    def reset(self):
        self.steps = 0
        while True:
            agent_row_pos = np.random.randint(0, self.row_num-1)
            agent_col_pos = np.random.randint(0, self.col_num-1)
            self.agent_pos = np.array([agent_row_pos, agent_col_pos])

            if (not np.array_equal(self.agent_pos, self.trap_pos)) and (not np.array_equal(self.agent_pos, self.goal_pos)):
                break

        return self.agent_pos.copy()


    def step(self, action):
        self.steps += 1

        if action == 0 and self.agent_pos[0] < self.row_num-1:
            self.agent_pos[0] += 1
        elif action == 2 and self.agent_pos[0] > 0:
            self.agent_pos[0] -= 1
        elif action == 1 and self.agent_pos[1] < self.col_num-1:
            self.agent_pos[1] += 1
        elif action == 3 and self.agent_pos[1] > 0:
            self.agent_pos[1] -= 1

        done_win = np.array_equal(self.agent_pos, self.goal_pos)
        done_lose = np.array_equal(self.agent_pos, self.trap_pos)
        done = False
        if self.steps > self.max_step:
            reward = -1 * self.reward
            done = True
        elif done_win:
            reward = self.reward
            done = True
        elif done_lose:
            reward = -1 * self.reward
            done = True
        else:
            reward = -1


        return self.agent_pos.copy(), reward, done

    def render(self, q_values=None):
        colored_tiles = {'red': (self.trap_pos[0], self.trap_pos[1]),
                         'green': (self.goal_pos[0], self.goal_pos[1]),
                         'black': (self.agent_pos[0], self.agent_pos[1])}
        image = plot_colored_grid(self.row_num, self.col_num, colored_tiles, (100 * self.col_num, 100 * self.row_num), q_values)
        return image


In [None]:
class TabularQLearner:
    def __init__(self, state_shape, num_actions, learning_rate=0.001, discount_factor=0.9):
        self.state_shape = state_shape
        self.num_actions = num_actions
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.q_table = np.zeros((*state_shape, num_actions))
        self.q_table_explore = np.zeros((*state_shape, num_actions))

    def select_action(self, state):
        return np.argmax(self.q_table[state[0], state[1], :])

    def explore(self, state):
        action = np.argmin(self.q_table_explore[state[0], state[1], :])
        self.q_table_explore[state[0], state[1], action] += 1
        return action

    def update(self, state, action, reward, next_state):
        current_q_value = self.q_table[state[0], state[1], action]
        best_next_q_value = np.max(self.q_table[next_state[0], next_state[1]])
        new_q_value = current_q_value + self.learning_rate * (reward + self.discount_factor * best_next_q_value - current_q_value)
        self.q_table[state[0]][state[1]][action] = new_q_value

In [None]:
import numpy as np
import imageio
import matplotlib.pyplot as plt

env = CustomMiniGridEnv(5, 5, reward=20, max_step=40)

train_step = 10000

agent = TabularQLearner(env.observation_shape, env.action_shape)

for i in range(train_step):
    total_reward = 0
    state = env.reset()
    while True:
        action = agent.explore(state)

        next_state, reward, done = env.step(action)

        total_reward += reward
        agent.update(state, action, reward, next_state)
        state = next_state

        if done:
            break

step: 0
total reward: -44
step: 1000
total reward: 12
step: 2000
total reward: 10
step: 3000
total reward: -21
step: 4000
total reward: -29
step: 5000
total reward: -60
step: 6000
total reward: -60
step: 7000
total reward: 19
step: 8000
total reward: -45
step: 9000
total reward: -32


In [None]:
import imageio

for _ in range(10):
    state = env.reset()
    total_reward = 0
    while True:
        action = agent.select_action(state)

        state, reward, done = env.step(action)

        total_reward += reward

        if done:
            break

    print(f'rewards: {total_reward}')

state = env.reset()
frames = []
screen = env.render(agent.q_table)
frames.append(screen)
total_reward = 0
while True:
        action = agent.select_action(state)

        state, reward, done = env.step(action)

        total_reward += reward

        screen = env.render(agent.q_table)
        frames.append(screen)

        if done:
            break
imageio.mimsave('./render.mp4', frames, fps=3)
print(len('frames'))

In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open('/content/render.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)