<a href="https://colab.research.google.com/github/orattanathon/RL/blob/main/rescueWombat_env.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import numpy as np
import gym
from gym import spaces
from contextlib import closing
from io import StringIO
from gym import utils
import time
from gym.envs.toy_text import discrete
from gym.envs.registration import register


MAP = [
        "SPPFP",
        "PBPFP",
        "PPWBF",
        "PBPPP",
        "PFPPG"
       ]


In [22]:
env_dict = gym.envs.registration.registry.env_specs.copy()
for env in env_dict:
    if 'rescueWombat-v0' in env:
        print("Remove {} from registry".format(env))
        del gym.envs.registration.registry.env_specs[env]


Remove rescueWombat-v0 from registry


In [24]:
class rescueWombat(discrete.DiscreteEnv):
 
     metadata = {'render.modes': ['console']}
     def __init__(self):

        self.desc = np.asarray(MAP, dtype='c')
        num_states = 50
        num_actions = 5
        nrow = 5
        ncol = 5
        max_row = nrow - 1
        max_col = ncol - 1
        picked = False

        isd = np.array(self.desc == b'S').astype('float64').ravel()

        P = {state: {action: []
                     for action in range(num_actions)} for state in range(num_states)}

        def s_encode(row,col,picked):
          if not picked:
             i = row
             i *= ncol
             i += col
          else:
             i = row
             i *= ncol
             i += col
             i = i*2
          return i

        for row in range(nrow):
            for col in range(ncol):
                    state = s_encode(row,col,picked)

                    if self.desc[row,col] == b'S':
                        isd[state] += 1

                        for action in range(num_actions):
                            new_row, new_col = row, col 
                            reward = -1 #default reward for every action taken
                            done = False
                            if not picked:
                               if action == 0: #DOWN
                                  if self.desc[row + 1, max_row] in b'GP':
                                     new_row = min(row + 1, max_row)
                                  elif self.desc[row + 1, max_row] in b'B':
                                     new_row = row #can't move
                               if action == 1: #UP
                                  if self.desc[row - 1, 0] in b'GP':
                                     new_row = max(row - 1, 0)
                                  elif self.desc[row - 1, 0] in b'B':
                                     new_row = row  #can't move
                               if action == 2: #RIGHT
                                  if self.desc[col + 1, max_col] in b'GP':
                                    new_col = min(col + 1, max_col)
                                  elif self.desc[col + 1, max_col] in b'B':
                                     new_col = col #can't move
                               if action == 3: #LEFT
                                  if self.desc[col + 1, max_col] in b'GP':
                                    new_col = min(col + 1, max_col)
                                  elif self.desc[col + 1, max_col]in b'B':
                                     new_col = col
                               if action == 4: #PICKUP
                                  if self.desc[row, col] != b'W': 
                                     print ('Wombat is not here')
                                  elif self.desc[max_row, col] == b'W':
                                     reward = 20  
                                     self.picked == True
                           
                            if picked: 
                                if action == 0:  #DOWN
                                  if self.desc[row + 1, max_row] in b'GP':
                                     new_row = min(row + 1, max_row)
                                  elif self.desc[row + 1, max_row] in b'B':
                                     new_row = row
                                if action == 1: #UP
                                   if self.desc[row - 1, 0] in b'GP':
                                     new_row = max(row - 1, 0)
                                   elif self.desc[row - 1, 0] in b'B':
                                     new_row = row
                                if action == 2: #RIGHT
                                  if self.desc[col + 1, max_col] in b'GP':
                                    new_col = min(col + 1, max_col)
                                  elif self.desc[col + 1, max_col] in b'B':
                                     new_col = col
                                if action == 3:  #LEFT
                                  if self.desc[col + 1, max_col] in b'GP':
                                    new_col = min(col + 1, max_col)
                                  elif self.desc[col + 1, max_col]in b'B':
                                     new_col = col
                                if action == 4:  #PICKUP
                                    print ('Wombat has been picked up, find the goal')
                          
                            if self.desc[new_row,new_col] == b'F': #walk into the fire
                               reward = -30
                               done = True  #exit episode...failed
                            if self.desc[new_row,new_col] == b'G' and not picked:
                               reward = -30  #reach goal before picking up the wombat
                            elif self.desc[new_row,new_col] == b'G' and picked:
                               reward = 30  #Success
                               done = True
                            new_state = s_encode(new_row, new_col,picked)
                            P[state][action].append((1.0, new_state, reward, done))
                  
        isd /= isd.sum()
        discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, isd)

    

register(
    id='rescueWombat-v0',
    entry_point=f"{__name__}:rescueWombat"
)

In [25]:
env = gym.make('rescueWombat-v0')

Wombat is not here


In [26]:
env.reset()

action_size = env.action_space.n
print("Action size ", action_size)

state_size = env.observation_space.n
print("State size ", state_size)

Action size  5
State size  50


In [6]:
print(env.action_space)
print(env.observation_space)

print(env.desc)


Discrete(5)
Discrete(50)
[[b'S' b'P' b'P' b'F' b'P']
 [b'P' b'B' b'P' b'F' b'P']
 [b'P' b'P' b'W' b'B' b'F']
 [b'P' b'B' b'P' b'P' b'P']
 [b'P' b'F' b'P' b'P' b'G']]


In [28]:
class sarsa_policy:
    def __init__(self, epsilon, alpha, gamma, num_state, num_actions, action_space):
        self.num_actions = num_actions
        self.num_states = num_states
        self.epsilon = epsilon #exploration degree
        self.alpha = alpha  #learn rate
        self.gamma = gamma  #discount
        self.Q = np.zeros((self.num_states, self.num_actions)) #Q_Table
        self.action_space = action_space

        self.last_action = None
        self.last_state = None
        
        def choose_action(self, state): #policy
            
           exp_tradeoff = random.uniform(0, 1) #select a random number
 
           if exp_tradeoff < self.epsilon:   #exploration
              action = self.action_space.sample()
           else:
              action = np.argmax(self.Q[state, :]) #greedy action
           return action

        def update(self, prev_state, next_state, reward, prev_action, next_action): #update Qtable

           predict = self.Q[prev_state, prev_action]
           target = reward + self.gamma * self.Q[next_state, next_action]
           self.Q[prev_state, prev_action] += self.alpha * (target - predict)

           total_episodes = 100
           i = 1
        for i in range(total_episodes):
           prev_state = env.reset()
           prev_action = choose_action(prev_state)
           rewards = []
           while True:
        
               next_state, reward, done, info = env.step(prev_action)
               next_action = agent.choose_action(new_state) 
               agent.update(prev_state, new_state, reward, prev_action, next_action) 
      
               prev_state = new_state 
               prev_action = next_action
               rewards.append(reward)
               i += i
               if done: 
                  cumulativeReward = sum(reward)
                  break