In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
import numpy as np
import pygame

import gymnasium as gym
from gymnasium import spaces


class GridWorldEnv(gym.Env):
	metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

	def __init__(self, render_mode=None, size=5):
		self.size = size  # The size of the square grid
		self.window_size = 512  # The size of the PyGame window

		self.observation_space = spaces.Box(0, size - 1, shape=(4,), dtype=int)
		self.action_space = spaces.Discrete(4)

		self._action_to_direction = {
			0: np.array([-1, 0]),	# up
			1: np.array([0, 1]), 	# right
			2: np.array([1, 0]),	# down
			3: np.array([0, -1]),	# left
		}

		assert render_mode is None or render_mode in self.metadata["render_modes"]
		self.render_mode = render_mode

		self.window = None
		self.clock = None

	def _get_obs(self):
		return np.concatenate((self._agent_location, self._target_location))

	def _get_info(self):
		return np.linalg.norm(self._agent_location - self._target_location, ord=1)

	def reset(self, seed=None, options=None):
		# We need the following line to seed self.np_random
		super().reset(seed=seed)

		# Choose the agent's location uniformly at random
		self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)

		# We will sample the target's location randomly until it does not coincide with the agent's location
		self._target_location = self._agent_location
		while np.array_equal(self._target_location, self._agent_location):
			self._target_location = self.np_random.integers(
				0, self.size, size=2, dtype=int
			)

		observation = self._get_obs()
		info = self._get_info()

		if self.render_mode == "human":
			self._render_frame()

		return observation, info

	def step(self, action):
		# Map the action (element of {0,1,2,3}) to the direction we walk in
		direction = self._action_to_direction[action]
		# We use `np.clip` to make sure we don't leave the grid
		self._agent_location = np.clip(
			self._agent_location + direction, 0, self.size - 1
		)
		# An episode is done iff the agent has reached the target
		terminated = np.array_equal(self._agent_location, self._target_location)
		reward = 1 if terminated else 0  # Binary sparse rewards
		observation = self._get_obs()
		info = self._get_info()

		if self.render_mode == "human":
			self._render_frame()

		return observation, reward, terminated, False, info

	def render(self):
		if self.render_mode == "rgb_array":
			return self._render_frame()

	def _render_frame(self):
		if self.window is None and self.render_mode == "human":
			pygame.init()
			pygame.display.init()
			self.window = pygame.display.set_mode(
				(self.window_size, self.window_size)
			)
		if self.clock is None and self.render_mode == "human":
			self.clock = pygame.time.Clock()

		canvas = pygame.Surface((self.window_size, self.window_size))
		canvas.fill((255, 255, 255))
		pix_square_size = (
			self.window_size / self.size
		)  # The size of a single grid square in pixels

		# First we draw the target
		pygame.draw.rect(
			canvas,
			(255, 0, 0),
			pygame.Rect(
				pix_square_size * self._target_location,
				(pix_square_size, pix_square_size),
			),
		)
		# Now we draw the agent
		pygame.draw.circle(
			canvas,
			(0, 0, 255),
			(self._agent_location + 0.5) * pix_square_size,
			pix_square_size / 3,
		)

		# Finally, add some gridlines
		for x in range(self.size + 1):
			pygame.draw.line(
				canvas,
				0,
				(0, pix_square_size * x),
				(self.window_size, pix_square_size * x),
				width=3,
			)
			pygame.draw.line(
				canvas,
				0,
				(pix_square_size * x, 0),
				(pix_square_size * x, self.window_size),
				width=3,
			)

		if self.render_mode == "human":
			# The following line copies our drawings from `canvas` to the visible window
			self.window.blit(canvas, canvas.get_rect())
			pygame.event.pump()
			pygame.display.update()

			# We need to ensure that human-rendering occurs at the predefined framerate.
			# The following line will automatically add a delay to keep the framerate stable.
			self.clock.tick(self.metadata["render_fps"])
		else:  # rgb_array
			return np.transpose(
				np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
			)

	def close(self):
		if self.window is not None:
			pygame.display.quit()
			pygame.quit()


In [6]:
env = GridWorldEnv()

In [7]:
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Define the policy network
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size, 32)
        self.fc2 = nn.Linear(32, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

policy = Policy()


In [8]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(policy.parameters())


In [9]:
def update_policy(rewards, log_probs, optimizer):
    log_probs = torch.stack(log_probs)
    loss = -torch.mean(log_probs * sum(rewards))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [14]:
for episode in range(10000):
	state, _ = env.reset()
	done = False
	rewards = []
	log_probs = []
	
	while not done:
		# Select action
		state = torch.tensor(state, dtype=torch.float32).reshape(1, -1)
		probs = policy(state)
		action = torch.multinomial(probs, 1).item()
		log_prob = torch.log(probs[0, action])

		# Take step
		next_state, reward, done, _, _ = env.step(action)
		rewards.append(reward)
		log_probs.append(log_prob)
		state = next_state
		
	# Update policy
	print(f"Episode {episode}: {sum(rewards)}")
	update_policy(rewards, log_probs, optimizer)
	rewards = []
	log_probs = []


Episode 0: 1
Episode 1: 1
Episode 2: 1
Episode 3: 1
Episode 4: 1
Episode 5: 1
Episode 6: 1
Episode 7: 1
Episode 8: 1
Episode 9: 1
Episode 10: 1
Episode 11: 1
Episode 12: 1
Episode 13: 1
Episode 14: 1
Episode 15: 1
Episode 16: 1
Episode 17: 1
Episode 18: 1
Episode 19: 1
Episode 20: 1
Episode 21: 1
Episode 22: 1
Episode 23: 1
Episode 24: 1
Episode 25: 1
Episode 26: 1
Episode 27: 1
Episode 28: 1
Episode 29: 1
Episode 30: 1
Episode 31: 1
Episode 32: 1
Episode 33: 1
Episode 34: 1
Episode 35: 1
Episode 36: 1
Episode 37: 1
Episode 38: 1
Episode 39: 1
Episode 40: 1
Episode 41: 1
Episode 42: 1
Episode 43: 1
Episode 44: 1
Episode 45: 1
Episode 46: 1
Episode 47: 1
Episode 48: 1
Episode 49: 1
Episode 50: 1
Episode 51: 1
Episode 52: 1
Episode 53: 1
Episode 54: 1
Episode 55: 1
Episode 56: 1
Episode 57: 1
Episode 58: 1
Episode 59: 1
Episode 60: 1
Episode 61: 1
Episode 62: 1
Episode 63: 1
Episode 64: 1
Episode 65: 1
Episode 66: 1
Episode 67: 1
Episode 68: 1
Episode 69: 1
Episode 70: 1
Episode 71: 1
Ep

In [22]:
state, _ = env.reset()

print(state)

state = torch.tensor(state, dtype=torch.float32).reshape(1, -1)
probs = policy(state)

action = torch.multinomial(probs, 1).item()

action


[0 3 3 0]


3