In [None]:
import gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

np.set_printoptions(precision=3)

# n-step SARSA

In [None]:
def returns(rewards, gamma):
    """Discounted cumulative return on a given reward sequence"""
    return sum([(gamma ** i) * rewards[i] for i in range(len(rewards))])
    
class N_SARSA():
    
    def __init__(self, n=3, choice=[0,1,2,3], shape=(SIZE, SIZE, 4), e=0.8, y=0.9, lr=1e-1,):
        
        self.n  = n
        self.states = deque(maxlen=n)
        self.actions = deque(maxlen=n)
        self.rewards = deque(maxlen=n)
        
        self.e  = e  # epsilon
        self.y  = y  # gamma
        self.lr = lr  # learning rate
        self.q = np.random.randn(*shape)  # q value array
        self.choice = choice  # action space
        
        self.test_mode = False
    
    def action(self, state):
        
        if not self.test_mode and np.random.rand() <= self.e:
            action = np.random.choice(self.choice)
        else:
            action = np.argmax(self.q[state])
        self.states.append(state)
        self.actions.append(action)
        return action
    
    def observe(self, reward):
        self.rewards.append(reward)
    
    def learn(self, done):
        
        if not done:
            if len(self.rewards) == self.n:
                state, action = self.states[0], self.actions[0]
                state_, action_  = self.states[-1], self.actions[-1]
                
                g = returns(self.rewards, self.gamma) + (self.y ** self.n) * self.q[state_, action_]
                self.q[state, action] += self.lr * (g - self.q[state, action])
            
            # not enough steps have been recorded
            else:
                pass
        
        else:
            for i in range(self.n):
                state, action, _ = self.cache[i]
                if len(self.cache) == self.n:
                    g = sum([self.y**i * self.cache[i+j][2] for i in range(self.n-j)])
                    self.q[state, action] += self.lr * (g - self.q[state, action])

In [None]:
def episode3(env, agent):
    
    state = int2loc(env.reset())
    action = agent.action(state)
    
    for _ in range(MAX_STEPS):
        
        # take action & observe
        state_, reward, done, _ = env.step(action)
        agent.observe(reward)
        state_ = int2loc(state_)
        
        # choose next action
        action_ = agent.action(state_)
        
        # update q value
        agent.learn(done);
        
        if done:
            return reward
        
        # iter to next step
        state = state_
        action = action_

In [None]:
EPISODES = 10000
MAX_STEPS = 100  # max steps before terminating an episode

agent = N_SARSA(n=3, e=1, lr=0.8)
returns = []

for i in range(EPISODES):
    
    episode3(env, agent);
    
    if agent.e >= 0.2:
        agent.e *= 0.996
        
    if i % 20 == 0:
        agent.test_mode = True
        r = [episode3(env, agent) for _ in range(5)]
        returns.append(sum(r) / len(r))
        agent.test_mode = False

In [None]:
plt.scatter(range(len(returns)), returns);

In [None]:
arrows = np.array(['←','↓','→','↑'])
np.array([arrows[np.argmax(agent.q, axis=2)[i,j]] 
          if (env.desc[i,j] == env.desc[0,0]) or (env.desc[i,j] == env.desc[0,1]) else env.desc[i,j]
          for i in range(4) for j in range(4)
          ]).reshape(4,4)