from [here](https://github.com/moskomule/pytorch.rl.learning) <br>
[another one](https://github.com/vikasjiitk/Deep-RL-Mountain-Car/blob/master/MCqlearn.py)

In [25]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
import math
from EXITrl.table_base import TableBase
from EXITrl.approximation_v_base import ApproximationVBase, ExperienceReplay
env = GridworldEnv()

### Sarsa (Table)

In [26]:
class Sarsa(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9):
        super().__init__(env, num_episodes, "epsilon_greedy", epsilon, alpha, gamma)

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _= self.env.step(action)
            action_ = self.policy(state_)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[state_, action_]
            td_error = td_target - self.Q[state][action]
            self.Q[state][action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = Sarsa(env, 50)
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -2, -1],
       [-1, -1, -1,  0],
       [-1, -1,  0,  0]])

### Sarsa lambda (Table)

In [3]:
class SarsaLambda(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9, lambd=0.1):
        super().__init__(env, num_episodes, "epsilon_greedy", epsilon, alpha, gamma, lambd)
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        state = self.env.reset()
        action = self.policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = self.policy(state_)
            ########## CORE Algorithm #########
            if done: 
                td_target = reward
            else: 
                td_target = reward + self.gamma * self.Q[state_, action_]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = SarsaLambda(env, 510)
s.train()
s.convert_Q_to_V()

array([[ 0, -1, -1, -2],
       [-1, -2, -2, -1],
       [-1, -2, -1, -1],
       [-2, -2, -1,  0]])

### Sarsa Approximtion (Grid World)

In [27]:
class GridworldEnv2DState(GridworldEnv):
    def __init__(self, shape=[4, 4]):
        super().__init__(shape)
        
    def convert_to_2_dimension_state(self, state):
        return np.array([math.floor(state/4), state%4], dtype=int)
    
    def reset(self):
        return self.convert_to_2_dimension_state(super(GridworldEnv, self).reset())
    
    def step(self, action):
        state, reward, done, info = super(GridworldEnv, self).step(action)
        state = self.convert_to_2_dimension_state(state)
        return state, reward, done, info
    
env = GridworldEnv2DState()

In [28]:
class SarsaApproximation(ApproximationVBase):
    def __init__(self, 
                 env, 
                 num_state, 
                 num_action, 
                 num_episodes, 
                 num_experience=100, 
                 epsilon=0.01, 
                 alpha=0.008, 
                 gamma=.9):
        super().__init__(env, 
                         num_state, 
                         num_action, 
                         num_episodes, 
                         "epsilon_greedy",
                         epsilon, 
                         alpha, 
                         gamma)
        self.initialize()
        if num_experience==1:
            self.update_experience = self.update_step_by_step_experience
        else:
            self.experience_replay = ExperienceReplay(num_experience) 
            self.update_experience = self.update_experience_replay
    
    def update_step_by_step_experience(self, state, action, reward, state_, action_, done):
        if done:
            td_target = torch.Tensor(np.array(reward))
        else:
            td_target = reward + self.gamma * self.approximate_q(state_)[action_]
        predict_q = self.approximate_q(state)[action]
        self.update_weight(td_target, predict_q)
    
    def update_experience_replay(self, state, action, reward, state_, action_, done):
        def get_target(state, action, reward, state_, action_, done):
            if done:
                td_target = torch.Tensor(np.array(reward))
            else:
                td_target = reward + self.gamma * self.approximate_q(state_)[action_]
            predict_q = self.approximate_q(state)[action]
            return td_target, predict_q

        self.experience_replay.remember(state, action, reward, state_, action_, done)
        targets, predict_qs = self.experience_replay.get_batch(get_target)
        self.update_weight(targets, predict_qs)
        
    def _loop(self, episode) -> int:
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = self.policy(state)
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = self.policy(state_)
            self.update_experience(state, action, reward, state_, action_, done)
            total_reward += reward
            state = state_
            action = action_
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = env.convert_to_2_dimension_state(state)
            print(convert_state, self.approximate_q(convert_state).detach().numpy())
            V[state] = self.approximate_q(convert_state).max().item()
        return V.reshape(self.env.shape)
        
s = SarsaApproximation(env, 
                       num_state=2, 
                       num_action=env.action_space.n, 
                       num_episodes=50,
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
s.train(True)
s.convert_Q_to_V()

episode: 0 reward: -1.0
episode: 1 reward: -39.0
episode: 2 reward: -11.0
episode: 3 reward: -5.0
episode: 4 reward: -2.0
episode: 5 reward: -2.0
episode: 6 reward: -3.0
episode: 7 reward: -19.0
episode: 8 reward: -3.0
episode: 9 reward: -3.0
episode: 10 reward: -2.0
episode: 11 reward: -3.0
episode: 12 reward: -4.0
episode: 13 reward: -4.0
episode: 14 reward: 0.0
episode: 15 reward: -2.0
episode: 16 reward: -5.0
episode: 17 reward: -4.0
episode: 18 reward: -5.0
episode: 19 reward: -3.0
episode: 20 reward: -4.0
episode: 21 reward: -1.0
episode: 22 reward: -3.0
episode: 23 reward: -2.0
episode: 24 reward: -4.0
episode: 25 reward: -2.0
episode: 26 reward: 0.0
episode: 27 reward: 0.0
episode: 28 reward: -2.0
episode: 29 reward: 0.0
episode: 30 reward: -2.0
episode: 31 reward: -5.0
episode: 32 reward: -1.0
episode: 33 reward: -5.0
episode: 34 reward: -1.0
episode: 35 reward: -3.0
episode: 36 reward: 0.0
episode: 37 reward: -3.0
episode: 38 reward: -3.0
episode: 39 reward: -3.0
episode: 40 

array([[-0.45985079, -1.00015092, -1.88900685, -2.70013881],
       [-1.00140154, -1.89707613, -2.69473171, -3.44678116],
       [-1.89256406, -2.70018387, -3.42576194, -4.11727571],
       [-2.7459414 , -3.45785093, -4.18732262, -4.85527468]])

### Test nn by Q from Table base

In [29]:
Q = np.array([[ 0.0000,  0.0000,  0.0000,  0.0000],
            [-1.6439, -1.4790, -1.3537, -0.9999],
            [-2.1910, -2.2997, -2.0220, -1.9134],
            [-2.8211, -2.6443, -2.3472, -2.4686],
            [-0.9999, -1.0780, -1.4980, -1.6079],
            [-1.8345, -1.8612, -1.7473, -1.6657],
            [-1.8404, -2.1300, -2.0000, -2.1527],
            [-2.0695, -2.3770, -1.8677, -2.0256],
            [-1.8623, -2.0250, -2.4303, -2.0676],
            [-2.1694, -1.8296, -2.3622, -1.9963],
            [-1.9552, -1.6668, -1.4604, -1.7984],
            [-1.0469, -1.5610, -0.9980, -1.0685],
            [-2.4361, -2.4313, -2.4637, -2.7323],
            [-1.9719, -1.8673, -2.0908, -2.6633],
            [-1.4525, -0.9980, -1.5573, -1.9203],
            [ 0.0000,  0.0000,  0.0000,  0.0000]])
env = GridworldEnv2DState()
s = SarsaApproximation(env, 
                       num_state=2, 
                       num_action=env.action_space.n, 
                       num_episodes=50,
                       epsilon=0.01, 
                       alpha=0.008, 
                       gamma=.9)
for _ in range(500):
    idx = np.random.randint(Q.shape[0])
    action = np.random.randint(4)
    state = np.array([math.floor(idx/4), idx%4], dtype=int)
    td_target = Q[idx, action]
    s.update_weight(td_target, s.approximate_q(state)[action])
s.convert_Q_to_V()

[0 0] [-0.36118457 -1.3198457  -0.2557702  -0.22230437]
[0 1] [-1.511564  -2.1404922 -1.3061261 -1.1788965]
[0 2] [-2.4676168 -2.8936074 -2.0234046 -1.9880375]
[0 3] [-3.3136873 -3.579819  -2.4548166 -2.7299285]
[1 0] [-1.2686487 -1.6585399 -1.5670905 -1.4436761]
[1 1] [-1.6204808 -1.786285  -1.8365965 -1.4560344]
[1 2] [-1.9652414 -2.3225915 -1.9482744 -1.7251934]
[1 3] [-2.5179617 -2.9341543 -2.1117406 -2.2142372]
[2 0] [-1.9817442 -2.006702  -2.524445  -2.5173528]
[2 1] [-1.7604569 -1.700321  -2.151839  -2.006669 ]
[2 2] [-1.5198797 -1.619311  -1.7397889 -1.4716948]
[2 3] [-1.7550889 -2.0477638 -1.7674761 -1.6426393]
[3 0] [-2.5028756 -2.3395753 -3.1446683 -3.3754904]
[3 1] [-2.1944273 -1.8976816 -2.7189476 -2.8086221]
[3 2] [-1.6339544 -1.3893665 -2.010995  -1.9795097]
[3 3] [-1.3282132 -1.2867506 -1.5140728 -1.3920003]


array([[-0.22230437, -1.17889655, -1.98803747, -2.45481658],
       [-1.26864874, -1.45603442, -1.72519338, -2.11174059],
       [-1.98174417, -1.70032096, -1.47169483, -1.64263928],
       [-2.33957529, -1.89768159, -1.38936651, -1.28675056]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [24]:
import gym
env = gym.make('CartPole-v1')
s = SarsaApproximation(env, 
                       num_state=env.observation_space.shape[0],
                       num_action=env.action_space.n, 
                       num_episodes=50,
                       num_experience=512,
                       epsilon=0.01, 
                       alpha=0.007, 
                       gamma=.99)
s.train(True)


  result = entry_point.load(False)


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
episode: 0 reward: 13.0
episode: 1 reward: 17.0
episode: 2 reward: 21.0
episode: 3 reward: 25.0
episode: 4 reward: 123.0
episode: 5 reward: 102.0
episode: 6 reward: 30.0
episode: 7 reward: 8.0
episode: 8 reward: 46.0
episode: 9 reward: 62.0
episode: 10 reward: 78.0
episode: 11 reward: 84.0
episode: 12 reward: 115.0
episode: 13 reward: 124.0
episode: 14 reward: 121.0
episode: 15 reward: 131.0
episode: 16 reward: 154.0
episode: 17 reward: 182.0
episode: 18 reward: 179.0
episode: 19 reward: 255.0
episode: 20 reward: 205.0
episode: 21 reward: 178.0
episode: 22 reward: 13.0
episode: 23 reward: 19.0
episode: 24 reward: 208.0
episode: 25 reward: 155.0
episode: 26 reward: 457.0
episode: 27 reward: 206.0
episode: 28 reward: 220.0
episode: 29 reward: 178.0
episode: 30 reward: 196.0
episode: 31 reward: 500.0
episode: 32 reward: 192.0
episode: 33 reward: 14.0
episode: 34 reward: 18.0
episode