In [44]:
# __future__ import should always be first
from __future__ import annotations

# Standard library imports
from collections import defaultdict

# Third-party imports
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam

import numpy as np


# Gymnasium & Minigrid imports
import gymnasium as gym  # Correct way to import Gymnasium
from gymnasium.spaces import Dict, Discrete, Box
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.constants import DIR_TO_VEC
from minigrid.core.grid import Grid
from minigrid.core.actions import Actions
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from gymnasium.utils.play import play
import pandas as pd
# Visualization imports
import matplotlib.pyplot as plt


In [113]:

class SimpleEnv(MiniGridEnv):
    def __init__(
            self, 
            size=10, 
            agent_start_pos=(1, 8), 
            agent_start_dir=0, 
            max_steps=1000, 
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)
        
        
        
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        self.action_space = gym.spaces.Discrete(3)
    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        #create gird
        self.grid = Grid(width, height)
        #place barrier
        self.grid.wall_rect(0, 0, width, height)
        #place goal
        self.put_obj(Goal(), 8, 1)
        #place walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())
        #place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"
    
    def count_states(self):
        free_cells = sum(1 for x in range(self.grid.width)
                      for y in range(self.grid.height)
                      if not self.grid.get(x, y)) * 4
        return free_cells 


In [114]:
env = SimpleEnv(render_mode= None)
env.reset();

In [20]:
action = 0
obs, _, _, _, _ = env.step(action)
print("Obs shape after step:", obs["image"].shape)


Obs shape after step: (7, 7, 3)


In [21]:
action = 0
obs1, _, _, _, _ = env.step(action)
print("Obs shape after step:", obs["image"].shape)

Obs shape after step: (7, 7, 3)


In [22]:
obs["image"] == obs1["image"]

array([[[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [False, False,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [False, False,  True],
        [False, False,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],


In [18]:
env.observation_space

Dict('direction': Discrete(4), 'image': Box(0, 255, (7, 7, 3), uint8), 'mission': MissionSpace(<function SimpleEnv._gen_mission at 0x16c611440>, None))

In [None]:
env.observation_space.shape[0]

# Wrappers

In [47]:
class MiniGridFlatImg(gym.ObservationWrapper):
    """
    Keep only the 7x7 RGB image from a MiniGrid Dict observation.
    Output: 147-dim float32 vector in [0, 1].
    """
    def __init__(self, env):
        # initialise the parent ObservationWrapper so it can do its bookkeeping
        super().__init__(env)

        img_size = np.prod(env.observation_space["image"].shape)   # 7*7*3 = 147
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=(img_size,), dtype=np.float32
        )

    def observation(self, obs):
        img_flat = obs["image"].astype(np.float32).flatten() / 255.0
        return img_flat

In [48]:
class MiniGridReward(gym.Wrapper):
    def __init__(self, env, goal_states):
        super().__init__(env)
        self.goal_states = set(goal_states)

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        # access agent position after the transition
        x, y = self.env.unwrapped.agent_pos
        next_state = (x, y)

        rew = 0 if next_state in self.goal_states else -1
        done = terminated or truncated or (rew == 0)

        return obs, rew, done, truncated, info
        


In [115]:
env_wrapped = MiniGridFlatImg(env)

In [26]:
env_wrapped.reset();

In [28]:
action = 0
obs, _, _, _, _ = env_wrapped.step(action)


In [32]:
obs.shape

(147,)

In [33]:
action = 0
obs1, _, _, _, _ = env_wrapped.step(action)

In [36]:
env_wrapped.observation_space

Box(0.0, 1.0, (147,), float32)

In [37]:
env_wrapped.observation_space.shape[0]

147

In [116]:
env_wrapped_rew= MiniGridReward(env_wrapped, goal_states = [(8, 1)])

In [None]:
env_wrapped_rew

In [77]:
obs, info = env_wrapped_rew.reset()

In [78]:
print(type(obs), obs.shape) 

AttributeError: 'tuple' object has no attribute 'shape'

In [76]:
len(obs)

147

# Vanilla MLP

In [51]:
def mlp(sizes, activation = nn.Tanh, output_activation = nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) -2 else output_activation #everything but last layer has activation, outherwise output
        layers += [nn.Linear(sizes [i], sizes[i +1], act())]
    return nn.Sequential(*layers)


In [117]:
def train(env_name = env_wrapped_rew, hidden_sizes = [32], lr = 1e-2, epochs = 50, batch_size = 50, render = False):

    env = env_name

    obs_dim = env.observation_space.shape[0] 
    # print("obs_dim:", obs_dim)
    n_acts = env.action_space.n
    # print("n_acts:", n_acts)



    #generate polucy network
    logits_net = mlp(sizes = [obs_dim] + hidden_sizes + [n_acts])

    #takes policy network and returns action distribution
    def get_policy(obs):
        logits = logits_net(obs)
        return Categorical(logits = logits)

    #samples actions from the action distrubution from the policy network
    def get_action(obs):
        return get_policy(obs).sample().item()
    

    # make loss function whose gradient, for the right data, is policy gradient
    def compute_loss(obs, act, weights):
        logp = get_policy(obs).log_prob(act)
        return -(logp * weights).mean()

      # make optimizer
    optimizer = Adam(logits_net.parameters(), lr=lr)

    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs, info = env.reset()      # first obs comes from starting distribution
        # print("obs shape:", obs.shape, "obs type:", type(obs))
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()
            # print("obs shape:", obs.shape, "obs type:", type(obs))
            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _, _ = env.step(act)
            

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len     #why is this the way the setup is, this is where i want to add rewards

                # reset episode-specific variables
                obs, info  = env.reset()
                done = False
                ep_rews = []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break
        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act=torch.as_tensor(batch_acts, dtype=torch.int32),
                                  weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))
        
    return logits_net





In [63]:
env_wrapped_rew.observation_space.shape

(147,)

In [120]:
policy = train(env_name = env_wrapped_rew, hidden_sizes = [32], lr = 1e-2, epochs = 1500, batch_size = 5000, render = False)

epoch:   0 	 loss: -916.581 	 return: -713.375 	 ep_len: 713.875
epoch:   1 	 loss: -1083.559 	 return: -1000.000 	 ep_len: 1000.000
epoch:   2 	 loss: -1034.143 	 return: -902.167 	 ep_len: 902.333
epoch:   3 	 loss: -951.360 	 return: -790.000 	 ep_len: 790.429
epoch:   4 	 loss: -944.680 	 return: -800.714 	 ep_len: 801.143
epoch:   5 	 loss: -923.436 	 return: -827.000 	 ep_len: 827.571
epoch:   6 	 loss: -798.241 	 return: -642.667 	 ep_len: 643.333
epoch:   7 	 loss: -1023.132 	 return: -893.667 	 ep_len: 893.833
epoch:   8 	 loss: -1009.061 	 return: -843.857 	 ep_len: 844.143
epoch:   9 	 loss: -985.804 	 return: -834.714 	 ep_len: 835.143
epoch:  10 	 loss: -1008.750 	 return: -933.333 	 ep_len: 933.500
epoch:  11 	 loss: -1024.364 	 return: -955.833 	 ep_len: 956.000
epoch:  12 	 loss: -974.409 	 return: -893.500 	 ep_len: 894.000
epoch:  13 	 loss: -1003.632 	 return: -945.333 	 ep_len: 945.667
epoch:  14 	 loss: -927.076 	 return: -841.667 	 ep_len: 842.167
epoch:  15 	 los

In [121]:
policy

Sequential(
  (0): Linear(in_features=147, out_features=32, bias=True)
  (1): Linear(in_features=32, out_features=3, bias=True)
)

In [127]:
5000/36

138.88888888888889

In [139]:
class RandomStart(gym.Wrapper):
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)   # ← still a dict here
        base = self.unwrapped                  # MiniGridEnv

        # sample a free floor tile
        while True:
            x = base.np_random.integers(1, base.width  - 1)
            y = base.np_random.integers(1, base.height - 1)
            if base.grid.get(x, y) is None:
                base.agent_pos = (x, y)
                base.agent_dir = base.np_random.integers(0, 4)
                break

        # regenerate dict-obs; *do not* flatten here
        obs = base.gen_obs()
        return obs, info

In [164]:
env_base   = LargeSimpleEnv(render_mode="human")
env_rs     = RandomStart(env_base)       # randomise first
env_flat   = MiniGridFlatImg(env_rs)     # then flatten
env_wr     = MiniGridReward(env_flat, goal_states=[(8, 1)])

In [122]:
#if you want to always start at the behinning at the start of each episode
# # create the base env *with* a render mode
# env_base = SimpleEnv(render_mode="human")      # window pops up
# # or render_mode="rgb_array"  # returns an image you can display in a notebook

# # wrap exactly as before
# env_flat_vis  = MiniGridFlatImg(env_base)
# env_wrapped_rew_vis = MiniGridReward(env_flat_vis, goal_states=[(8, 1)])

In [165]:
#takes policy network and returns action distribution
def get_policy(obs):
    logits = policy(obs)
    return Categorical(logits = logits)

#samples actions from the action distrubution from the policy network
def get_action(obs):
    return get_policy(obs).sample().item()



def play_policy(env, policy, num_episodes=1):
    """
    Play the policy in the environment for a number of episodes.
    """
    for episode in range(num_episodes):
        obs, info = env.reset()
        done = False
        ep_rews = []
        while not done:
            action = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _, _ = env.step(action)
            ep_rews.append(rew)
            env.render()
        print(f"Episode {episode + 1} finished with reward: {sum(ep_rews)}")

play_policy(env_wr, policy, num_episodes=10)

KeyboardInterrupt: 

In [133]:
cell = env.unwrapped.grid.get(1, 8)
cell

# Testing Generalizability

In [163]:

class LargeSimpleEnv(MiniGridEnv):
    def __init__(
            self, 
            size=20, 
            agent_start_pos=(1, 8), 
            agent_start_dir=0, 
            max_steps=1000, 
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)
        
        
        
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        self.action_space = gym.spaces.Discrete(3)
    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        #create gird
        self.grid = Grid(width, height)
        #place barrier
        self.grid.wall_rect(0, 0, width, height)
        #place goal
        self.put_obj(Goal(), 8, 1)
        #place walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())
            self.grid.set(i, width - 10, Wall())
            self.grid.set(i + width // 2 - 1, width - 13, Wall())
            self.grid.set(i, height - 4, Wall())
        #place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"
    
    def count_states(self):
        free_cells = sum(1 for x in range(self.grid.width)
                      for y in range(self.grid.height)
                      if not self.grid.get(x, y)) * 4
        return free_cells 
