from [here](https://github.com/moskomule/pytorch.rl.learning) <br>
[another one](https://github.com/vikasjiitk/Deep-RL-Mountain-Car/blob/master/MCqlearn.py)

In [10]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
import math
from EXITrl.table_base import TableBase
from EXITrl.approx_v_base import ApproxVBase, ExperienceReplay
from gridworld_env_2d_state import GridworldEnv2DState
env = GridworldEnv()

### Sarsa (Table)

In [2]:
class Sarsa(TableBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _= self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Q[state, action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = Sarsa(env, 50, "epsilon_greedy")
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -2, -1],
       [-1, -1, -1,  0],
       [-1, -1,  0,  0]])

### Sarsa lambda (Table)

In [104]:
class SarsaLambda(TableBase):
    def __init__(self, env, num_episodes, policy, *args, **kwargs):
        super().__init__(env, num_episodes, policy, *args, **kwargs)
        self.initialize()
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[_state, _action]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = _state
            action = _action
        return total_reward
s = SarsaLambda(env, 510, "epsilon_greedy")
s.train()
s.convert_Q_to_V()

array([[ 0, -1, -2, -2],
       [-1, -1, -2, -2],
       [-1, -2, -1, -1],
       [-2, -1, -1,  0]])

### Sarsa Approximtion (Grid World)

In [11]:
class SarsaApproximation(ApproxVBase):
    def __init__(self, num_experience=100, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialize()
        if num_experience==1:
            self.update_experience = self.update_step_by_step_experience
        else:
            self.experience_replay = ExperienceReplay(num_experience) 
            self.update_experience = self.update_experience_replay
    
    def update_step_by_step_experience(self, state, action, reward, _state, _action, done):
        if done:
            td_target = torch.Tensor(np.array(reward))
        else:
            td_target = reward + self.gamma * self.get_q(_state)[_action]
        current_q = self.get_q(state)[action]
        self.update_q(td_target, current_q)
    
    def update_experience_replay(self, state, action, reward, _state, _action, done):
        def get_target(state, action, reward, _state, _action, done):
            if done:
                td_target = torch.Tensor(np.array(reward))
            else:
                td_target = reward + self.gamma * self.get_q(_state)[_action]
            predict_q = self.get_q(state)[action]
            return td_target, predict_q

        self.experience_replay.remember(state, action, reward, _state, _action, done)
        targets, predict_qs = self.experience_replay.get_batch(get_target)
        current_q = self.get_q(state)[action]
        self.update_q(targets, current_q)
        
    def _loop(self, episode) -> int:
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = self.policy(state)
        while not done:
            _state, reward, done, _ = self.env.step(action)
            _action = self.policy(_state)
            self.update_experience(state, action, reward, _state, _action, done)
            total_reward += reward
            state = _state
            action = _action
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = env.convert_to_2_dimension_state(state)
            print(convert_state, self.approximate_q(convert_state).detach().numpy())
            V[state] = self.approximate_q(convert_state).max().item()
        return V.reshape(self.env.shape)

env = GridworldEnv2DState()
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
s.train(True)
s.convert_Q_to_V()

episode: 0 reward: -2.0
episode: 1 reward: -43.0
episode: 2 reward: -44.0
episode: 3 reward: -2.0


KeyboardInterrupt: 

### Test nn by Q from Table base

In [73]:
Q = np.array([[ 0.0000,  0.0000,  0.0000,  0.0000],
            [-1.6439, -1.4790, -1.3537, -0.9999],
            [-2.1910, -2.2997, -2.0220, -1.9134],
            [-2.8211, -2.6443, -2.3472, -2.4686],
            [-0.9999, -1.0780, -1.4980, -1.6079],
            [-1.8345, -1.8612, -1.7473, -1.6657],
            [-1.8404, -2.1300, -2.0000, -2.1527],
            [-2.0695, -2.3770, -1.8677, -2.0256],
            [-1.8623, -2.0250, -2.4303, -2.0676],
            [-2.1694, -1.8296, -2.3622, -1.9963],
            [-1.9552, -1.6668, -1.4604, -1.7984],
            [-1.0469, -1.5610, -0.9980, -1.0685],
            [-2.4361, -2.4313, -2.4637, -2.7323],
            [-1.9719, -1.8673, -2.0908, -2.6633],
            [-1.4525, -0.9980, -1.5573, -1.9203],
            [ 0.0000,  0.0000,  0.0000,  0.0000]])
env = GridworldEnv2DState()
s = SarsaApproximation(env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
for _ in range(500):
    idx = np.random.randint(Q.shape[0])
    action = np.random.randint(4)
    state = np.array([math.floor(idx/4), idx%4], dtype=int)
    td_target = Q[idx, action]
    s.update_weight(td_target, s.approximate_q(state)[action])
s.convert_Q_to_V()

[0 0] [-0.8593397  -0.02596103  0.04927285 -0.09540172]
[0 1] [-1.550729  -1.3769492 -1.2023138 -1.0628335]
[0 2] [-2.2976317 -2.3127599 -2.0419474 -1.9821557]
[0 3] [-2.9014282 -2.870419  -2.4938223 -2.6326172]
[1 0] [-1.8078032  -0.85622334 -1.4937942  -1.3052136 ]
[1 1] [-1.6243346 -1.3108866 -1.5183533 -1.2685151]
[1 2] [-1.6815829 -1.6833512 -1.5008175 -1.2624112]
[1 3] [-2.1608539 -2.1386018 -1.841018  -1.8120695]
[2 0] [-2.4911911 -1.5572182 -2.3347552 -1.991789 ]
[2 1] [-1.8365465 -1.2378066 -1.7997459 -1.4752573]
[2 2] [-1.279315   -1.0324389  -1.0822536  -0.90967405]
[2 3] [-1.3371313 -1.3591679 -1.0460652 -0.9169352]
[3 0] [-3.1654108 -2.239034  -2.9942985 -2.5828466]
[3 1] [-2.314729  -1.7042893 -2.3467805 -2.056631 ]
[3 2] [-1.5461706 -1.1557423 -1.4187038 -1.2709324]
[3 3] [-0.91932964 -0.7418115  -0.5711436  -0.55552554]


array([[ 0.04927285, -1.06283355, -1.98215568, -2.49382234],
       [-0.85622334, -1.26851511, -1.26241124, -1.81206954],
       [-1.55721819, -1.23780656, -0.90967405, -0.91693521],
       [-2.23903394, -1.70428932, -1.15574229, -0.55552554]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [None]:
import gym
env = gym.make('CartPole-v1')
s = SarsaApproximation(env=env, 
                       num_episodes=50,
                       policy="epsilon_greedy",
                       num_experience=512,
                       epsilon=0.01, 
                       alpha=0.007, 
                       gamma=.99)
s.train(True)


  result = entry_point.load(False)


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
episode: 0 reward: 10.0
episode: 1 reward: 10.0
episode: 2 reward: 13.0
episode: 3 reward: 12.0
episode: 4 reward: 15.0
episode: 5 reward: 17.0
episode: 6 reward: 14.0
episode: 7 reward: 20.0
episode: 8 reward: 133.0
episode: 9 reward: 17.0
episode: 10 reward: 142.0
episode: 11 reward: 200.0
episode: 12 reward: 208.0
episode: 13 reward: 140.0
