In [1]:
import gymnasium as gym
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

from IPython.display import HTML
import imageio
import base64

import cv2
import numpy as np

import sys

  from pkg_resources import resource_stream, resource_exists


In [2]:
sys.path.append("..")
import my_envs

In [3]:
# Choose environment
env_id_train = "MiniGrid-DistributionalShift-Train-v0"
env_id_test  = "MiniGrid-DistributionalShift-Test-v0"

In [4]:
# Create and wrap environment
def make_env(env_id, render_mode="rgb_array"):
    env = gym.make(env_id, render_mode=render_mode)
    env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)
    return env

In [5]:
# Vectorize for stable-baselines3 compatibility
vec_env = make_vec_env(lambda: make_env(env_id_train), n_envs=4)

In [6]:
# Create and train the agent
model = PPO("CnnPolicy", vec_env, verbose=1)
model.learn(total_timesteps=10_000)

Using cpu device
Wrapping the env in a VecTransposeImage.


  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 90.9     |
|    ep_rew_mean     | -131     |
| time/              |          |
|    fps             | 2916     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 8192     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 132         |
|    ep_rew_mean          | -171        |
| time/                   |             |
|    fps                  | 641         |
|    iterations           | 2           |
|    time_elapsed         | 25          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.046972707 |
|    clip_fraction        | 0.365       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.91       |
|    explained_variance   | -0.000246   |
|    learning_rate        | 0.

<stable_baselines3.ppo.ppo.PPO at 0x105f823f0>

In [7]:
# Test and record one episode
test_env = make_env(env_id_test)

obs, info = test_env.reset()
frames = []
for step in range(300):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, info = test_env.step(action)
    
    frame = test_env.render()  # returns RGB array
    
    # draw score and step number on the frame
    score = getattr(test_env.unwrapped, "current_score", 0)
    text = f"Step: {step}   Score: {score}"

    bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) #opencv uses GBR
    cv2.putText(
        bgr, text, (5, 15),
        cv2.FONT_HERSHEY_SIMPLEX, 0.5,
        (255, 255, 255), 1, cv2.LINE_AA
    )

    frame = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    frames.append(frame)

    if terminated or truncated:
        obs, info = test_env.reset()

test_env.close()

In [8]:
# Save to video (GIF / MP4)
# os.makedirs("videos", exist_ok=True)

# gif_path = "videos/ppo_lavacrossing.gif"
# mp4_path = "videos/ppo_lavacrossing.mp4"

# print(f"Saving animation to {gif_path} and {mp4_path}...")

# imageio.mimsave(gif_path, frames, fps=10)
# imageio.mimsave(mp4_path, frames, fps=10, quality=8)

# print("âœ… Done! You can open the files from:")
# print(os.path.abspath("videos/"))

In [9]:
# Show video
imageio.mimsave("/tmp/temp_video.mp4", frames, fps=15)

video = open("/tmp/temp_video.mp4", "rb").read()
b64 = base64.b64encode(video).decode()

HTML(f"""
<video width="480" height="360" controls>
    <source src="data:video/mp4;base64,{b64}" type="video/mp4">
</video>
""")
