# Implementación del algoritmo de SARSA 
## (State-Action-Reward-State-Action)

Implementación del algorimo de SARSA usando formato tabular y actualizando los valores de acuerdo a la siguiente fórmula
$$
Q^{\pi}(s,a)=r+\gammaQ^{\pi}(s',a')
$$
Donde $s'$,$a'$ son los estados siguientes y la acción siguiente tomada.
Se usa la técnica de Epsilon greedy para que el agente explore todo el ambiente en un inicio y despúes aprenda de sus experiencias reduciendo esa aleatoriedad del epsilon
Ese valor $\epsilon$ decae conforme el tiempo del episodio


In [23]:
import torch
import gym
import numpy as np
import random

In [63]:
#Consideremos un ambiente discreto
env_name='CartPole-v0'
env=gym.make(env_name)
epsilon_start=0.99
epsilon_end=0.01
max_steps=1000
alpha=0.99
gamma=0.99

In [64]:
#Primero haremos una implementación tabular de SARSA
class QTable:
    def __init__(self,action_space,observation_space):
        self.device="cuda:0" if torch.cuda.is_available() else "cpu"
        self.action_space=action_space
        self.observation_space=observation_space
        self.observation_size=observation_space.shape[0]
        self.action_size=action_space.n #Datos necesarios para el caso de SARSA
        self.q_table=torch.zeros(self.observation_size,self.action_size)


In [65]:
tabla=QTable(env.action_space,env.observation_space)
tabla.q_table.shape

torch.Size([4, 2])

# Entrenamiento del modelo, calculo de los valores Q en forma de tabla

In [73]:
for episode in range(max_steps):
    done=False
    state=env.reset()
    next_state=state
    total_reward,reward,counter=0,0,0
    action=env.action_space.sample()
    next_action=action
    counter=1;
    while not done:
        next_state,reward,done,_=env.step(action)
        epsilon=max(epsilon_end,epsilon_start*(1/100**counter)) #Decaimiento lineal del epsilon
        if random.random() >epsilon:
            next_action=tabla.q_table.max(dim=1)[1][0].item()
            #print(next_action)
        else:
            next_action=env.action_space.sample() #Epsilon greedy 
        tabla.q_table[state,action]+=alpha*(reward+gamma*(tabla.q_table[next_state,next_action]-tabla.q_table[state,action]))
        state=next_state
        action=next_action
        total_reward+=reward 
        counter+=1
    if episode % 10 ==0:
        print("Episode: {} Reward: {} Solved: {}".format(episode,total_reward,(total_reward>195.0)))
    if total_reward>195.0:
        print("Solved after {} episodes".format(episode))
        break
    

Episode: 0 Reward: 10.0 Solved: False
Episode: 10 Reward: 31.0 Solved: False
Episode: 20 Reward: 14.0 Solved: False
Episode: 30 Reward: 41.0 Solved: False
Episode: 40 Reward: 10.0 Solved: False
Episode: 50 Reward: 26.0 Solved: False
Episode: 60 Reward: 11.0 Solved: False
Episode: 70 Reward: 22.0 Solved: False
Episode: 80 Reward: 11.0 Solved: False
Episode: 90 Reward: 23.0 Solved: False
Episode: 100 Reward: 11.0 Solved: False
Episode: 110 Reward: 25.0 Solved: False
Episode: 120 Reward: 12.0 Solved: False
Episode: 130 Reward: 51.0 Solved: False
Episode: 140 Reward: 28.0 Solved: False
Episode: 150 Reward: 10.0 Solved: False
Episode: 160 Reward: 24.0 Solved: False
Episode: 170 Reward: 79.0 Solved: False
Episode: 180 Reward: 9.0 Solved: False
Episode: 190 Reward: 19.0 Solved: False
Episode: 200 Reward: 31.0 Solved: False
Episode: 210 Reward: 22.0 Solved: False
Episode: 220 Reward: 46.0 Solved: False
Episode: 230 Reward: 23.0 Solved: False
Episode: 240 Reward: 10.0 Solved: False
Episode: 250