from [here](https://github.com/moskomule/pytorch.rl.learning)

In [445]:
%reload_ext autoreload
%autoreload 2
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import torch
import random
import numpy as np
from EXITrl.table_base import TableBase
from EXITrl.approximation_base import ApproximationBase
env = GridworldEnv()

### Sarsa (Table)

In [420]:
class Sarsa(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9):
        super().__init__(env, num_episodes, epsilon, alpha, gamma)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        state = self.env.reset()
        action = policy(state)
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _= self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.Q[state_][action_]
            td_error = td_target - self.Q[state][action]
            self.Q[state][action] += self.alpha * td_error
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = Sarsa(env, 50)
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -2, -1],
       [-1, -2, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa lambda (Table)

In [421]:
class SarsaLambda(TableBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.5, gamma=.9, lambd=0.1):
        super().__init__(env, num_episodes, epsilon, alpha, gamma, lambd)
        self.Z = self.Q.clone()

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        state = self.env.reset()
        action = policy(state)
        self.Z.zero_()
        total_reward = 0
        done = False
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.Q[state_, action_]
            td_error = td_target - self.Q[state, action]
            self.Z[state, action] += 1
            self.Q += self.alpha * td_error * self.Z
            self.Z = self.gamma * self.lambd * self.Z
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = SarsaLambda(env, 50)
s.train()
s.convert_Q_to_V()

array([[ 0,  0, -1, -2],
       [ 0, -1, -1, -1],
       [-1, -2, -1,  0],
       [-2, -1,  0,  0]])

### Sarsa Approximtion (Grid World)

In [443]:
import math
env = GridworldEnv()
class SarsaApproximation(ApproximationBase):
    def __init__(self, env, num_episodes, epsilon=0.01, alpha=0.01, gamma=.9):
        num_state = 2 # width * height
        super().__init__(env, 
                         num_state, 
                         env.action_space.n, 
                         num_episodes, 
                         epsilon, 
                         alpha, 
                         gamma)
        self.state_rage = env.observation_space.n ** (1/num_state)
        self.experience_replay = new 
        
    def convert_to_2_dimension_state(self, state):
        return np.array([math.floor(state/self.state_rage), state%self.state_rage], dtype=int)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        done = False
        total_reward, reward = 0, 0
        state = self.convert_to_2_dimension_state(self.env.reset())
        action = policy(state)
        while not done:
            state_, reward, done, _ = self.env.step(action)
            state_ = self.convert_to_2_dimension_state(state_)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.approximate_q(state_)[action_]
            qs = self.approximate_q(state)
            target_qs = qs.clone().detach()
            target_qs[action] = td_target
            self.update_weight(target_qs, qs) # ???? detach???
            ###################################
            total_reward += reward
            state = state_
            action = action_
#         print(episode, s.convert_Q_to_V())
        return total_reward
    
    def convert_Q_to_V(self):
        V = np.array([0.]*self.env.observation_space.n)
        for state in range(self.env.observation_space.n):
            convert_state = self.convert_to_2_dimension_state(state)
            print(convert_state, self.approximate_q(convert_state).detach().numpy())
            V[state] = self.approximate_q(convert_state).max().item()
        return V.reshape(self.env.shape)
        
s = SarsaApproximation(env, 50)
s.train()
s.convert_Q_to_V()

[0 0] [-0.14976396 -1.3569927  -0.06195487 -0.75239843]
[0 1] [-1.25385   -2.7299783 -1.5095493 -1.0155364]
[0 2] [-2.3397498 -4.6651506 -2.782865  -1.8875178]
[0 3] [-3.4441392 -6.66638   -4.079488  -2.7538733]
[1 0] [-1.0204146 -2.569368  -1.5627735 -2.4981616]
[1 1] [-1.835434  -3.7967498 -2.4308038 -2.7876065]
[1 2] [-2.5968835 -5.521755  -3.4551175 -3.39391  ]
[1 3] [-3.4486918 -7.3800316 -4.618518  -4.201066 ]
[2 0] [-1.9644141 -4.1391964 -3.1750574 -4.685544 ]
[2 1] [-2.575557 -5.295725 -3.727499 -4.803001]
[2 2] [-3.2865708 -6.710038  -4.471779  -5.1598663]
[2 3] [-4.0161624 -8.435356  -5.5039196 -5.780836 ]
[3 0] [-2.8253977 -5.827261  -4.6414165 -6.9541206]
[3 1] [-3.3631766 -6.9390755 -5.1094303 -7.009023 ]
[3 2] [-4.025936 -8.207821 -5.72496  -7.142169]
[3 3] [-4.726742 -9.67179  -6.561225 -7.557959]


array([[-0.06195487, -1.01553643, -1.88751781, -2.75387335],
       [-1.02041459, -1.83543396, -2.59688354, -3.44869184],
       [-1.96441412, -2.57555699, -3.28657079, -4.0161624 ],
       [-2.82539773, -3.36317658, -4.02593613, -4.72674179]])

In [440]:
Q = np.array([[ 0.0000,  0.0000,  0.0000,  0.0000],
            [-1.6439, -1.4790, -1.3537, -0.9999],
            [-2.1910, -2.2997, -2.0220, -1.9134],
            [-2.8211, -2.6443, -2.3472, -2.4686],
            [-0.9999, -1.0780, -1.4980, -1.6079],
            [-1.8345, -1.8612, -1.7473, -1.6657],
            [-1.8404, -2.1300, -2.0000, -2.1527],
            [-2.0695, -2.3770, -1.8677, -2.0256],
            [-1.8623, -2.0250, -2.4303, -2.0676],
            [-2.1694, -1.8296, -2.3622, -1.9963],
            [-1.9552, -1.6668, -1.4604, -1.7984],
            [-1.0469, -1.5610, -0.9980, -1.0685],
            [-2.4361, -2.4313, -2.4637, -2.7323],
            [-1.9719, -1.8673, -2.0908, -2.6633],
            [-1.4525, -0.9980, -1.5573, -1.9203],
            [ 0.0000,  0.0000,  0.0000,  0.0000]])
# for i in range(100):
#     np.random.choice(Q)
# s.update_weight()
s = SarsaApproximation(env, 50)
for _ in range(500):
    idx = np.random.randint(Q.shape[0])
    action = np.random.randint(4)
    state = np.array([math.floor(idx/4), idx%4], dtype=int)
    td_target = Q[idx, action]
    s.update_weight(td_target, s.approximate_q(state)[action])
s.convert_Q_to_V()

[0 0] [ 0.07427087  0.07534689 -0.24210957  0.0764288 ]
[0 1] [-1.4274087 -1.5826619 -1.6245518 -1.0674704]
[0 2] [-2.0183673 -2.4098647 -2.3626473 -1.8950099]
[0 3] [-2.2041118 -2.88327   -2.7740805 -2.3001413]
[1 0] [-0.85711956 -0.9855417  -1.4769936  -1.1950191 ]
[1 1] [-1.5759901 -1.7121133 -1.8472254 -1.4143223]
[1 2] [-1.4742002 -1.7901657 -1.8355591 -1.3156751]
[1 3] [-1.5201169 -2.13996   -2.0964162 -1.6064342]
[2 0] [-1.880239  -1.9784862 -2.69813   -2.435831 ]
[2 1] [-1.7874537 -1.7661083 -2.2819705 -1.8166974]
[2 2] [-0.97662973 -1.1425152  -1.3656131  -0.8282993 ]
[2 3] [-0.71616554 -1.2160507  -1.2582939  -0.735638  ]
[3 0] [-2.455274  -2.4101024 -3.2449052 -3.1650648]
[3 1] [-1.9490496 -1.9849489 -2.72313   -2.3830712]
[3 2] [-1.054707  -1.2291914 -1.7129636 -1.2794075]
[3 3] [-0.14162564 -0.51212746 -0.7025315  -0.14466754]


array([[ 0.0764288 , -1.06747043, -1.89500988, -2.20411181],
       [-0.85711956, -1.41432226, -1.31567514, -1.52011693],
       [-1.88023901, -1.76610827, -0.82829928, -0.71616554],
       [-2.41010237, -1.94904959, -1.05470705, -0.14162564]])

[CartPole wiki](https://github.com/openai/gym/wiki/CartPole-v0)

### Sarsa Aproximation (CartPole)

In [264]:
import gym
# env = gym.make('CartPole-v1')
env = gym.make('MountainCar-v0')
import random
class SarsaApproximation(ApproximationBase):
    def __init__(self, env, num_episodes, epsilon=0.1, alpha=0.001, gamma=.9):
        super().__init__(env,
                         env.observation_space.shape[0],
                         env.action_space.n,
                         num_episodes, 
                         epsilon, 
                         alpha, 
                         gamma)

    def _loop(self, episode) -> int:
        policy = self.epsilon_greedy
        done = False
        total_reward, reward = 0, 0
        state = self.env.reset()
        action = policy(state)
        while not done:
            state_, reward, done, _ = self.env.step(action)
            action_ = policy(state_)
            ########## CORE Algorithm #########
            td_target = reward + self.gamma * self.approximate_q(state_, action_)
            self.update_weight(td_target, self.approximate_q(state, action))
#             td_error = td_target - self.approximate_q(state, action)
#             self.weight -= self.alpha * td_error * self.feature(state, action)
            ###################################
            total_reward += reward
            state = state_
            action = action_
        return total_reward
s = SarsaApproximation(env, 50)
s.train(True)
# s.convert_Q_to_V()

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


  result = entry_point.load(False)


AttributeError: 'numpy.ndarray' object has no attribute 'dim'