from [here](https://github.com/moskomule/pytorch.rl.learning) <br>
[another one](https://github.com/vikasjiitk/Deep-RL-Mountain-Car/blob/master/MCqlearn.py)

In [1]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
import math
from EXITrl.table_base import TableBase
from EXITrl.approx_v_base import ApproxVBase, ExperienceReplay
from gridworld_env_2d_state import GridworldEnv2DState
env = GridworldEnv()

### Sarsa (Table)

In [3]:
class Sarsa(TableBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _= self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Q[state, action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = Sarsa(env, 50, policy="epsilon_greedy")
s.train()
s.convert_Q_to_V()

Episode 1	Average Score: -69.00 	other{}Episode 2	Average Score: -35.50 	other{}Episode 3	Average Score: -24.67 	other{}Episode 4	Average Score: -19.25 	other{}Episode 5	Average Score: -20.80 	other{}Episode 6	Average Score: -17.67 	other{}Episode 7	Average Score: -15.43 	other{}Episode 8	Average Score: -17.38 	other{}Episode 9	Average Score: -15.44 	other{}Episode 10	Average Score: -14.00 	other{}Episode 10	Average Score: -14.00 	other{}
Episode 11	Average Score: -7.10 	other{}Episode 12	Average Score: -8.20 	other{}Episode 13	Average Score: -9.00 	other{}Episode 14	Average Score: -9.20 	other{}Episode 15	Average Score: -6.50 	other{}Episode 16	Average Score: -6.90 	other{}Episode 17	Average Score: -7.20 	other{}Episode 18	Average Score: -5.00 	other{}Episode 19	Average Score: -5.40 	other{}Episode 20	Average Score: -5.60 	other{}Episode 20	Average Score: -5.60 	other{}
Episode 21	Average Score: -5.70 	other{}Episode 22	Average Score: -4.50 	other{}Episode 23	

array([[ 0,  0, -1, -1],
       [ 0, -1, -1, -1],
       [-1, -1, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa lambda (Table)

In [8]:
class SarsaLambda(TableBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = SarsaLambda(env, 50, policy="epsilon_greedy")
s.train()
s.convert_Q_to_V()

Episode 1	Average Score: -35.00 	other{}Episode 2	Average Score: -18.00 	other{}Episode 3	Average Score: -24.00 	other{}Episode 4	Average Score: -19.25 	other{}Episode 5	Average Score: -17.60 	other{}Episode 6	Average Score: -17.67 	other{}Episode 7	Average Score: -17.43 	other{}Episode 8	Average Score: -17.75 	other{}Episode 9	Average Score: -17.00 	other{}Episode 10	Average Score: -15.90 	other{}Episode 10	Average Score: -15.90 	other{}
Episode 11	Average Score: -12.50 	other{}Episode 12	Average Score: -12.50 	other{}Episode 13	Average Score: -9.00 	other{}Episode 14	Average Score: -9.00 	other{}Episode 15	Average Score: -8.80 	other{}Episode 16	Average Score: -7.60 	other{}Episode 17	Average Score: -7.40 	other{}Episode 18	Average Score: -6.90 	other{}Episode 19	Average Score: -6.10 	other{}Episode 20	Average Score: -6.10 	other{}Episode 20	Average Score: -6.10 	other{}
Episode 21	Average Score: -6.40 	other{}Episode 22	Average Score: -6.90 	other{}Episode 2

array([[ 0,  0, -1, -2],
       [ 0, -1, -2, -1],
       [-1, -1, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa Approximtion (Grid World)

In [12]:
class SarsaApproximation(ApproxVBase):
    def __init__(self, num_experience=100, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()
        if num_experience==1:
            self.update_experience = self.update_step_by_step_experience
        else:
            self.experience_replay = ExperienceReplay(num_experience) 
            self.update_experience = self.update_experience_replay
    
    def update_step_by_step_experience(self, state, action, reward, _state, _action, done):
        if done:
            td_target = torch.Tensor(np.array(reward))
        else:
            td_target = reward + self.gamma * self.get_q(_state)[_action]
        current_q = self.get_q(state)[action]
        self.update_q(td_target, current_q)
    
    def update_experience_replay(self, state, action, reward, _state, _action, done):
        def get_target(state, action, reward, _state, _action, done):
            if done:
                td_target = torch.Tensor(np.array(reward))
            else:
                td_target = reward + self.gamma * self.get_q(_state)[_action]
            predict_q = self.get_q(state)[action]
            return td_target, predict_q

        self.experience_replay.remember(state, action, reward, _state, _action, done)
        targets, predict_qs = self.experience_replay.get_batch(get_target)
        current_q = self.get_q(state)[action]
        self.update_q(targets, current_q)
        
    def _loop(self, episode) -> int:
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = self.policy(state)
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            self.update_experience(state, action, reward, _state, _action, done)
            total_reward += reward
            state = _state
            action = _action
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = env.convert_to_2_dimension_state(state)
            print(convert_state, self.get_q(convert_state).detach().numpy())
            V[state] = self.get_q(convert_state).max().item()
        return V.reshape(self.env.shape)

env = GridworldEnv2DState()
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
s.train(True)
s.convert_Q_to_V()

Episode 10	Average Score: -4.30 	other{}
Episode 20	Average Score: -7.30 	other{}
Episode 30	Average Score: -5.20 	other{}
Episode 40	Average Score: -9.40 	other{}}
Episode 50	Average Score: -5.20 	other{}
[0 0] [-0.9077287 -1.282341  -1.1293762 -1.2470013]
[0 1] [-0.87199724 -1.0577717  -0.7746947  -1.0634011 ]
[0 2] [-1.0023547  -1.1457618  -0.78058726 -1.1445779 ]
[0 3] [-1.2064143 -1.3191538 -0.8626278 -1.2914224]
[1 0] [-0.8681651 -1.2761105 -0.7396302 -1.2011229]
[1 1] [-0.88112247 -1.1920553  -0.6213519  -1.1986345 ]
[1 2] [-1.0490115  -1.3578742  -0.68235993 -1.339334  ]
[1 3] [-1.2284728  -1.5378819  -0.75267804 -1.4917085 ]
[2 0] [-0.99110204 -1.52512    -0.75940585 -1.4113271 ]
[2 1] [-1.0444002 -1.4676784 -0.7043794 -1.4636323]
[2 2] [-1.1804383  -1.5996135  -0.75728023 -1.6135907 ]
[2 3] [-1.3209777  -1.7362726  -0.81212986 -1.7646189 ]
[3 0] [-1.1745627  -1.8637522  -0.83010244 -1.6675584 ]
[3 1] [-1.2222526 -1.7895045 -0.79371   -1.7461829]
[3 2] [-1.3582909 -1.9214396 -

array([[-0.90772867, -0.77469468, -0.78058726, -0.8626278 ],
       [-0.73963022, -0.6213519 , -0.68235993, -0.75267804],
       [-0.75940585, -0.70437938, -0.75728023, -0.81212986],
       [-0.83010244, -0.79370999, -0.8466109 , -0.89951169]])

### Test nn by Q from Table base

In [10]:
Q = np.array([[ 0.0000,  0.0000,  0.0000,  0.0000],
            [-1.6439, -1.4790, -1.3537, -0.9999],
            [-2.1910, -2.2997, -2.0220, -1.9134],
            [-2.8211, -2.6443, -2.3472, -2.4686],
            [-0.9999, -1.0780, -1.4980, -1.6079],
            [-1.8345, -1.8612, -1.7473, -1.6657],
            [-1.8404, -2.1300, -2.0000, -2.1527],
            [-2.0695, -2.3770, -1.8677, -2.0256],
            [-1.8623, -2.0250, -2.4303, -2.0676],
            [-2.1694, -1.8296, -2.3622, -1.9963],
            [-1.9552, -1.6668, -1.4604, -1.7984],
            [-1.0469, -1.5610, -0.9980, -1.0685],
            [-2.4361, -2.4313, -2.4637, -2.7323],
            [-1.9719, -1.8673, -2.0908, -2.6633],
            [-1.4525, -0.9980, -1.5573, -1.9203],
            [ 0.0000,  0.0000,  0.0000,  0.0000]])
env = GridworldEnv2DState()
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
for _ in range(500):
    idx = np.random.randint(Q.shape[0])
    action = np.random.randint(4)
    state = np.array([math.floor(idx/4), idx%4], dtype=int)
    td_target = Q[idx, action]
    s.update_v(td_target, s.get_q(state)[action])
s.convert_Q_to_V()

[0 0] [-0.08950523 -0.04526846 -0.4771667  -0.518084  ]
[0 1] [-1.3291311 -1.3113158 -1.2687043 -1.242636 ]
[0 2] [-2.2121005 -2.2506235 -1.8334401 -2.0281816]
[0 3] [-2.7804286 -2.812126  -2.0403173 -2.5864334]
[1 0] [-1.0722723 -1.1309662 -1.5587296 -1.9006   ]
[1 1] [-1.1853498 -1.3835914 -1.5559304 -1.6240482]
[1 2] [-1.4170573 -1.5560282 -1.491466  -1.6300268]
[1 3] [-1.8548989 -1.9498202 -1.5356302 -2.021651 ]
[2 0] [-2.1810641 -2.0493307 -2.4047909 -2.9850712]
[2 1] [-1.4733295 -1.5274479 -1.7964101 -2.1572762]
[2 2] [-0.71565205 -0.88860154 -1.0541437  -1.2763841 ]
[2 3] [-0.92003447 -1.0146617  -0.85674316 -1.2646188 ]
[3 0] [-2.8465388 -2.507756  -2.814615  -3.7934325]
[3 1] [-1.9259622 -1.7914754 -2.0552983 -2.8513947]
[3 2] [-0.96153903 -0.9717023  -1.1913431  -1.7806051 ]
[3 3] [-0.14823942 -0.25928116 -0.3466466  -0.76996356]


array([[-0.04526846, -1.24263597, -1.83344007, -2.0403173 ],
       [-1.0722723 , -1.18534982, -1.41705728, -1.53563023],
       [-2.04933071, -1.47332954, -0.71565205, -0.85674316],
       [-2.50775599, -1.79147542, -0.96153903, -0.14823942]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [11]:
import gym
env = gym.make('CartPole-v1')
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       num_experience=512,
                       epsilon=0.01, 
                       alpha=0.007, 
                       gamma=.99)
s.train(True)


  result = entry_point.load(False)


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
Episode 10	Average Score: 9.70 	other{}
Episode 20	Average Score: 9.10 	other{}
Episode 30	Average Score: 9.50 	other{}
Episode 40	Average Score: 9.20 	other{}
Episode 50	Average Score: 9.10 	other{}
