In [None]:
# SARSA TAXI
import gym
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import random

env = gym.make("Taxi-v3").env
Q = defaultdict(lambda : 0)
n = defaultdict(lambda : 1) 

actionspace = range(env.action_space.n)
greedy_action = lambda s : max(actionspace, key=lambda a : Q[(s,a)])
epsilon = 0.1
gamma = 0.9

episode_scores = []
for _ in range(100000):
    state = env.reset()
    current_score = 0.
    for t in range(1000):
        if epsilon > random.random() :
            action = env.action_space.sample()
        else :
            action = greedy_action(state)

        # SARSA
        if t > 0 :
            Q[(prev_state,prev_action)] = Q[(prev_state,prev_action)] + 1./n[(prev_state,prev_action)] * ( reward + gamma * Q[(state, action)] - Q[(prev_state,prev_action)] )

        next_state, reward, done, info = env.step(action)
        current_score += reward

        if done :
            Q[(prev_state,prev_action)] = Q[(prev_state,prev_action)] + 1./n[(state,action)] * ( reward - Q[(prev_state,prev_action)] )
            break

        prev_state, state, prev_action = state, next_state, action

    episode_scores.append(current_score)


plt.plot(episode_scores)
plt.xlabel('Episode')
plt.ylabel('Episode Reward')
plt.show()

In [None]:
# SARSA FOR OUR for Cliff Walking Environment
import gym
import numpy as np
import random
import math
from collections import defaultdict, deque
import matplotlib.pyplot as plt

env = gym.make('CliffWalking-v0')

# Returns updated Q-value for the most last experience
def update_Q_sarsa(alpha, gamma, Q, state, action, reward, next_state=None, next_action=None):
    current = Q[state][action]
    Qsa_next = Q[next_state][next_action] if next_state is not None else 0    
    target = reward + (gamma * Qsa_next)
    new_value = current + (alpha * (target - current))
    return new_value


# Action based on epsilon
def epsilon_greedy(Q, state, nA, eps):
    if random.random() > eps: # select greedy action with probability epsilon
        return np.argmax(Q[state])
    else:                     # otherwise, select an action randomly
        return random.choice(np.arange(env.action_space.n))

In [None]:
# Main SARSA Algorithm for Cliff Walking Environment

def sarsa(env, num_episodes, alpha, gamma=1.0, plot_every=100):
    no_actions = env.action_space.n 
    Q = defaultdict(lambda: np.zeros(no_actions))
    
    # monitor performance
    tmp_scores = deque(maxlen=plot_every)
    avg_scores = deque(maxlen=num_episodes)
    
    for episode in range(1, num_episodes+1):
        if episode % 100 == 0:
            print("\rEpisode {}/{}".format(episode, num_episodes), end="")   
        score = 0
        state = env.reset()
        
        eps = 1.0 / episode                              
        action = epsilon_greedy(Q, state, no_action, eps)  
        
        while True:
            next_state, reward, done, info = env.step(action) 
            score += reward                          
            if not done:
                next_action = epsilon_greedy(Q, next_state, nA, eps) 
                Q[state][action] = update_Q_sarsa(alpha, gamma, Q, \
                                                  state, action, reward, next_state, next_action)
                
                state = next_state  
                action = next_action 
            if done:
                Q[state][action] = update_Q_sarsa(alpha, gamma, Q, \
                                                  state, action, reward)
                tmp_scores.append(score) 
                break
        if (episode % plot_every == 0):
            avg_scores.append(np.mean(tmp_scores))

    # plot performance
    plt.plot(np.linspace(0,num_episodes,len(avg_scores),endpoint=False), np.asarray(avg_scores))
    plt.xlabel('Episode Number')
    plt.ylabel('Average Reward (Over Next %d Episodes)' % plot_every)
    plt.show()
    print(('Best Average Reward over %d Episodes: ' % plot_every), np.max(avg_scores))    
    return Q

In [None]:
Q_sarsa = sarsa(env, 5000, .01)

In [None]:
policy_sarsa = np.array([np.argmax(Q_sarsa[key]) if key in Q_sarsa else -1 for key in np.arange(48)]).reshape(4,12)
print("\nEstimated Optimal Policy (UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3, N/A = -1):")
print(policy_sarsa)