In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from kaermorhenv import KaerMorhenv, map_from_csv, HyperParams, SARSA

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Optional, Callable
from kaermohenv import KaerMorhenv, map_from_csv, HyperParams, SARSA
from tqdm import tqdm
from matplotlib import rc
rc('animation', html='html5')

In [None]:
env = KaerMorhenv(
    board=map_from_csv("map1.csv"),
    monster_coords=np.array([25, 25])
)
ax = env.render()
plt.show(ax)

In [None]:
LearnFun = Callable[[np.ndarray, SARSA, HyperParams], np.ndarray]
HParamGenerator = Callable[[int, int], HyperParams]

In [None]:
def q_learn(
    q_table: np.ndarray,
    sarsa: SARSA,
    hparams: HyperParams
) -> np.ndarray:
    _, lr, dr = hparams
    s1, a1, r, s2, _ = sarsa
    q_table[s1, a1] += lr * (r + dr * q_table[s2].max() - q_table[s1, a1])
    return q_table

In [None]:
def choose_action(
    env: KaerMorhenv,
    q_table: np.ndarray,
    state: int,
    hparams: HyperParams,
) -> int:
    if np.random.rand() > hparams.exploration_rate:
        action = np.argmax(q_table[state])
    else:
        action = env.action_space.sample()
    return action

In [None]:
def adventure(
    env: KaerMorhenv,
    q_table: np.ndarray,
    hparams: HyperParams,
    learn_fun: Optional[LearnFun] = None
):
    state = env.reset()
    action = choose_action(env, q_table, state, hparams)
    done = False
    rewards = []
    actions = [action]
    while not done:
        new_state, reward, done, info = env.step(action)
        new_action = choose_action(env, q_table, new_state, hparams)
        sarsa = SARSA(state, action, reward, new_state, new_action)
        q_table = learn_fun(q_table, sarsa, hparams) if learn_fun else q_table
        state = new_state
        action = new_action
        actions.append(action)
        rewards.append(reward) 
    return q_table, sum(rewards), actions

In [None]:
def q_train(
    env: KaerMorhenv,
    epochs: int,
    learn_fun: LearnFun,
    h_param_generator: HParamGenerator,
) -> Tuple[np.ndarray, List[float]]:
    q_table = np.zeros((env.nS, env.nA))
    rewards_history = []
    actions_history = []
    for e in tqdm(range(epochs)):
        h_params = h_param_generator(epochs, e)
        q_table, reward, actions = adventure(env, q_table, h_params, learn_fun)
        rewards_history.append(reward)
        actions_history.append(actions)
        
    return q_table, rewards_history, actions_history

In [None]:
param_generator = lambda epochs, e: HyperParams(((epochs - e) /  epochs), 0.8, 0.95)

In [None]:
n_epochs = 5000
q_table, rewards_history, actions_history = q_train(
    env, n_epochs, q_learn, param_generator
)

In [None]:
plt.plot(range(n_epochs), rewards_history)

In [None]:
anim = env.render_actions(actions_history[-1], interval=200)
anim.save("anim.gif", writer="imagemagick")
anim