<a href="https://colab.research.google.com/github/ipez02/csci164/blob/main/cs166training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet "gymnasium[atari,accept-rom-license]" stable-baselines3[extra] autorom[accept-rom-license]
!AutoROM --accept-license


[0mAutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [2]:
!pip install ale-py



In [3]:
import gymnasium as gym


print(gym.envs.registry.keys())

dict_keys(['CartPole-v0', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Pendulum-v1', 'Acrobot-v1', 'phys2d/CartPole-v0', 'phys2d/CartPole-v1', 'phys2d/Pendulum-v0', 'LunarLander-v3', 'LunarLanderContinuous-v3', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3', 'CarRacing-v3', 'Blackjack-v1', 'FrozenLake-v1', 'FrozenLake8x8-v1', 'CliffWalking-v1', 'CliffWalkingSlippery-v1', 'Taxi-v3', 'tabular/Blackjack-v0', 'tabular/CliffWalking-v0', 'Reacher-v2', 'Reacher-v4', 'Reacher-v5', 'Pusher-v2', 'Pusher-v4', 'Pusher-v5', 'InvertedPendulum-v2', 'InvertedPendulum-v4', 'InvertedPendulum-v5', 'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v4', 'InvertedDoublePendulum-v5', 'HalfCheetah-v2', 'HalfCheetah-v3', 'HalfCheetah-v4', 'HalfCheetah-v5', 'Hopper-v2', 'Hopper-v3', 'Hopper-v4', 'Hopper-v5', 'Swimmer-v2', 'Swimmer-v3', 'Swimmer-v4', 'Swimmer-v5', 'Walker2d-v2', 'Walker2d-v3', 'Walker2d-v4', 'Walker2d-v5', 'Ant-v2', 'Ant-v3', 'Ant-v4', 'Ant-v5', 'Humanoid-v2', 'Humanoid-v3', 

In [4]:
import ale_py
import gymnasium as gym

In [10]:
# =====================================================
# DQN on Atari Space Invaders (based on Pong starter)
# =====================================================

!pip install --quiet "gymnasium[atari,accept-rom-license]" stable-baselines3[extra] ale-py
!AutoROM --accept-license

import os
# Set the ATARI_ROM_DIRS environment variable
# IMPORTANT: Update this path if your ROMs are in a different location
# os.environ["ATARI_ROM_DIRS"] = "/usr/local/lib/python3.10/dist-packages/ale_py/roms/"
# The above line might be needed depending on your setup if AutoROM doesn't place
# ROMs in a discoverable location.

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from gymnasium.wrappers import RecordVideo # No longer explicitly importing FrameStack/FrameStackObservation


# Corrected import path for make_atari and wrap_deepmind
from stable_baselines3.common.env_util import make_atari_env
# from stable_baselines3.common.atari_wrappers import AtariWrapper # No longer directly using AtariWrapper args this way

# ------------------------------------------
# 1. Environment Setup
# ------------------------------------------

def make_env():
    # Using make_atari_env with "NoFrameskip" version, which often handles preprocessing including frame stacking
    base_env = make_atari_env("SpaceInvadersNoFrameskip-v4", n_envs=1, seed=0)
    # Access the single environment within the VecEnv returned by make_atari_env
    single_env = base_env.envs[0]
    # Wrap the single environment with Monitor
    monitored_env = Monitor(single_env)
    # Return the monitored single environment
    return monitored_env

env = DummyVecEnv([lambda: make_env()]) # DummyVecEnv expects a function that returns the env

# ------------------------------------------
# 2. DQN Model
# ------------------------------------------

model = DQN(
    "CnnPolicy",
    env,
    buffer_size=100000,
    learning_rate=1e-4,
    batch_size=32,
    train_freq=4,
    target_update_interval=1000,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    verbose=1,
)

# ------------------------------------------
# 3. Training Runs (Early + Later)
# ------------------------------------------

# Early run (mostly random)
print("Starting early training run...")
model.learn(total_timesteps=5000)
model.save("dqn_spaceinvaders_early")
print("Early training run finished.")

# Continue training for longer run
print("Starting later training run...")
model.learn(total_timesteps=100000)
model.save("dqn_spaceinvaders_later")
print("Later training run finished.")

# ------------------------------------------
# 4. Video Recording
# ------------------------------------------

os.makedirs("videos", exist_ok=True)

def record_video(model_path, save_name, max_steps=2000):
    print(f"Recording video for {save_name}...")
    # Use make_atari_env with "NoFrameskip" version for evaluation as well
    eval_env = make_atari_env("SpaceInvadersNoFrameskip-v4", n_envs=1, seed=0)
    # RecordVideo expects a single env, so access the unwrapped env from the VecEnv
    record_env = RecordVideo(
        eval_env.envs[0],
        video_folder="videos",
        name_prefix=save_name
    )

    model = DQN.load(model_path)

    obs, _ = record_env.reset() # Use record_env here
    done = False
    steps = 0
    while not done and steps < max_steps:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, _ = record_env.step(action) # Use record_env here
        done = terminated or truncated
        steps += 1

    record_env.close() # Close record_env
    print(f"Video recording for {save_name} finished.")


# Record both early and later runs
record_video("dqn_spaceinvaders_early", "spaceinvaders_early")
record_video("dqn_spaceinvaders_later", "spaceinvaders_later")

print("✅ Videos saved to /videos/")

[0mAutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.
Using cpu device
Wrapping the env in a VecTransposeImage.
Starting early training run...


  return datetime.utcnow().replace(tzinfo=utc)


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 164      |
|    ep_rew_mean      | 2        |
|    exploration_rate | 0.01     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 49       |
|    time_elapsed     | 13       |
|    total_timesteps  | 658      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000968 |
|    n_updates        | 139      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 293      |
|    ep_rew_mean      | 4.88     |
|    exploration_rate | 0.01     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 44       |
|    time_elapsed     | 52       |
|    total_timesteps  | 2342     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00129  |
|    n_updates      

  IMAGEMAGICK_BINARY = r"C:\Program Files\ImageMagick-6.8.8-Q16\magick.exe"
  return datetime.utcnow().replace(tzinfo=utc)
  logger.warn(


Video recording for spaceinvaders_early finished.
Recording video for spaceinvaders_later...
Video recording for spaceinvaders_later finished.
✅ Videos saved to /videos/


In [7]:
import gymnasium.wrappers
import gymnasium.wrappers.atari_preprocessing

print("Contents of gymnasium.wrappers:")
print(dir(gymnasium.wrappers))

print("\nContents of gymnasium.wrappers.atari_preprocessing:")
print(dir(gymnasium.wrappers.atari_preprocessing))

Contents of gymnasium.wrappers:
['AddRenderObservation', 'AddWhiteNoise', 'AtariPreprocessing', 'Autoreset', 'ClipAction', 'ClipReward', 'DelayObservation', 'DtypeObservation', 'FilterObservation', 'FlattenObservation', 'FrameStackObservation', 'GrayscaleObservation', 'HumanRendering', 'MaxAndSkipObservation', 'NormalizeObservation', 'NormalizeReward', 'ObstructView', 'OrderEnforcing', 'PassiveEnvChecker', 'RecordEpisodeStatistics', 'RecordVideo', 'RenderCollection', 'RescaleAction', 'RescaleObservation', 'ReshapeObservation', 'ResizeObservation', 'StickyAction', 'TimeAwareObservation', 'TimeLimit', 'TransformAction', 'TransformObservation', 'TransformReward', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__getattr__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_renamed_wrapper', '_wrapper_to_class', 'atari_preprocessing', 'common', 'importlib', 'rendering', 'stateful_action', 'stateful_observation', 'stateful_reward', 'transform_action', 'tran