In [58]:
import gym
from gym import spaces
import pygame
import numpy as np
from panel import state

class SawWorldEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None):
        self.window_size = 512  # The size of the PyGame window

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        
        self.target_location = np.array([20,0, 0,0]) #dmax= 100, thetamax = 90, ddmax = 20, dthetamax = 20
        self.actions = np.array([ [1,1], [1,0], [1,-1],[0,1], [0,0],[0,-1]])
        self.max_values = np.array([23,20,5,5,len(self.actions)])
        self.angle_scale = .3
        self.observation_space = np.array([self.max_values[0],self.max_values[1], self.max_values[2],self.max_values[3]])#d, theta, dd, dtheta, actins
        self.states = self.max_values.copy()
        self.states[0:4] = self.states[0:4]*2+1
        
        

   

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None


        
        
    def _get_obs(self):
        return self.observation_space
    
    def _get_info(self):
        return -1#{"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)}
    
    def reset(self, seed=None, options=None):
        #self.close()
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        # Choose the agent's location uniformly at random
        
        self.observation_space = np.array([self.max_values[0],self.max_values[1], self.max_values[2],self.max_values[3]])#d, theta, dd, dtheta, actins

        # We will sample the target's location randomly until it does not coincide with the agent's location

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info
    
    def step(self, action):
        terminated = False
        reward = 0
        
        self.observation_space = self.move_saw(self.observation_space, action)

        terminated, reward = self.find_terminated(self.observation_space)



        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, terminated, False, info
    
    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        pass
        
    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
    
    def move_saw(self, state, action):

        dd = int(state[2] + self.angle_scale*(state[1]-self.max_values[1])* action[0]) #dd
        dtheta = state[3] + action[1] #dtheta


        if dd > self.max_values[2]:
            state[2] = self.max_values[2]
        elif dd < -self.max_values[2]:
            state[2] = self.max_values[2]
        else:
            state[2] = dd

        if dtheta > self.max_values[3]:
            state[3] = self.max_values[3]
        elif dtheta < -self.max_values[3]:
            state[3] = self.max_values[3]
        else:
            state[3] = dtheta

        state[0] = state[0] + state[2]
        state[1] = state[1] + state[3]
        return state

    def find_terminated(self, state):
        reward = -1
        terminated = False
        
        if state[0] <= 0 or state[0] >= self.states[0]:
            terminated = True
        elif state[1] <= 0 or state[1] >= self.states[1]:
            terminated = True
        elif np.array_equal(state, self.target_location):
            terminated = True
            reward = 0

        return terminated, reward


In [61]:


def evaluate(policy, values ,env):
    gamma = 1
    
    nextValues = np.zeros(np.shape(values))
        
    for i in range(env.states[0]):
        #print(i)
        for j in range(env.states[1]):
            for k in range(env.states[2]):
                for l in range(env.states[3]):
                    state = np.array([i,j,k,l])
                    stateval = 0
                    terminated, reward = env.find_terminated(state)

                    if terminated:
                        break
                    for m in range(0,env.max_values[4]):
                        
                        stateNext = env.move_saw(state.copy(), env.actions[m])
                        terminated, reward = env.find_terminated(stateNext)
                        if terminated:
                            stateNext = state
                        
                        stateval += policy[i][j][k][l][m] * (reward + gamma * values[stateNext[0]][stateNext[1]][stateNext[2]][stateNext[3]])
                        
                    if not terminated:
                        nextValues[i][j][k][l] = stateval
                        
    return nextValues
        

        
def iterate(policy, values, env):
    
    
    
    nextpolicy = np.zeros(np.shape(policy))
    
    for i in range(env.states[0]):
        #print(i)
        for j in range(env.states[1]):
            for k in range(env.states[2]):
                for l in range(env.states[3]):
                    state = np.array([i,j,k,l])
                    stateval = []
                    terminated, reward = env.find_terminated(state)

                    if terminated:
                        break
                    for m in range(0,env.max_values[4]):
                        
                        s = env.move_saw(state.copy(), env.actions[m])
                        terminated, reward = env.find_terminated(s)
                        if terminated:
                            s = state.copy()
                        
                        stateval.append(values[s[0]][s[1]][s[2]][s[3]])
                    #print(stateval)
                        
                    
                    maxval = stateval[np.argmax(stateval)]
                    #print(maxval)
                    #print("duh", state, maxval)
                    #print(allNext)
                    count = (stateval == maxval).sum()
                    for m in range(0,env.max_values[4]):
                        nextpolicy[i][j][k][l][m] = 1/count
    return nextpolicy
                        

def knownGridworld():
    env = SawWorldEnv(render_mode = "rgb_array")


    states = env.states

    values = np.zeros(states[0:4])

    policy = np.ones(states) * 1/6
    #nextvals = evaluate(policy, values, env)
    
    for y in range (10):
        nextvals = np.zeros(states[0:4])
        for x in range(100):
            values = nextvals.copy()
            #print(nextvals)
            nextvals = evaluate(policy, values, env)
            delta = np.sum(np.abs(nextvals - values))
            print("eval: ",delta)
            if delta <.001:
                break
        
        newpolicy = iterate(policy, nextvals, env)
        delpol = np.sum(np.abs(newpolicy - policy))
        print("iter: ",delpol)
        if delpol < .1:
            break
        policy = newpolicy
    
    env.close()
    return newpolicy

def test():
    env = SawWorldEnv(render_mode = "rgb_array")


    states = env.states

    values = np.zeros(states[0:4])

    policy = np.ones(states) * 1/6
    nextvals = evaluate(policy, values, env)
    newpolicy = iterate(policy, nextvals, env)
knownGridworld()
#test()

eval:  160497.5
eval:  135860.72222222222
eval:  113492.62962962959
eval:  93017.10956790124
eval:  74405.30979938273
eval:  57640.06155692727
eval:  42708.26626443185
eval:  30947.63545655672
eval:  22983.595193484685
eval:  17190.689961806074
eval:  12838.79501313814
eval:  9565.395132464902
eval:  7107.000653606765
eval:  5268.748667399196
eval:  3902.2700682836653
eval:  2891.121252523325
eval:  2145.265993194608
eval:  1596.3095020892808
eval:  1192.794613281912
eval:  896.2849948305932
eval:  678.2331069305284
eval:  517.5488656287743
eval:  398.7267659235258
eval:  310.41952578326016
eval:  244.35743543771932
eval:  194.5349295997668
eval:  156.6023484430215
eval:  127.41452717704368
eval:  104.69813708682412
eval:  86.80833120800328
eval:  72.55192310050899
eval:  61.05970805028184
eval:  51.694755072690505
eval:  43.98681670763913
eval:  37.58556062714502
eval:  32.22727887436794
eval:  27.711197414495754
eval:  23.882597454261237
eval:  20.62075764128296
eval:  17.83030427824