from [here](https://github.com/moskomule/pytorch.rl.learning)

In [46]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
from EXITrl.table_base import TableBase, ApproximationBase
env = GridworldEnv()

### Sarsa (Table)

In [3]:
class Sarsa(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9):
        super().__init__(env, num_episodes, epsilon, alpha, gamma)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        state = self.env.reset()
        action = policy(state)
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _= self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.Q[state_][action_]
            td_error = td_target - self.Q[state][action]
            self.Q[state][action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = Sarsa(env, 50)
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -2, -2],
       [ 0, -1, -2, -1],
       [-1, -2, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa lambda (Table)

In [142]:
class SarsaLambda(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9, lambd=0.1):
        super().__init__(env, num_episodes, epsilon, alpha, gamma, lambd)
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        state = self.env.reset()
        action = policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.Q[state_, action_]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = SarsaLambda(env, 50)
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -2, -1],
       [-1, -2, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa Approximtion (Grid World)

In [63]:
raw_state = 14# np.array([x for x in range(16)])
raw_state%4, np.floor(raw_state/4)

(2, 3.0)

In [241]:
import math
env = GridworldEnv()
class SarsaApproximation(ApproximationBase):
    def __init__(self, env, num_episodes, epsilon=0.5, alpha=0.001, gamma=.9):
        num_state = 2 # width * height
        super().__init__(env, 
                         num_state, 
                         env.action_space.n, 
                         num_episodes, 
                         epsilon, 
                         alpha, 
                         gamma)
        self.state_rage = env.observation_space.n ** (1/num_state)
        
    def convert_to_2_dimension_state(self, state):
        return np.array([math.floor(state/self.state_rage), state%self.state_rage], dtype=int)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        done = False
        total_reward, reward = 0, 0
        state = self.convert_to_2_dimension_state(self.env.reset())
        action = policy(state)
        while not done:
            state_, reward, done, _ = self.env.step(action)
            state_ = self.convert_to_2_dimension_state(state_)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.approximate_q(state_, action_)
            self.update_weight(td_target, self.approximate_q(state, action))
#             td_error = td_target - self.approximate_q(state, action)
#             self.weight -= self.alpha * td_error * self.feature(state, action)
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = self.convert_to_2_dimension_state(state)
            qs = []
            for action in range(self.env.action_space.n):
                qs.append(self.approximate_q(convert_state, action).item())
            V[state] = np.array(qs).max()
        return V.reshape(self.env.shape)
        
s = SarsaApproximation(env, 50)
s.train()
s.convert_Q_to_V()

array([[-0.60148382, -0.57171768, -0.54195154, -0.51218545],
       [-0.54788327, -0.58922839, -0.62383676, -0.65844512],
       [-0.49428269, -0.54139221, -0.5769729 , -0.61158127],
       [-0.44068211, -0.48940787, -0.53010905, -0.56471741]])

In [204]:
c = [1,0]
a = np.array([[1,2],[3,4]])
a[c]

array([[3, 4],
       [1, 2]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [49]:
import gym
env = gym.make('CartPole-v1')
import random
class SarsaApproximation(ApproximationBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.001, gamma=.9):
        super().__init__(env,
                         env.observation_space.shape[0],
                         env.action_space.n,
                         num_episodes, 
                         epsilon, 
                         alpha, 
                         gamma)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = policy(state)
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.approximate_q(state_, action_)
            self.update_weight(td_target, self.approximate_q(state, action))
#             td_error = td_target - self.approximate_q(state, action)
#             self.weight -= self.alpha * td_error * self.feature(state, action)
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = SarsaApproximation(env, 50)
s.train(True)
# s.convert_Q_to_V()

  result = entry_point.load(False)


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
episode: 0 reward: 9.0
episode: 1 reward: 11.0
episode: 2 reward: 10.0
episode: 3 reward: 11.0
episode: 4 reward: 11.0
episode: 5 reward: 10.0
episode: 6 reward: 11.0
episode: 7 reward: 10.0
episode: 8 reward: 8.0
episode: 9 reward: 10.0
episode: 10 reward: 8.0
episode: 11 reward: 11.0
episode: 12 reward: 9.0
episode: 13 reward: 9.0
episode: 14 reward: 9.0
episode: 15 reward: 11.0
episode: 16 reward: 11.0
episode: 17 reward: 10.0
episode: 18 reward: 10.0
episode: 19 reward: 9.0
episode: 20 reward: 10.0
episode: 21 reward: 10.0
episode: 22 reward: 11.0
episode: 23 reward: 8.0
episode: 24 reward: 9.0
episode: 25 reward: 9.0
episode: 26 reward: 10.0
episode: 27 reward: 9.0
episode: 28 reward: 10.0
episode: 29 reward: 8.0
episode: 30 reward: 10.0
episode: 31 reward: 13.0
episode: 32 reward: 8.0
episode: 33 reward: 9.0
episode: 34 reward: 11.0
episode: 35 reward: 9.0
episode: 36 rewar