In [None]:
import numpy as np

## SARSA(On-Policy TD Control)

same in "2_MC Control agent.ipynb"

In [None]:
class Env:
    def __init__(self):
        self.grid_width = 5
        self.grid_height = self.grid_width
        self.action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]     # U, D, L, R
        self.gtriangle1 = [1, 2]
        self.gtriangle2 = [2, 1]
        self.goal = [2, 2]
        
    def step(self, state, action):
        x, y = state
        
        # get next state by action
        x+= action[0]
        y+= action[1]
        
        if x < 0 :
            x = 0
        elif x > (self.grid_width-1) :
            x = (self.grid_width-1)

        if y < 0 :
            y = 0
        elif y > (self.grid_width-1) :
            y = (self.grid_width-1)
        
        next_state = [x, y]
        
        # reward 
        if next_state == self.gtriangle1 or next_state == self.gtriangle2:
            reward = -1
            done = True
        elif next_state == self.goal:
            reward = 1
            done = True
        else:
            reward = 0
            done = False
        
        return next_state, reward, done
    
    def reset(self):
        return [0, 0]

In [None]:
class SARSA_agent:
    def __init__(self):
        self.action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        self.action_text= ['U', 'D', 'L', 'R']
        self.grid_width = 5
        self.grid_height = self.grid_width
        self.Qtable = np.zeros((self.grid_width, self.grid_height, len(self.action_grid)))
        self.e = .1
        self.learning_rate = .01
        self.discount_factor = .95
        self.memory=[]
    
    def get_action(self, state):
        # with prob.ε take random action
        if np.random.randn() <  self.e :
            idx = np.random.choice(len(self.action_grid),1)[0]
        else :
            Qvalues = self.Qtable[tuple(state)]
            maxQ = np.amax(Qvalues)
            tie_Qchecker = np.where(Qvalues==maxQ)[0]
            
            # if tie max value, get random
            if len(tie_Qchecker) > 1:
                idx = np.random.choice(tie_Qchecker, 1)[0]
            else :
                idx = np.argmax(Qvalues)
                
        action = self.action_grid[idx]
        return action    
        
    # using First visit MC    
    def update(self, state, action, reward, next_state, next_action):
        action_idx = self.action_grid.index(action)
        next_action_idx = self.action_grid.index(next_action)
        current_Q = self.Qtable[tuple(state)][action_idx]
        next_Q = self.Qtable[tuple(next_state)][next_action_idx]
        updated_Q = current_Q + self.learning_rate*((reward + self.discount_factor*next_Q)-current_Q)
        self.Qtable[tuple(state)][action_idx] = updated_Q
        
    def save_actionseq(self, action_sequence, action):
        idx = self.action_grid.index(action)
        action_sequence.append(self.action_text[idx])

In [None]:
if __name__ =='__main__':
    env = Env()
    agent = SARSA_agent()
    total_episode = 10000
    sr = 0
    
    for episode in range(total_episode):
        action_sequence=[]
        total_reward = 0
        walk = 0
        
        # initial state, action, done
        state = env.reset()
        action = agent.get_action(state)
        done = False
        
        while not done:  
            agent.save_actionseq(action_sequence, action)
            
            # next state, action
            next_state, reward, done = env.step(state, action)
            next_action = agent.get_action(next_state)

            # update Qtable
            agent.update(state, action, reward, next_state, next_action)
            
            total_reward += reward
            state = next_state
            action = agent.get_action(state)
            
            if done:
                if episode % 100 == 0:
                    print('finished at', next_state)
                    print('episode :{}, The number of step:{}\n The sequence of action is:\
                          {}\nThe total reward is: {}\n'.format(episode, walk, action_sequence, total_reward))
                if state == env.goal:
                    sr += 1
                break

            
            
            
print('The accuracy :', sr/total_episode*100, '%')

finished at [1, 2]
episode :0, The number of step:0
 The sequence of action is:                          ['D', 'R', 'R']
The total reward is: -1

finished at [2, 2]
episode :100, The number of step:0
 The sequence of action is:                          ['L', 'U', 'U', 'L', 'U', 'U', 'R', 'L', 'U', 'U', 'U', 'L', 'R', 'R', 'R', 'R', 'D', 'L', 'D', 'D', 'U', 'R', 'U', 'L', 'U', 'U', 'R', 'D', 'R', 'L', 'D', 'L']
The total reward is: 1

finished at [1, 2]
episode :200, The number of step:0
 The sequence of action is:                          ['L', 'U', 'L', 'D', 'R', 'U', 'L', 'U', 'D', 'U', 'D', 'R', 'R']
The total reward is: -1

finished at [2, 2]
episode :300, The number of step:0
 The sequence of action is:                          ['U', 'U', 'U', 'D', 'U', 'R', 'L', 'U', 'R', 'L', 'R', 'L', 'L', 'U', 'L', 'U', 'L', 'U', 'U', 'U', 'R', 'L', 'D', 'U', 'R', 'U', 'L', 'L', 'U', 'R', 'L', 'U', 'U', 'U', 'R', 'L', 'U', 'U', 'U', 'R', 'L', 'U', 'U', 'U', 'U', 'U', 'U', 'U', 'U', 'D', 'D', '

In [None]:
agent.Qtable

array([[[-0.03023882, -0.02467009, -0.02837411, -0.06668559],
        [-0.07930537, -0.26357927, -0.02696578, -0.03617608],
        [-0.07255643, -0.9848687 , -0.07104253,  0.21381225],
        [ 0.16530891,  0.27939033, -0.02512164,  0.08539015],
        [ 0.03108897,  0.17120166,  0.0884156 ,  0.03187485]],

       [[-0.02736856,  0.00101901, -0.03508546, -0.28638021],
        [-0.07707683, -0.99483772, -0.03218098, -0.99499104],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.16530639,  0.71922487, -0.97498984,  0.17194065],
        [ 0.05434822,  0.26532044,  0.28438648,  0.11537532]],

       [[-0.01787888,  0.31956485,  0.01884917, -0.99999969],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.28976468,  0.49870141,  1.        ,  0.40044101],
        [ 0.1288828 ,  0.12448834,  0.70403931,  0.30455484]],

       [[ 0.03141885,  0.40551945,  0.31264187,  0.3092279 ],
  

In [None]:
agent.Qtable[0,1]

array([-0.07930537, -0.26357927, -0.02696578, -0.03617608])

In [None]:
agent.Qtable[1,1]

array([-0.07707683, -0.99483772, -0.03218098, -0.99499104])

In [None]:
agent.Qtable[1,0]

array([-0.02736856,  0.00101901, -0.03508546, -0.28638021])