In [23]:
%pip install swig
%pip install stable-baselines3[extra]
%pip install gymnasium[box2d]

import gymnasium as gym
import torch
import imageio
import numpy as np
import cv2
from collections import deque
import imageio
from stable_baselines3 import PPO
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor




In [24]:
def preprocess_frame(frame, resolution=(84, 84), grayscale=True):
    if grayscale:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = cv2.resize(frame, resolution)
    frame = frame.astype(np.uint8)
    if grayscale:
        frame = np.expand_dims(frame, axis=-1)
    return frame





class FrameStack:
    def __init__(self, k, resolution=(84, 84), grayscale=True):
        self.k = k
        self.frames = deque(maxlen=k)
        self.resolution = resolution
        self.grayscale = grayscale

    def reset(self, obs):
        frame = preprocess_frame(obs, self.resolution, self.grayscale)
        for _ in range(self.k):
            self.frames.append(np.copy(frame))
        return np.concatenate(self.frames, axis=-1)

    def step(self, obs):
        frame = preprocess_frame(obs, self.resolution, self.grayscale)
        self.frames.append(frame)
        return np.concatenate(self.frames, axis=-1)




In [None]:
class PreprocessedCarRacing(gym.Wrapper):
    def __init__(self, env, frame_stack=4, resolution=(84, 84), grayscale=True):
        super().__init__(env)
        self.frame_stack = FrameStack(frame_stack, resolution, grayscale)
        channels = frame_stack if grayscale else 3 * frame_stack
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(channels, resolution[0], resolution[1]), dtype=np.uint8
        )


    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        # Stack came out as (H, W, C) so I had to transpose to (C, H, W)
        stacked = self.frame_stack.reset(obs)
        return np.transpose(stacked, (2, 0, 1)), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        stacked = self.frame_stack.step(obs)
        return np.transpose(stacked, (2, 0, 1)), reward, terminated, truncated, info


In [26]:
class CustomCNNExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=512):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            sample = torch.zeros(1, n_input_channels, observation_space.shape[1], observation_space.shape[2])
            cnn_out = self.cnn(sample)
        self._features_dim = cnn_out.shape[1]
    def forward(self, observations):
        return self.cnn(observations)




In [None]:
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo, NormalizeObservation

def make_env():
    env = gym.make("CarRacing-v3", render_mode="rgb_array")
    env = RecordEpisodeStatistics(env)  # Monitor rewards
    # add any additional wrappers here, e.g. PreprocessedCarRacing
    return env




In [28]:
# def record_policy(model_path, output_video_path, max_steps=1000, fps=30):
#     env = gym.make("CarRacing-v3", render_mode="rgb_array")
#     env = PreprocessedCarRacing(env, frame_stack=4, resolution=(84, 84), grayscale=True)
#     # Only need policy_kwargs if this extractor was used in training!
#     model = PPO.load(model_path, env=env)
#     obs, info = env.reset()
#     frames = []
#     done = False
#     steps = 0
#     while not done and steps < max_steps:
#         action, _ = model.predict(obs, deterministic=True)
#         frame = env.render()
#         frames.append(frame)
#         done = terminated or truncated
#         steps += 1
#     env.close()
#     imageio.mimsave(output_video_path, frames, fps=fps)
#     print(f"Video saved to {output_video_path}")


# The kernel may be crashing due to excessive memory usage from storing all frames in the `frames` list.
# To reduce memory usage, you can write frames directly to the video file as you generate them, instead of storing all in memory.
# imageio.get_writer allows you to write frames one by one.

def record_policy(model_path, output_video_path, max_steps=1000, fps=30):
    # Create the environment for loading the model (without preprocessing)
    env_load = gym.make("CarRacing-v3", render_mode="rgb_array")
    model = PPO.load(model_path, env=env_load)
    env_load.close()  # Close the environment used for loading

    # Create the environment for stepping and rendering (without preprocessing)
    env_step = gym.make("CarRacing-v3", render_mode="rgb_array")
    # env_step = PreprocessedCarRacing(env_step, frame_stack=4, resolution=(84, 84), grayscale=True) # Remove this line

    obs, info = env_step.reset()
    frames = []  # Re-initialize frames list
    done = False
    steps = 0

    while not done and steps < max_steps:
        action, _ = model.predict(obs, deterministic=True)
        frame = env_step.render()  # Get the frame for rendering
        frames.append(frame)
        obs, reward, terminated, truncated, info = env_step.step(action)
        done = terminated or truncated
        steps += 1

    env_step.close()
    imageio.mimsave(output_video_path, frames, fps=fps)  # Save the video
    print(f"Video saved to {output_video_path}")

In [29]:
from google.colab import files
uploaded = files.upload()  # This will prompt you to select the file

Saving ppo_carracing_900000_steps (1).zip to ppo_carracing_900000_steps (1).zip


In [30]:

record_policy("ppo_carracing_900000_steps (1)", "video_eval_900k(1).mp4", max_steps=1000)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.




Video saved to video_eval_900k(1).mp4
