In [1]:
import gymnasium as gym
import minigrid
from minigrid.wrappers import ImgObsWrapper
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch
from stable_baselines3.common.env_util import make_vec_env

pygame 2.5.2 (SDL 2.28.3, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class MinigridFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = torch.nn.Sequential(
            torch.nn.Conv2d(n_input_channels, 16, (2, 2)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 32, (2, 2)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, (2, 2)),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = torch.nn.Sequential(torch.nn.Linear(n_flatten, features_dim), torch.nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

In [3]:
env_name = "MiniGrid-MultiRoom-N4-S5-v0"

In [6]:
policy_kwargs = dict(
    features_extractor_class=MinigridFeaturesExtractor,
    features_extractor_kwargs=dict(features_dim=128),
)

env = gym.make(env_name, render_mode="rgb_array")
#env = make_vec_env(env_name, n_envs=4)
env = ImgObsWrapper(env)

model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1, device="cpu")
#model = PPO("MlpPolicy", env, verbose=1)
model.learn(50000, progress_bar=False)

KeyboardInterrupt: 

In [None]:
env = gym.make(env_name, render_mode="human")
env = ImgObsWrapper(env)
observation, info = env.reset()

for _ in range(500):
    #action = env.action_space.sample()  # agent policy that uses the observation and info
    action = model.predict(observation)[0]
    observation, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
        observation, info = env.reset()

env.close()

In [None]:
model.get_parameters().keys()