In [None]:
# Install necessary packages and set up the environment
!apt-get update -qq
!apt-get install -y xvfb ffmpeg
!pip install gym[atari] stable-baselines3 gym[accept-rom-license] shimmy pyvirtualdisplay
!ale-import-roms
!pip install --upgrade stable-baselines3 gym

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import VecFrameStack, SubprocVecEnv
import matplotlib.pyplot as plt
import numpy as np
from pyvirtualdisplay import Display

# Set up virtual display
display = Display(visible=0, size=(1400, 900))
display.start()

# Create the environment
def make_env():
    env = gym.make('MsPacman-v4', render_mode='rgb_array')  # Use 'rgb_array' to capture frames
    print("Environment created")
    return env

# Create vectorized environments
num_envs = 1  # Number of parallel environments
env = SubprocVecEnv([make_env for _ in range(num_envs)])  # Vectorized environment
env = VecFrameStack(env, n_stack=4)  # Stack frames
print("Environment stacked and vectorized")

# Initialize the model
model = DQN('CnnPolicy', env, verbose=1)
print("Model initialized")

# Train the model with fewer timesteps
model.learn(total_timesteps=10)
print("Model trained")

# Define the number of episodes for evaluation
num_episodes = 1

# Evaluate the model
def evaluate_model(model, env, num_episodes=1):
    all_rewards = []
    for _ in range(num_episodes):
        obs = env.reset()
        if isinstance(obs, tuple):
            obs = obs[0]
        done = False
        total_reward = 0
        while not done:
            action, _states = model.predict(obs)
            obs, reward, done, info = env.step(action)
            total_reward += reward

            # Render the environment
            img = env.render(mode='rgb_array')  # Get the image
            if img is not None:
                if isinstance(img, list):  # Handle list of frames
                    img = img[0]
                img = np.array(img)  # Ensure img is a numpy array
                if img.shape[-1] == 12:  # Handle more than 3 channels
                    img = img[:, :, :3]  # Use only the first 3 channels (RGB)
                plt.imshow(img)
                plt.title("DQN Playing")
                plt.axis('off')
                plt.show()
            else:
                print("Rendering returned None")

            if done:
                break

        all_rewards.append(total_reward)
    return np.mean(all_rewards)

# Run evaluation
avg_reward = evaluate_model(model, env, num_episodes)
print(f"Average reward over {num_episodes} episodes: {avg_reward}")

# Close the environment
env.close()
print("Environment closed")


Output hidden; open in https://colab.research.google.com to view.