In [1]:
LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3


class GridWorldEnv:
    """
        Has the following members
        - s: the current state
        - nA: number of actions
        - nS: number of states
        - T: transitions, dictionary where
                          T[s][a] == (next_state, reward, done)
                          done=True means that the episode is over after this step


    """

    def __init__(self):
        self.T = {
            0: {LEFT: (0, 0, False), RIGHT: (1, 0, False), UP: (0, 0, False), DOWN: (2, 0, False)},
            1: {LEFT: (0, 0, False), RIGHT: (1, 0, False), UP: (1, 0, False), DOWN: (3, 0, False)},
            2: {LEFT: (2, 0, False), RIGHT: (3, 0, False), UP: (0, 0, False), DOWN: (2, 0, False)},
            3: {LEFT: (2, 0, False), RIGHT: (4, 0, False), UP: (1, 0, False), DOWN: (5, 0, False)},
            4: {LEFT: (4, -5, True), RIGHT: (4, -5, True), UP: (4, -5, True), DOWN: (4, -5, True)},
            5: {LEFT: (5, 0, False), RIGHT: (6, 0, False), UP: (3, 0, False), DOWN: (5, 0, False)},
            6: {LEFT: (6, 5, True), RIGHT: (6, 5, True), UP: (6, 5, True), DOWN: (6, 5, True)}
        }
        self.s = 0
        self.nA = 4
        self.nS = len(self.T.keys())

    def reset(self):
        self.s = 0
        return self.s

    def move(self, a):
        next_state, reward, done = self.T[self.s][a]
        #self.s = next_state
        return next_state, reward, done


In [2]:
import numpy as np

In [3]:
class TD_solver:

    def __init__(self, estimator):
        self.estimator = estimator # SARSA OR QL
        self.grid = GridWorldEnv()
        self.q = np.zeros([self.grid.nS, self.grid.nA]) # action-value function array, [states,actions]
        self.alpha = 1
        self.gamma = 0.9
        self.epsilon = 1
        self.pi = np.ones([self.grid.nS, self.grid.nA])
        
    def reset(self):
        self.__init__(self.estimator)
        
    def cal_Pi(self,epsilon):
        random = self.epsilon/self.grid.nA
        best = 1 - self.epsilon + random

        for state in range(self.q.shape[0]):
            bestAction = np.argmax(self.q[state]) #choose the best action in every state according to q
            self.pi[state][:] = random
            self.pi[state][bestAction] = best
        
    def pickAction(self, s):
        a = np.random.choice([i for i in range(self.grid.nA)],\
                            p = self.pi[s])  #pick action randomly with weights.
        return a
    
    def update(self):
        self.alpha = self.alpha * 0.9996
        self.epsilon = self.epsilon * 0.9992 # decrease epsilon and alpha slowly
#         print(self.alpha," ",self.epsilon)
        
    def need_termination(self, q0, q, optP, optP0):
#         print("\ndiff(P0,P) \n=",(optP0-optP),"\n")
#         print("\ndiff(q0,q) \n=",(q0-q),"\n")
        return np.all(abs(q0-q) < 1e-10) & np.all(abs(optP0-optP) < 1e-10)
      
    def SARSA_estimator(self, s0, a0):
        return self.gamma * self.q[s0,a0] # choose the next action, on-policy
    
    def QL_estimator(self, s0):
        return self.gamma * self.q[s0,np.argmax(self.q[s0])] # don't choose the next action, off-policy
    
    def solve(self):
        self.reset()
        Terminated = False
        self.cal_Pi(self.epsilon)
        q0 = self.q.copy()
        count = 0
        optP0 = np.array([0 for i in range(7)])
        while not(Terminated):
            route = []
            self.grid.reset()
            s0 = self.grid.s
            a0 = self.pickAction(self.grid.s)
#             print(i,"th round q =\n",self.q)#
            done = False
            count = count + 1
#             print("\n",i,"th round: q is\n",self.q)
           
            while not(done):
                self.grid.s = s0
                route.append(s0)
                a = a0
#                 print("state is ",self.grid.s)#
                s0, reward, done = self.grid.move(a)
                a0 = self.pickAction(s0)  # for Q-Learning not used
                if done:
                     self.q[self.grid.s,:] = self.q[self.grid.s,a] + self.alpha*\
                                            (reward - self.q[self.grid.s,a])
                elif self.estimator == "SARSA":
                        self.q[self.grid.s,a] = self.q[self.grid.s,a] + self.alpha*\
                                                (reward + self.SARSA_estimator(s0,a0) - self.q[self.grid.s,a]) 
                else:
                       self.q[self.grid.s,a] = self.q[self.grid.s,a] + self.alpha*\
                                               (reward + self.QL_estimator(s0) - self.q[self.grid.s,a]) 
#                 if self.q[s0,a0] != 0:
#                     print("current state is ", self.grid.s)
#                     print("done is ",done, 
#                           "\nreward is",reward,"\nq[s,a] is",self.q[self.grid.s,a] )
#                     print(self.q)
                self.cal_Pi(self.epsilon)
    
                self.update()
#             print("afterwards q is\n",self.q)
            optP =np.array([ np.argmax(self.q[i]) for i in range(self.grid.nS)])
            Terminated = self.need_termination(q0,self.q, optP, optP0)
            optP0 = optP.copy()
            q0 = self.q.copy()  
#             print(Terminated)
#             print(*route, sep = ' -> ')
#             print(self.q)
        
        print("there are totally ", count,"times")

        return self.pi, self.q, optP

    

# test
e = 5
print("\nQ-Learning:") 
s1 = TD_solver("QL")
for i in range(e):
#     print("state-value function:\n ",s1.solve()[0])
#     print("action-value function: \n",s1.solve()[1])
    print("optimal policy is: ",s1.solve()[2],"\n")
    
print("\nSARSA:")
s2 = TD_solver("SARSA")
for i in range(e):
#     print("state-value function: \n ",s2.solve()[0])
#     print("action-value function: \n",s2.solve()[1])
    print("optimal policy is: ",s2.solve()[2],"\n")



Q-Learning:
there are totally  10 times
optimal policy is:  [0 0 0 0 0 0 0] 

there are totally  37 times
optimal policy is:  [2 1 2 1 0 2 0] 

there are totally  40 times
optimal policy is:  [2 1 2 1 0 2 0] 

there are totally  29 times
optimal policy is:  [1 1 2 1 0 2 0] 

there are totally  32 times
optimal policy is:  [2 1 2 1 0 2 0] 


SARSA:
there are totally  1580 times
optimal policy is:  [1 0 2 1 0 2 0] 

there are totally  2028 times
optimal policy is:  [1 0 2 1 0 2 0] 

there are totally  1116 times
optimal policy is:  [2 1 3 1 0 2 0] 

there are totally  1473 times
optimal policy is:  [1 0 2 1 0 2 0] 

there are totally  4194 times
optimal policy is:  [2 1 3 1 0 2 0] 

