# Hex Tiled Q-Learning SAR example

The challenge in this example is implementing a Q-Learning search algorithm in a hexagonal tiled environment

In [None]:
import sys, os

sys.path.insert(0, os.path.join(os.getcwd(), "../src")) # run from within examples folder

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from copy import copy
from loguru import logger

In [None]:
from jsim.Environment import HexEnvironment, HexDirections
from jsim.Agent import Agent
from jsim.Simulation import Simulation
from jsim.Environment.HexEnvironment.HexCoords import OffsetCoord
from jsim.Meta import State

In [None]:
class QHexEnv(HexEnvironment):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.pdm = self._generate_pdm()
        self.n_bins = 4
        self.digitized_pdm = self._encode_pdm(n=self.n_bins)
        c = np.arange(0,1+np.max(self.digitized_pdm))
        u,v,w,x,y,z = np.meshgrid(c,c,c,c,c,c) # one for each direction (6)
        u,v,w,x,y,z = u.flatten(), v.flatten(), w.flatten(), x.flatten(), y.flatten(), z.flatten()
        self.possible_states = {(a,b,c,d,e,f): int(np.where((u==a)&(v==b)&(w==c)&(x==d)&(y==e)&(z==f))[0]) for a,b,c,d,e,f in zip(u,v,w,x,y,z)}


    def _generate_pdm(self, N:int=5):
        x, y = np.meshgrid(np.arange(self.shape[0]), np.arange(self.shape[1]))

        pdm = np.zeros(self.shape)

        for _ in range(N): # Generate N gaussians
            A = max(np.random.rand()*2,1)
            a = np.random.rand()*0.03
            b = 0
            c = np.random.rand()*0.03
            x0 = np.random.uniform(0,self.shape[0]) # Centre of gaussian
            y0 = np.random.uniform(0,self.shape[1])

            pdm += A*np.exp(-(a*(x-x0)**2+2*b*(x-x0)*(y-y0)+c*(y-y0)**2))

        return pdm

    def _encode_pdm(self, n: int) -> np.ndarray:
        n_bins = n
        bins = np.arange(0,np.max(self.pdm), np.max(self.pdm)/n_bins)
        return np.digitize(self.pdm, bins)

    def _state_from_list(self, l):
        assert len(l) == len(HexDirections)
        return self.possible_states[tuple(l)]


    def _evaluate_position(self, pos: OffsetCoord) -> int:
        return self.digitized_pdm[pos.row][pos.col]

    def reset(self, agent_p: OffsetCoord) -> int:
        #self.pdm = self._generate_pdm()
        #self.digitized_pdm = self._encode_pdm(n=self.n_bins)
        return self._evaluate_position(agent_p)

    def step(self, agent_s: OffsetCoord) -> tuple[int, float]:
        reward = self._evaluate_position(agent_s)
        neighbors = self.neighbors_coord(agent_s)

        values_at_neighbors = self._state_from_list([self._evaluate_position(n) for n in neighbors])

        return values_at_neighbors, reward


In [None]:
class QAgent(Agent):
    def __init__(self, penv: QHexEnv = None) -> None:
        self.state = OffsetCoord(col=0,row=0)
        self.penv = penv

        self.qtable = np.zeros((len(self.penv.possible_states),6))

    def policy(self, pnext_s: int) -> HexDirections:
        epsilon = 0.2
        if np.random.uniform(0,1) < epsilon:
            idx = np.random.randint(0,6)
        else:
            idx = np.argmax(self.qtable[pnext_s])
        return HexDirections(idx)

    def step(self, pnext_s: int) -> HexDirections:
        return self.policy(pnext_s)

    def learn(self, reward: float, state: int, next_state: int,  action: HexDirections) -> None:
        alpha = 0.1
        gamma = 0.9

        reward = -reward

        new_value = (1-alpha) * self.qtable[state,action] + alpha * (reward+gamma*np.max(self.qtable[next_state]))

        self.qtable[state][action] = new_value


    def update(self, pa: HexDirections) -> OffsetCoord:
        if not hasattr(self,'state'):
            logger.error(f"Reset has not been called as {hasattr(self,'state')=}")
            raise Exception

        next_state = QHexEnv.neighbor_coord(self.state, pa)
        self.state = next_state

        return copy(self.state)

    def reset(self, ps: OffsetCoord, vicinity: int) -> tuple[HexDirections, OffsetCoord]:
        self.state = copy(ps)
        return self.policy(vicinity), self.state

In [None]:
class QSim(Simulation):
    agent: QAgent
    env: QHexEnv

    def __init__(self, initial_pos=OffsetCoord(col=5,row=5)) -> None:
        self.env = QHexEnv(psim=self)
        self.agent = QAgent(penv=self.env)

        self.initial_agent_s = initial_pos

        self.data_store = {'env_s':[],'agent_s':[],'agent_a':[],'reward':[]}
        self.long_term_ds = {}
        super().__init__()

    def reset(self) -> None:
        self.vicinity_pdm = self.env.reset(self.initial_agent_s)
        self.agent_a, self.agent_s = self.agent.reset(copy(self.initial_agent_s),self.vicinity_pdm)
        self.data_store = {'env_s':[],'agent_s':[],'agent_a':[],'reward':[]}

    def trials(self, num_trials: int, max_num_steps: int):
        for i in range(num_trials):
            try:
                self.steps(max_num_steps)
                logger.info(f"Trial = {i} | Max number of steps ({max_num_steps}) reached ")
            except IndexError:
                logger.info(f"Trial = {i} | Agent went out of bounds, reseting")
            finally:
                self.long_term_ds[i] = copy(self.data_store)

    def steps(self, num_steps: int) -> None:
        self.reset()

        for _ in range(num_steps):
            vicinity_pdm, reward = self.env.step(self.agent_s)

            self.collect_data(vicinity_pdm, self.agent_a, self.agent_s, reward)

            agent_a = self.agent.step(vicinity_pdm)
            agent_s = self.agent.update(agent_a)

            self.agent.learn(reward, self.vicinity_pdm, vicinity_pdm, agent_a)

            self.agent_s = agent_s
            self.agent_a = agent_a
            self.vicinity_pdm = vicinity_pdm



    def collect_data(self, env_s: int, agent_a: HexDirections, agent_s: OffsetCoord, reward: float) -> None:
        self.data_store['env_s'].append(copy(env_s))
        self.data_store['agent_a'].append(copy(agent_a))
        self.data_store['agent_s'].append(copy(agent_s))
        self.data_store['reward'].append(copy(reward))

    def plot(self):
        fig, (ax1, ax2) = plt.subplots(2,1, gridspec_kw={'height_ratios': [1, 3]}, figsize=(8, 6), dpi=80)

        t = np.arange(0,len(self.long_term_ds))

        # Reward over time
        ax1.bar(self.long_term_ds.keys(), [np.sum(self.long_term_ds[ds]['reward']) for ds in self.long_term_ds])
        ax1.set_ylabel('Total reward')
        ax1.set_xlabel('Epoch')

        # Show pdm and path
        xy = [HexEnvironment.offset_to_pixel(f) for f in self.data_store['agent_s']]
        x = [f.x for f in xy]
        y = [f.y for f in xy]

        pdm = self.env.pdm
        for p in self.env.as_mpl_polygons(cmap=mpl.cm.get_cmap('gray')):
            ax2.add_patch(p)

        ax2.plot(x,y)
        ax2.scatter(x[0],y[0],label='Start')
        ax2.scatter(x[-1],y[-1],label='End')
        ax2.set_ylabel('y')
        ax2.set_xlabel('x')
        ax2.legend()

        fig.tight_layout()
        plt.show()


In [None]:
sim = QSim()

sim.trials(50,100)

In [None]:
sim.plot()