In [2]:
import gymnasium as gym
import ale_py
import cv2

from stable_baselines3 import PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import (
    EvalCallback,
)
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.callbacks import ProgressBarCallback

In [16]:
# gym.register_envs(ale_py)

In [3]:
env = make_atari_env("Breakout-v4", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

A.L.E: Arcade Learning Environment (version 0.11.0+dfae0bd)
[Powered by Stella]


In [18]:
eval_callback = EvalCallback(
    env,
    best_model_save_path="../logs/exercise_3/",
    log_path="../logs/exercise_3/",
    eval_freq=5_000,
    deterministic=True,
    render=False,
    n_eval_episodes=50,
)

In [19]:
model = A2C(
    "CnnPolicy",
    env,
    tensorboard_log="../logs/exercise_3/tensorboard/",
    verbose=0,
)

In [20]:
TRAINING_TIMESTEPS =  1_000_000

model.learn(
    total_timesteps=TRAINING_TIMESTEPS, callback=[eval_callback, ProgressBarCallback()]
)

Output()



<stable_baselines3.a2c.a2c.A2C at 0x7a13a0144f50>

In [4]:
model = A2C.load("../logs/exercise_3/best_model.zip")

In [5]:
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")

Mean reward: 12.60 ± 4.48


In [None]:
MAX_STEPS = 500

observation = env.reset()
frames = []
step_count = 0

for step in range(MAX_STEPS):
    frame = env.render()
    frames.append(frame)

    action, _ = model.predict(observation, deterministic=True)
    obs, rewards, dones, infos = env.step(action)
    step_count += 1

    # if dones[0]:
    #     print(
    #         f"Episode finished after {step_count} steps "
    #         f"({'truncated' if infos[0].get('TimeLimit.truncated', False) else 'terminated'})"
    #     )
    #     break

env.close()

print(f"Final Step: {step_count}")
print(f"Number of Frames: {len(frames)}")

Final Step: 500
Number of Frames: 500


In [23]:
from IPython.display import HTML
from base64 import b64encode
import os

# Create a video from the frames
video_filename = "../videos/atari_breakout_dqn.mp4"
compressed_path = "../videos/atari_breakout_dqn_compressed.mp4"
height, width, _ = frames[0].shape

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(video_filename, fourcc, 30.0, (width, height))

for frame in frames:
    video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
video.release()

print(f"Video guardado como {video_filename}")

os.system(f"rm {compressed_path}")
# Compressed video path
os.system(f"ffmpeg -i {video_filename} -vcodec libx264 {compressed_path}")
os.system(f"rm {video_filename}")
os.system(f"mv {compressed_path} {video_filename}")

# Show video
mp4 = open(video_filename, "rb").read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(
    """
<video width=800 controls>
      <source src="%s" type="video/mp4">
</video>"""
    % data_url
)

Video guardado como ../videos/atari_breakout_dqn.mp4


rm: cannot remove '../videos/atari_breakout_dqn_compressed.mp4': No such file or directory
ffmpeg version n7.1 Copyright (c) 2000-2024 the FFmpeg developers
  built with gcc 14.2.1 (GCC) 20250207
  configuration: --prefix=/usr --disable-debug --disable-static --disable-stripping --enable-amf --enable-avisynth --enable-cuda-llvm --enable-lto --enable-fontconfig --enable-frei0r --enable-gmp --enable-gnutls --enable-gpl --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libdav1d --enable-libdrm --enable-libdvdnav --enable-libdvdread --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgsm --enable-libharfbuzz --enable-libiec61883 --enable-libjack --enable-libjxl --enable-libmodplug --enable-libmp3lame --enable-libopencore_amrnb --enable-libopencore_amrwb --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libplacebo --enable-libpulse --enable-librav1e --enable-librsvg --enable-librubberband --enable-libsnappy --e