# Temporal Difference Control: SARSA
An implementation of SARSA control using a gridworld.

The code is intentionally not optimal in order to increase legibility and make it easier to understand(TD(0) update is not performed "online" but after finishing the episode)

The gridworld has the shape(3,4) with a winning state "w"(0,3), and a lossing state "l"(1,3), a non valid state "x"(2,1) and a start state s(3,0)

|  |  |  |  |
|---|---|---|---|
|  |  |  | w |
|  |  |  | l |
|  | x |  |  |
| s |  |  |  |

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

import grid_world

### Disccount factor and step size

In [2]:

GAMMA = 0.9
ALPHA =0.1

### Auxiliary function to display the values of a policy after finishing iterative policy evaluation

In [3]:
def print_values(V,grid):
    for i in range(grid.width):
        print("--------------------------")
        for j in range(grid.height):
            v = V.get((i,j),0)
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")
        print("")

### Auxiliary function to display a stochastic policy

In [4]:
def print_policy(P,grid):
    for i in range(grid.width):
        print("---------------------------")
        for j in range(grid.height):
            a = P.get((i,j),' ')
            if isinstance(a,dict):
                a = list(a)[0]
            print("  %s  |" % a, end="")
        print("")

### From or defined grid world file, import a negative grid ,retrieve all actions and states and print grid rewards
Negative grid is used to encourage the agent to find a shortest path to the goal

In [5]:
grid = grid_world.Grid.standard_grid()
states = grid.all_states()
actions = list(set([action   for action_tup in grid.actions.values() for action in action_tup]))

In [6]:
def argmax_dict(dictionary):
    # returns the argmax key and the max value from a dictionary
    # will be used for policy improvement from Q
    max_key = None
    max_val = float("-inf")
    
    for k,v in dictionary.items():
        if v > max_val:
            max_val = v
            max_key = k
            
    return max_key,max_val
        
argmax_dict({"a":1,"b":2})

('b', 2)

In [7]:
actions

['D', 'U', 'R', 'L']

In [8]:
def epsilon_greedy_action(Q,state,epsilon=0.1):
    # choose an action using epsilon-greedy strategy
    probability = np.random.random()
    result = 0
    
    if probability < epsilon:
        #explore
        result = np.random.choice(actions)
    else: 
        #exploit
        result = argmax_dict(Q[state])[0]
        
    return result

In [9]:
grid = grid_world.Grid.standard_grid()

In [10]:
print("Rewards of grid")
print_values(grid.rewards,grid)

Rewards of grid
--------------------------
 0.00| 0.00| 0.00| 1.00|
--------------------------
 0.00| 0.00| 0.00|-1.00|
--------------------------
 0.00| 0.00| 0.00| 0.00|


### Initialize  policy

In [11]:
policy = {(2,0):'U',
         (1,0):'U',
         (0,0):'R',
         (0,1):'R',
         (0,2):'R',
         (1,2):'R',
         (2,1):'R',
         (2,2):'R',
         (2,3):'U'}

print_policy(policy,grid)

---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  R  |     |
---------------------------
  U  |  R  |  R  |  U  |


In [12]:
def sarsa_control(grid,policy,episodes,gamma =1,alpha=1):
    Q = dict()
    
    for state in policy.keys():
        Q[state] = dict()
        for action in actions:
            Q[state][action] = np.random.rand()
            

    for episode in range(1,episodes +1):
        epsilon  = 0.5/episode
        finished = False
        
        s = (2,0)
        grid.set_state(s)
        a = epsilon_greedy_action(Q,s,epsilon)
        
        while not finished:
            r = grid.move(a)
            
            if grid.game_over():
                finished = True
                continue
                
            s1 = grid.current_state()
            a1 =epsilon_greedy_action(Q,s1,epsilon)
            
            
            Q[s][a] = Q[s][a] + alpha*(r +(gamma*Q[s1][a1]) - Q[s][a])
            
            s = s1
            a = a1
            
    for s in policy.keys():
        state_greedy_action = argmax_dict(Q[s])
        
        if not state_greedy_action[0] is None:
            policy[s] = state_greedy_action[0]
            
    return policy,Q

policy,Q = sarsa_control(grid,policy,20000,GAMMA,ALPHA)

In [13]:
policy

{(0, 0): 'R',
 (0, 1): 'R',
 (0, 2): 'R',
 (1, 0): 'U',
 (1, 2): 'U',
 (2, 0): 'U',
 (2, 1): 'R',
 (2, 2): 'U',
 (2, 3): 'L'}

In [14]:
print("Policy")
print_policy(policy,grid)

Policy
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |


In [15]:
V = defaultdict(lambda:0)
for state in policy.keys():
    V[state] = Q[state][policy[state]]
    print("state  |policy action| state value")
    print(state,"|      ",policy[state] , "    |", V[state] )
    

state  |policy action| state value
(2, 0) |       U     | 0.2537105974842908
state  |policy action| state value
(1, 0) |       U     | 0.28190066387143453
state  |policy action| state value
(0, 0) |       R     | 0.31322295985714976
state  |policy action| state value
(0, 1) |       R     | 0.3480255109523889
state  |policy action| state value
(0, 2) |       R     | 0.38669501216932123
state  |policy action| state value
(1, 2) |       U     | 0.3480255109523889
state  |policy action| state value
(2, 1) |       R     | 0.28190066387143453
state  |policy action| state value
(2, 2) |       U     | 0.31322295985714976
state  |policy action| state value
(2, 3) |       L     | 0.31054643107300195


In [16]:
print_values(grid=grid,V=V)

--------------------------
 0.31| 0.35| 0.39| 0.00|
--------------------------
 0.28| 0.00| 0.35| 0.00|
--------------------------
 0.25| 0.28| 0.31| 0.31|


## Conclusions
* Sarsa contol can be used to  to "online" learning from experience(monte carlo requires end of episode to perform learning)
* it required many more iterations to converge(every update implies a more granular observation of state,action)
* it found a good policy