In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
import numpy as np

import gymnasium as gym
from gymnasium import spaces

class TerraBot(gym.Env):
	def __init__(self, size=5):
		self.size = size  # The size of the square grid
		self.window_size = 512  # The size of the PyGame window

		self.observation_space = spaces.Box(0, size - 1, shape=(4,), dtype=int)
		self.action_space = spaces.Discrete(4)

		self._action_to_direction = {
			0: np.array([-1, 0]),	# up
			1: np.array([0, 1]), 	# right
			2: np.array([1, 0]),	# down
			3: np.array([0, -1]),	# left
		}

	def _get_obs(self):
		return np.concatenate((self._agent_location, self._target_location))

	def _get_info(self):
		return np.linalg.norm(self._agent_location - self._target_location, ord=1)

	def reset(self, seed=None, options=None):
		# We need the following line to seed self.np_random
		super().reset(seed=seed)

		# Choose the agent's location uniformly at random
		self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)

		# We will sample the target's location randomly until it does not coincide with the agent's location
		self._target_location = self._agent_location
		while np.array_equal(self._target_location, self._agent_location):
			self._target_location = self.np_random.integers(
				0, self.size, size=2, dtype=int
			)

		observation = self._get_obs()
		info = self._get_info()

		if self.render_mode == "human":
			self._render_frame()

		return observation, info

	def step(self, action):
		# Map the action (element of {0,1,2,3}) to the direction we walk in
		direction = self._action_to_direction[action]
		# We use `np.clip` to make sure we don't leave the grid
		self._agent_location = np.clip(
			self._agent_location + direction, 0, self.size - 1
		)
		# An episode is done iff the agent has reached the target
		terminated = np.array_equal(self._agent_location, self._target_location)
		reward = 1 if terminated else 0  # Binary sparse rewards
		observation = self._get_obs()
		info = self._get_info()

		if self.render_mode == "human":
			self._render_frame()

		return observation, reward, terminated, False, info


In [4]:
env = GridWorldEnv()

In [5]:
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Define the policy network
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size, 32)
        self.fc2 = nn.Linear(32, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

policy = Policy()


In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(policy.parameters())


In [7]:
def update_policy(rewards, log_probs, optimizer):
    log_probs = torch.stack(log_probs)
    loss = -torch.mean(log_probs * sum(rewards))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [9]:
for episode in range(10000):
	state, _ = env.reset()
	done = False
	rewards = []
	log_probs = []
	
	while not done:
		# Select action
		state = torch.tensor(state, dtype=torch.float32).reshape(1, -1)
		probs = policy(state)
		action = torch.multinomial(probs, 1).item()
		log_prob = torch.log(probs[0, action])

		# Take step
		next_state, reward, done, _, _ = env.step(action)
		rewards.append(reward)
		log_probs.append(log_prob)
		state = next_state
		
	# Update policy
	if episode % 1000 == 0:
		print(f"Episode {episode}: {sum(rewards)}")
	update_policy(rewards, log_probs, optimizer)
	rewards = []
	log_probs = []


Episode 0: 1
Episode 1000: 1
Episode 2000: 1
Episode 3000: 1
Episode 4000: 1
Episode 5000: 1
Episode 6000: 1
Episode 7000: 1
Episode 8000: 1
Episode 9000: 1


In [12]:
state, _ = env.reset()

print(state)

state = torch.tensor(state, dtype=torch.float32).reshape(1, -1)
probs = policy(state)

action = torch.multinomial(probs, 1).item()

action


[0 4 4 1]


2