from [here](https://github.com/moskomule/pytorch.rl.learning) <br>
[another one](https://github.com/vikasjiitk/Deep-RL-Mountain-Car/blob/master/MCqlearn.py)

In [5]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
import math
from EXITrl.table_base import TableBase
from EXITrl.approx_v_base import ApproxVBase, ExperienceReplay
from gridworld_env_2d_state import GridworldEnv2DState
env = GridworldEnv()

### Sarsa (Table)

In [6]:
class Sarsa(TableBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _= self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Q[state, action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = Sarsa(env, 50, "epsilon_greedy")
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -1, -1],
       [-1, -1, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa lambda (Table)

In [7]:
class SarsaLambda(TableBase):
    def __init__(self, env, num_episodes, policy, *args, **kwargs):
        super().__init__(env, num_episodes, policy, *args, **kwargs)
        self.initialize()
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = SarsaLambda(env, 510, "epsilon_greedy")
s.train()
s.convert_Q_to_V()

array([[ 0, -1, -1, -2],
       [-1, -1, -2, -1],
       [-1, -2, -2, -1],
       [-2, -1, -1,  0]])

### Sarsa Approximtion (Grid World)

In [12]:
class SarsaApproximation(ApproxVBase):
    def __init__(self, num_experience=100, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()
        if num_experience==1:
            self.update_experience = self.update_step_by_step_experience
        else:
            self.experience_replay = ExperienceReplay(num_experience) 
            self.update_experience = self.update_experience_replay
    
    def update_step_by_step_experience(self, state, action, reward, _state, _action, done):
        if done:
            td_target = torch.Tensor(np.array(reward))
        else:
            td_target = reward + self.gamma * self.get_q(_state)[_action]
        current_q = self.get_q(state)[action]
        self.update_q(td_target, current_q)
    
    def update_experience_replay(self, state, action, reward, _state, _action, done):
        def get_target(state, action, reward, _state, _action, done):
            if done:
                td_target = torch.Tensor(np.array(reward))
            else:
                td_target = reward + self.gamma * self.get_q(_state)[_action]
            predict_q = self.get_q(state)[action]
            return td_target, predict_q

        self.experience_replay.remember(state, action, reward, _state, _action, done)
        targets, predict_qs = self.experience_replay.get_batch(get_target)
        current_q = self.get_q(state)[action]
        self.update_q(targets, current_q)
        
    def _loop(self, episode) -> int:
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = self.policy(state)
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            self.update_experience(state, action, reward, _state, _action, done)
            total_reward += reward
            state = _state
            action = _action
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = env.convert_to_2_dimension_state(state)
            print(convert_state, self.get_q(convert_state).detach().numpy())
            V[state] = self.get_q(convert_state).max().item()
        return V.reshape(self.env.shape)

env = GridworldEnv2DState()
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
s.train(True)
s.convert_Q_to_V()

episode: 0 reward: 0.0
episode: 1 reward: -8.0
episode: 2 reward: -3.0
episode: 3 reward: -10.0
episode: 4 reward: -9.0
episode: 5 reward: -11.0
episode: 6 reward: -4.0
episode: 7 reward: -2.0
episode: 8 reward: -5.0
episode: 9 reward: -1.0
episode: 10 reward: -12.0
episode: 11 reward: -3.0
episode: 12 reward: -6.0
episode: 13 reward: -13.0
episode: 14 reward: -5.0
episode: 15 reward: 0.0
episode: 16 reward: -4.0
episode: 17 reward: -2.0
episode: 18 reward: -23.0
episode: 19 reward: -10.0
episode: 20 reward: -2.0
episode: 21 reward: 0.0
episode: 22 reward: -10.0
episode: 23 reward: -3.0
episode: 24 reward: -6.0
episode: 25 reward: -16.0
episode: 26 reward: -26.0
episode: 27 reward: -2.0
episode: 28 reward: -9.0
episode: 29 reward: -17.0
episode: 30 reward: -17.0
episode: 31 reward: -9.0
episode: 32 reward: -2.0
episode: 33 reward: -20.0
episode: 34 reward: -15.0
episode: 35 reward: -23.0
episode: 36 reward: -7.0
episode: 37 reward: -19.0
episode: 38 reward: -3.0
episode: 39 reward: -21

array([[-0.68539697, -1.00315309, -1.13466418, -1.28627765],
       [-0.79662758, -1.01211452, -1.16810155, -1.32517374],
       [-0.88837278, -1.04370654, -1.20077884, -1.36256003],
       [-0.89547336, -1.06510794, -1.23345613, -1.40207601]])

### Test nn by Q from Table base

In [14]:
Q = np.array([[ 0.0000,  0.0000,  0.0000,  0.0000],
            [-1.6439, -1.4790, -1.3537, -0.9999],
            [-2.1910, -2.2997, -2.0220, -1.9134],
            [-2.8211, -2.6443, -2.3472, -2.4686],
            [-0.9999, -1.0780, -1.4980, -1.6079],
            [-1.8345, -1.8612, -1.7473, -1.6657],
            [-1.8404, -2.1300, -2.0000, -2.1527],
            [-2.0695, -2.3770, -1.8677, -2.0256],
            [-1.8623, -2.0250, -2.4303, -2.0676],
            [-2.1694, -1.8296, -2.3622, -1.9963],
            [-1.9552, -1.6668, -1.4604, -1.7984],
            [-1.0469, -1.5610, -0.9980, -1.0685],
            [-2.4361, -2.4313, -2.4637, -2.7323],
            [-1.9719, -1.8673, -2.0908, -2.6633],
            [-1.4525, -0.9980, -1.5573, -1.9203],
            [ 0.0000,  0.0000,  0.0000,  0.0000]])
env = GridworldEnv2DState()
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
for _ in range(500):
    idx = np.random.randint(Q.shape[0])
    action = np.random.randint(4)
    state = np.array([math.floor(idx/4), idx%4], dtype=int)
    td_target = Q[idx, action]
    s.update_v(td_target, s.get_q(state)[action])
s.convert_Q_to_V()

[0 0] [-0.29720393 -0.77073574 -0.39772606 -0.09194392]
[0 1] [-1.4442132 -1.6489762 -1.2947102 -1.1942033]
[0 2] [-2.331369  -2.3084543 -1.9601423 -1.9670415]
[0 3] [-2.953517  -2.7731466 -2.2698061 -2.405674 ]
[1 0] [-1.2534157 -1.2084376 -1.3155695 -1.4328315]
[1 1] [-1.5475113 -1.566447  -1.4979403 -1.675886 ]
[1 2] [-1.6540701 -1.7119943 -1.4306933 -1.5571195]
[1 3] [-2.0827663 -2.0529833 -1.5269271 -1.7968255]
[2 0] [-1.9395704 -1.6109526 -1.6911267 -2.23518  ]
[2 1] [-1.7863352 -1.5094359 -1.5015067 -2.0322442]
[2 2] [-1.0499052 -1.1368555 -0.8278215 -1.2377869]
[2 3] [-1.0864139 -1.2765877 -0.7017511 -1.1208764]
[3 0] [-2.4744623 -1.9700685 -2.0178757 -2.951979 ]
[3 1] [-2.1332045 -1.7127857 -1.6662935 -2.6075447]
[3 2] [-1.2380323  -1.0950063  -0.79769313 -1.6168311 ]
[3 3] [-0.38406634 -0.59264046  0.01429594 -0.6642961 ]


array([[-0.09194392, -1.19420326, -1.96014225, -2.26980615],
       [-1.20843756, -1.4979403 , -1.43069327, -1.52692711],
       [-1.61095262, -1.50150669, -0.82782149, -0.70175111],
       [-1.97006845, -1.6662935 , -0.79769313,  0.01429594]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [None]:
import gym
env = gym.make('CartPole-v1')
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       num_experience=512,
                       epsilon=0.01, 
                       alpha=0.007, 
                       gamma=.99)
s.train(True)
