In [1]:
# __future__ import should always be first
from __future__ import annotations

# Standard library imports
from collections import defaultdict

# Third-party imports
import torch
import numpy as np
import torch.nn as nn
from torch.distributions.categorical import Categorical

# Gymnasium & Minigrid imports
import gymnasium as gym  # Correct way to import Gymnasium
from gymnasium.spaces import Dict, Discrete, Box
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.constants import DIR_TO_VEC
from minigrid.core.grid import Grid
from minigrid.core.actions import Actions
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from gymnasium.utils.play import play
import pandas as pd
# Visualization imports
import matplotlib.pyplot as plt


In [10]:

class SimpleEnv(MiniGridEnv):
    def __init__(
            self, 
            size=10, 
            agent_start_pos=(1, 8), 
            agent_start_dir=0, 
            max_steps=256, 
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)
        
        
        
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        self.action_space = gym.spaces.Discrete(3)
    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        #create gird
        self.grid = Grid(width, height)
        #place barrier
        self.grid.wall_rect(0, 0, width, height)
        #place goal
        self.put_obj(Goal(), 8, 1)
        #place walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())
        #place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"
    
    def count_states(self):
        free_cells = sum(1 for x in range(self.grid.width)
                      for y in range(self.grid.height)
                      if not self.grid.get(x, y)) * 4
        return free_cells 


In [2]:
"""This is working code for a simple minigrid environment with full view."""

class SimpleEnv(MiniGridEnv):
    def __init__(
        self, 
        size=10, 
        agent_start_pos=(1, 8), 
        agent_start_dir=0, 
        max_steps=256, 
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)

        # Create a simple mission space
        mission_space = MissionSpace(mission_func=self._gen_mission)

        # Use an odd agent_view_size to pass MiniGrid's assertion
        # Even though we'll override gen_obs(), we still need an odd value here.
        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            agent_view_size=11,  # Must be odd (>= size)
            **kwargs,
        )

        # 3 discrete actions (forward, turn left, turn right)
        self.action_space = Discrete(3)

    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        # Create the grid
        self.grid = Grid(width, height)
        # Place outer walls
        self.grid.wall_rect(0, 0, width, height)
        # Place a goal
        self.put_obj(Goal(), 8, 1)
        # Place some walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())

        # Place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"

    def count_states(self):
        free_cells = sum(
            1
            for x in range(self.grid.width)
            for y in range(self.grid.height)
            if not self.grid.get(x, y)
        ) * 4
        return free_cells 

    def gen_obs(self):
        """Override the default partial-view to return the full grid."""
        full_grid = self.grid.encode()  # shape (width, height, 3)
        return {
            "image": full_grid,
            "direction": self.agent_dir,
            "mission": self.mission
        }



In [11]:
env = SimpleEnv(render_mode= None)
env.reset();

In [12]:
action = 0
obs, _, _, _, _ = env.step(action)
print("Obs shape after step:", obs["image"].shape)


Obs shape after step: (7, 7, 3)


In [13]:
action = 2
obs1, _, _, _, _ = env.step(action)
print("Obs shape after step:", obs["image"].shape)

Obs shape after step: (7, 7, 3)


In [14]:
obs["image"] == obs1["image"]

array([[[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [False, False,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [False, False,  True],
        [False, False,  True],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],


In [15]:
def mlp(sizes, activation = nn.Tanh, output_activation = nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) -2 else output_activation 
        layers += [nn.Linear(sizes [i], sizes[i +1], act())]
    return nn.Sequential(*layers)


In [None]:
def train(env_name = env, hidden_sizes = [32], lr = 1e-2, epochs = 50, batch_size = 5000, render = False):

    obs_dim = env.observation_space.shape[0] 
    n_acts = env.action_space.n

    #generate polucy network
    logits_net = mlp(sizes = [obs_dim] + hidden_sizes + [n_acts])

    #takes policy network and returns action distribution
    def get_policy(obs):
        logits = logits_net(obs)
        return Categorical(logits = logits)

    #samples actions from the action distrubution from the policy network
    def get_action(obs):
        return get_policy(obs).sample().item()
    
    #loss function
    