In [34]:
import sys
sys.path.append("../lib/myenv")
from gridworld import gridworld
import numpy as np 
import pygame
import random
import matplotlib.pyplot as plt 


# sample from categorical distribution
def sample_categorical(probabilities):
    return random.choices(range(len(probabilities)), probabilities)[0]

# gridworld dimension
dim=10
# gamma discounting factor 
gamma=1

# delta for policy evaluation
delta=0

#env variable 
gw=gridworld(dim)

# State value (V) is an array of dimension nxn where n is the gridworld size
V=np.random.rand(dim,dim)
# V=np.zeros((dim,dim))


#Transition matrix
Pss=np.ones((dim,dim))

# The action mapping for human readibility
# 0:right
# 1:Left
# 2:UP
# 3:Down 

# The agent is placed at (0,0) and value function are initiliazed to a zero array
gw.reset()

((0, 0, 9, 9), 'initialisation')

In [35]:
def policy(dim):
    """
    Initial policy with 1/4 chance for each action
    Input: dimension of gridworld
    output: random uniform policy
    """
    pi={}
    for i in range(dim):
        for j in range(dim):
            pi[(i,j)]=[0.25]*4

    return pi

In [36]:
# Initial random uniform policy
pi=policy(dim)
sliced_policy = dict(list(pi.items())[:5])
sliced_policy

{(0, 0): [0.25, 0.25, 0.25, 0.25],
 (0, 1): [0.25, 0.25, 0.25, 0.25],
 (0, 2): [0.25, 0.25, 0.25, 0.25],
 (0, 3): [0.25, 0.25, 0.25, 0.25],
 (0, 4): [0.25, 0.25, 0.25, 0.25]}

In [37]:
def policy_evaluation(pi,V,iteration_num=1):
    delta=0
    #looping over all states
    for _ in range(iteration_num):
        Vc=V.copy()
        for i in range(dim):
            for j in range(dim):
                gw.reset()
                listupdate=[]
                gw.agent=(i,j)

                if (i,j)!=(0,0):
                   for a in range(4):
                        o,r,_,_,_,=gw.step(a)
                        k,l=o[:2]
                        listupdate.append(pi[(i,j)][a]*(r+(gamma*Vc[k,l])))
                    
                    #update Vk+1 ##### the update here max is different from policy iteration (mean
                   Vc[i,j]=np.max(listupdate)
                
                #calculate delta
                # print("state: ",(i,j))
                # print("V",V[i,j])
                # print("Vc",Vc[i,j])
                # print("Vc-V",np.abs(Vc[i,j]-V[i,j]))
                # delta=max(delta,np.abs(V[i,j]-Vc[i,j]))
                # print("delta ",delta)

        
        V=Vc
        # if delta <0.1:
        #     break

    return V

In [38]:
def policy_improvement(Vpi):
# policy iteration 
    for i in range(dim):
                for j in range(dim):
                    listofall=[-100]*4
                    r=-1
                    if (i,j)==gw.target:
                        r=0
                    
                    # listofall=[Vpi[i+1,j],Vpi[i-1,j],Vpi[i,j-1],Vpi[i,j+1]]
                    if i+1<dim:
                        listofall[0]=pi[(i,j)][0]*(r+(gamma*Vpi[i+1,j]))
                    if j+1<dim:
                        listofall[3]=pi[(i,j)][3]*(r+(gamma*Vpi[i,j+1])) 
                    if i-1>dim:
                        listofall[1]=pi[(i,j)][1]*(r+(gamma*Vpi[i-1,j]))
                    if j-1>dim:
                        listofall[2]=pi[(i,j)][1]*(r+(gamma*Vpi[i,j-1]))


                    a=np.argmax(listofall)

                    pi[(i,j)]=[1 if i == a else 0 for i in range(4)]

    return pi

In [39]:
# value iteration algo 
Vpi=policy_evaluation(pi,V,iteration_num=3)
pi=policy_improvement(Vpi)


In [40]:
# The loop below will test the policy iteration algorithm
# start position
gw.reset()
o=(0,0,dim-1,dim-1)
terminated=False
rewards=[]
while not terminated:

    a=sample_categorical(pi[o[:2]])
    print(a)
    o,r,terminated,_,_,=gw.step(a)
    rewards.append(r)
    gw.render(np.round(Vpi, 3),mode='human')
    print(o[:2])
    print("reward is: ",r)

print("total rewards: ", np.sum(rewards))

pygame.quit()

0
(1, 0)
reward is:  -1
3
(1, 1)
reward is:  -1
0
(2, 1)
reward is:  -1
0
(3, 1)
reward is:  -1
0
(4, 1)
reward is:  -1
0
(5, 1)
reward is:  -1
0
(6, 1)
reward is:  -1
0
(7, 1)
reward is:  -1
0
(8, 1)
reward is:  -1
3
(8, 2)
reward is:  -1
0
(9, 2)
reward is:  -1
3
(9, 3)
reward is:  -1
3
(9, 4)
reward is:  -1
3
(9, 5)
reward is:  -1
3
(9, 6)
reward is:  -1
3
(9, 7)
reward is:  -1
3
(9, 8)
reward is:  -1
3
(9, 9)
reward is:  0
total rewards:  -17
