In [1]:
# @title Atari Pong AI - Installation Script (More Robust)

# This script aims to provide the most robust installation for Atari Pong
# using Stable Baselines3 and Gymnasium in Google Colab.

# IMPORTANT:
# 1. Start with a FRESH COLAB NOTEBOOK (Runtime -> Restart runtime).
# 2. Run this cell FIRST and wait for it to complete.
# 3. Carefully observe ALL output.
# 4. If the "Quick Environment Test" at the end FAILS, go to "Runtime -> Restart runtime"
#    and run *this entire cell again from scratch*. This is often necessary.
# 5. If it still fails after a couple of restarts and reruns, please share the FULL output.

print("--- Starting Robust Installation for Atari Pong ---")

# 0. Ensure pip is up-to-date
print("\n0. Upgrading pip...")
!pip install --upgrade pip

# 1. Install/Upgrade core Gymnasium and specific ale-py.
#    `gymnasium` is the successor to `gym`.
#    We explicitly install `ale-py` and then `gymnasium[atari]` to ensure order.
print("\n1. Installing/Upgrading ale-py and gymnasium[atari]...")
# Install ale-py first, explicitly, to ensure it's present for gymnasium[atari]
!pip install --upgrade ale-py
# Then install gymnasium with atari extras, which should now find ale-py
!pip install --upgrade gymnasium[atari]

# 2. Install Stable Baselines3 (SB3).
print("\n2. Installing/Upgrading stable-baselines3...")
!pip install --upgrade stable-baselines3

# 3. Install AutoROM.
#    This is CRUCIAL for Atari ROM management.
print("\n3. Installing AutoROM...")
!pip install autorom[accept-rom-license]

# 4. Run AutoROM.build() to download Atari ROMs.
#    This command needs to be run explicitly. This is the most common point of failure.
print("\n4. Running AutoROM.build() to download Atari ROMs. This may take a moment...")
print("Look for messages indicating ROMs are being downloaded/accepted.")
!python -m autorom.accept-rom-license

# 5. Install OpenCV Python (cv2).
print("\n5. Installing opencv-python...")
!pip install --upgrade opencv-python

print("\n--- Installation Steps Completed ---")

# --- Robust Environment Test ---
# This test attempts to create the Pong environment to verify installation.
print("\n--- Running Robust Environment Test for 'ALE/Pong-v5' ---")
try:
    import gymnasium as gym
    # Try importing ale_py directly to check if it's found
    try:
        import ale_py
        print(f"Successfully imported ale_py version: {ale_py.__version__}")
    except ImportError:
        print("ERROR: Could not import 'ale_py'. This indicates a fundamental installation issue.")
        raise

    # Attempt to make the environment
    env_test = gym.make("ALE/Pong-v5")
    env_test.reset()
    env_test.close()
    print(f"Successfully created and reset 'ALE/Pong-v5' environment.")
    print("This indicates that the Atari ROMs and dependencies are likely set up correctly.")
    print("\nSUCCESS: You can now proceed to the training script in a new cell.")
except Exception as e:
    print(f"\nFATAL ERROR: Failed to create 'ALE/Pong-v5' environment during test: {e}")
    print("This error means the Atari ROMs or the 'ale-py' library are NOT correctly set up.")
    print("\n--- TROUBLESHOOTING STEPS ---")
    print("1. Go to 'Runtime -> Restart runtime' in the Colab menu.")
    print("2. Run *this entire installation cell* again from scratch.")
    print("3. Carefully verify the output of `!python -m autorom.accept-rom-license` for ROM downloads.")
    print("4. If the error persists after 2-3 attempts, consider trying a different Colab instance or reporting the full error output.")
# @title Atari Pong AI - Training Script with Custom CNN (Saving to Google Drive)

# This script trains an AI agent to play Atari Pong using Stable Baselines3 (PPO algorithm)
# and Gymnasium. It saves the trained models directly to Google Drive.
# This version includes a custom-defined CNN architecture for the policy.
!pip install stable-baselines3[extra]
!pip install gymnasium[atari,accept-rom-license]
!pip install torch torchvision

import gymnasium as gym
from stable_baselines3 import PPO
# IMPORTANT: Import make_atari_env instead of make_vec_env for proper preprocessing
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
import torch as th
import torch.nn as nn
import os

# # --- 0. Mount Google Drive ---
# print("--- Mounting Google Drive ---")
# from google.colab import drive
# drive.mount('/content/drive')
# print("Google Drive mounted successfully.")


# --- 1. Define Custom CNN Architecture ---
# This class defines the neural network that will process the game's image observations.
# It inherits from BaseFeaturesExtractor, which is the standard way to create
# custom feature extractors in Stable Baselines3.
class CustomCNN(BaseFeaturesExtractor):
    """
    A deeper CNN feature extractor for reinforcement learning with complex environments.
    """
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Flatten()
        )

        with th.no_grad():
            dummy_input = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(dummy_input).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations / 255.0))

# --- Configuration ---
ENV_ID = "ALE/Pong-v5"  # The Gymnasium ID for Atari Pong

# Set the log directory to a path within your Google Drive
LOG_DIR = "./data/pong_ppo_custom_cnn_logs/"
TOTAL_TIMESTEPS = 250_000  # Total number of timesteps for training
SAVE_FREQ = 100_000  # Save model every X timesteps
N_ENVS = 4  # Number of parallel environments to run for vectorized training

# Create log directory in Google Drive if it doesn't exist
os.makedirs(LOG_DIR, exist_ok=True)

print(f"--- Starting Training for {ENV_ID} with Custom CNN ---")
print(f"Logs and models will be saved in: {LOG_DIR}")
print(f"Total timesteps: {TOTAL_TIMESTEPS}")
print(f"Number of parallel environments: {N_ENVS}")

# --- Environment Setup ---
# *** FIX: Use make_atari_env to apply the correct wrappers ***
# This handles frame-stacking, grayscale, resizing, etc., automatically.
vec_env = make_atari_env(ENV_ID, n_envs=N_ENVS, seed=0)
# The environment is now automatically wrapped with VecFrameStack and other
# essential preprocessing wrappers for Atari.
print(f"Successfully created and wrapped vectorized environment for {ENV_ID}")
print(f"Corrected Observation space shape: {vec_env.observation_space.shape}")
    # print(f"ERROR: Failed to create environment '{ENV_ID}': {e}")
    # exit()

# --- Model Definition with Custom Policy ---

# `policy_kwargs` is a dictionary passed to the model constructor.
# It tells the PPO model to use our `CustomCNN` class as the feature extractor.
policy_kwargs = {
    "features_extractor_class": CustomCNN,
    "features_extractor_kwargs": dict(features_dim=256),
}

# The model is now initialized with the 'CnnPolicy' but its default
# feature extractor will be replaced by our custom one via `policy_kwargs`.
model = PPO(
    "CnnPolicy",
    vec_env,
    policy_kwargs=policy_kwargs,
    verbose=1,
    tensorboard_log=LOG_DIR,
    device="auto" # Automatically uses GPU if available, otherwise CPU
)
print("\n--- PPO model initialized with Custom CNN Policy ---")
print("Model Architecture:")
print(model.policy)
print("---------------------------------------------------\n")


# --- Callbacks ---
checkpoint_callback = CheckpointCallback(
    save_freq=max(SAVE_FREQ // N_ENVS, 1),
    save_path=LOG_DIR,
    name_prefix="pong_ppo_custom_model"
)
print(f"Checkpoint callback set to save every {SAVE_FREQ} total timesteps.")

# --- Training ---
print("\n--- Starting Training Process ---")
try:
    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        callback=checkpoint_callback,
        progress_bar=True
    )
    print("\nTraining completed!")
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"\nAn unexpected error occurred during training: {e}")

# --- Save Final Model ---
final_model_path = os.path.join(LOG_DIR, "pong_ppo_custom_final_model")
model.save(final_model_path)
print(f"Final model saved to: {final_model_path}.zip")

# --- Optional: Evaluation (same as before) ---
print("\n--- Evaluation (Optional) ---")
# The evaluation code does not need to be changed.
# The loaded model will expect the same preprocessed observations.
try:
    loaded_model = PPO.load(final_model_path)
    eval_env = make_atari_env(ENV_ID, n_envs=1) # Use the same env creation for eval

    num_episodes = 5
    for episode in range(num_episodes):
        obs = eval_env.reset()
        episode_reward = 0
        done = False
        print(f"Starting evaluation episode {episode + 1}/{num_episodes}...")
        while not done:
            action, _states = loaded_model.predict(obs, deterministic=True)
            obs, reward, done, info = eval_env.step(action)
            episode_reward += reward[0] # Reward is an array in vec env
        print(f"Episode {episode + 1} finished with reward: {episode_reward}")
    eval_env.close()
    print("Evaluation complete.")

except Exception as e:
    print(f"Error during evaluation: {e}")
    print("Evaluation skipped.")

print("\nTo view training progress, you can use TensorBoard:")
print(f"Load TensorBoard in a new Colab cell with: %load_ext tensorboard")
print(f"Then run: %tensorboard --logdir {LOG_DIR}")

--- Starting Robust Installation for Atari Pong ---

0. Upgrading pip...

1. Installing/Upgrading ale-py and gymnasium[atari]...
Collecting gymnasium[atari]
  Using cached gymnasium-1.2.0-py3-none-any.whl.metadata (9.9 kB)
Using cached gymnasium-1.2.0-py3-none-any.whl (944 kB)
Installing collected packages: gymnasium
  Attempting uninstall: gymnasium
    Found existing installation: gymnasium 1.1.1
    Uninstalling gymnasium-1.1.1:
      Successfully uninstalled gymnasium-1.1.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
stable-baselines3 2.6.0 requires gymnasium<1.2.0,>=0.29.1, but you have gymnasium 1.2.0 which is incompatible.[0m[31m
[0mSuccessfully installed gymnasium-1.2.0

2. Installing/Upgrading stable-baselines3...
Collecting gymnasium<1.2.0,>=0.29.1 (from stable-baselines3)
  Using cached gymnasium-1.1.1-py3-none-any.whl.metadata (9.4 kB)
U

A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]




2025-07-21 16:26:19.967670: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-21 16:26:20.070304: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753140380.109489 3875337 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753140380.120937 3875337 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753140380.207591 3875337 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

--- Starting Training for ALE/Pong-v5 with Custom CNN ---
Logs and models will be saved in: ./data/pong_ppo_custom_cnn_logs/
Total timesteps: 250000
Number of parallel environments: 4
Successfully created and wrapped vectorized environment for ALE/Pong-v5
Corrected Observation space shape: (84, 84, 1)
Using cuda device
Wrapping the env in a VecTransposeImage.

--- PPO model initialized with Custom CNN Policy ---
Model Architecture:
ActorCriticCnnPolicy(
  (features_extractor): CustomCNN(
    (cnn): Sequential(
      (0): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_sta

Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 866      |
|    ep_rew_mean     | -20.7    |
| time/              |          |
|    fps             | 881      |
|    iterations      | 1        |
|    time_elapsed    | 9        |
|    total_timesteps | 8192     |
---------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 882         |
|    ep_rew_mean          | -20.6       |
| time/                   |             |
|    fps                  | 716         |
|    iterations           | 2           |
|    time_elapsed         | 22          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.007114288 |
|    clip_fraction        | 0.0257      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.79       |
|    explained_variance   | 8.4e-06     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0575      |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.000612   |
|    value_loss           | 0.245       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 891         |
|    ep_rew_mean          | -20.6       |
| time/                   |             |
|    fps                  | 679         |
|    iterations           | 3           |
|    time_elapsed         | 36          |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.015159073 |
|    clip_fraction        | 0.103       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.78       |
|    explained_variance   | 0.415       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0784      |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.00948    |
|    value_loss           | 0.214       |
-----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 902          |
|    ep_rew_mean          | -20.6        |
| time/                   |              |
|    fps                  | 662          |
|    iterations           | 4            |
|    time_elapsed         | 49           |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 0.0148796085 |
|    clip_fraction        | 0.173        |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.76        |
|    explained_variance   | 0.548        |
|    learning_rate        | 0.0003       |
|    loss                 | 0.0479       |
|    n_updates            | 30           |
|    policy_gradient_loss | -0.0138      |
|    value_loss           | 0.174        |
------------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 907         |
|    ep_rew_mean          | -20.7       |
| time/                   |             |
|    fps                  | 653         |
|    iterations           | 5           |
|    time_elapsed         | 62          |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.018115243 |
|    clip_fraction        | 0.195       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.75       |
|    explained_variance   | 0.562       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0353      |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0132     |
|    value_loss           | 0.15        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 915         |
|    ep_rew_mean          | -20.7       |
| time/                   |             |
|    fps                  | 647         |
|    iterations           | 6           |
|    time_elapsed         | 75          |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.020681404 |
|    clip_fraction        | 0.204       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.7        |
|    explained_variance   | 0.831       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0859      |
|    n_updates            | 50          |
|    policy_gradient_loss | -0.0213     |
|    value_loss           | 0.145       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 918         |
|    ep_rew_mean          | -20.6       |
| time/                   |             |
|    fps                  | 642         |
|    iterations           | 7           |
|    time_elapsed         | 89          |
|    total_timesteps      | 57344       |
| train/                  |             |
|    approx_kl            | 0.028013915 |
|    clip_fraction        | 0.278       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.66       |
|    explained_variance   | 0.533       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.000943    |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.025      |
|    value_loss           | 0.148       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 905         |
|    ep_rew_mean          | -20.6       |
| time/                   |             |
|    fps                  | 639         |
|    iterations           | 8           |
|    time_elapsed         | 102         |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.035081565 |
|    clip_fraction        | 0.298       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.71       |
|    explained_variance   | 0.162       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0115     |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0299     |
|    value_loss           | 0.161       |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 889        |
|    ep_rew_mean          | -20.7      |
| time/                   |            |
|    fps                  | 636        |
|    iterations           | 9          |
|    time_elapsed         | 115        |
|    total_timesteps      | 73728      |
| train/                  |            |
|    approx_kl            | 0.04625691 |
|    clip_fraction        | 0.347      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.71      |
|    explained_variance   | 0.554      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.000519  |
|    n_updates            | 80         |
|    policy_gradient_loss | -0.0369    |
|    value_loss           | 0.121      |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 866         |
|    ep_rew_mean          | -20.8       |
| time/                   |             |
|    fps                  | 635         |
|    iterations           | 10          |
|    time_elapsed         | 128         |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.043870613 |
|    clip_fraction        | 0.337       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.63       |
|    explained_variance   | 0.713       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0272     |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0334     |
|    value_loss           | 0.113       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 835         |
|    ep_rew_mean          | -20.8       |
| time/                   |             |
|    fps                  | 633         |
|    iterations           | 11          |
|    time_elapsed         | 142         |
|    total_timesteps      | 90112       |
| train/                  |             |
|    approx_kl            | 0.052141033 |
|    clip_fraction        | 0.385       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.61       |
|    explained_variance   | 0.861       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0253     |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.0375     |
|    value_loss           | 0.0853      |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 823        |
|    ep_rew_mean          | -20.9      |
| time/                   |            |
|    fps                  | 631        |
|    iterations           | 12         |
|    time_elapsed         | 155        |
|    total_timesteps      | 98304      |
| train/                  |            |
|    approx_kl            | 0.05028468 |
|    clip_fraction        | 0.352      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.55      |
|    explained_variance   | 0.824      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.038     |
|    n_updates            | 110        |
|    policy_gradient_loss | -0.0329    |
|    value_loss           | 0.0675     |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 823         |
|    ep_rew_mean          | -20.9       |
| time/                   |             |
|    fps                  | 630         |
|    iterations           | 13          |
|    time_elapsed         | 168         |
|    total_timesteps      | 106496      |
| train/                  |             |
|    approx_kl            | 0.055768535 |
|    clip_fraction        | 0.41        |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.54       |
|    explained_variance   | 0.855       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0252     |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0388     |
|    value_loss           | 0.0722      |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 820         |
|    ep_rew_mean          | -20.9       |
| time/                   |             |
|    fps                  | 628         |
|    iterations           | 14          |
|    time_elapsed         | 182         |
|    total_timesteps      | 114688      |
| train/                  |             |
|    approx_kl            | 0.068365134 |
|    clip_fraction        | 0.453       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.5        |
|    explained_variance   | 0.867       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0299     |
|    n_updates            | 130         |
|    policy_gradient_loss | -0.0367     |
|    value_loss           | 0.0811      |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 820        |
|    ep_rew_mean          | -20.9      |
| time/                   |            |
|    fps                  | 627        |
|    iterations           | 15         |
|    time_elapsed         | 195        |
|    total_timesteps      | 122880     |
| train/                  |            |
|    approx_kl            | 0.07735647 |
|    clip_fraction        | 0.42       |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.3       |
|    explained_variance   | 0.882      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0122    |
|    n_updates            | 140        |
|    policy_gradient_loss | -0.0476    |
|    value_loss           | 0.0707     |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 843         |
|    ep_rew_mean          | -20.9       |
| time/                   |             |
|    fps                  | 627         |
|    iterations           | 16          |
|    time_elapsed         | 209         |
|    total_timesteps      | 131072      |
| train/                  |             |
|    approx_kl            | 0.112919174 |
|    clip_fraction        | 0.494       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.34       |
|    explained_variance   | 0.807       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0459     |
|    n_updates            | 150         |
|    policy_gradient_loss | -0.0496     |
|    value_loss           | 0.101       |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 873        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 626        |
|    iterations           | 17         |
|    time_elapsed         | 222        |
|    total_timesteps      | 139264     |
| train/                  |            |
|    approx_kl            | 0.10431445 |
|    clip_fraction        | 0.5        |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.33      |
|    explained_variance   | 0.893      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0412    |
|    n_updates            | 160        |
|    policy_gradient_loss | -0.0546    |
|    value_loss           | 0.113      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 871        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 626        |
|    iterations           | 18         |
|    time_elapsed         | 235        |
|    total_timesteps      | 147456     |
| train/                  |            |
|    approx_kl            | 0.08115425 |
|    clip_fraction        | 0.451      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.37      |
|    explained_variance   | 0.927      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0583    |
|    n_updates            | 170        |
|    policy_gradient_loss | -0.044     |
|    value_loss           | 0.126      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 824        |
|    ep_rew_mean          | -20.9      |
| time/                   |            |
|    fps                  | 625        |
|    iterations           | 19         |
|    time_elapsed         | 248        |
|    total_timesteps      | 155648     |
| train/                  |            |
|    approx_kl            | 0.07816298 |
|    clip_fraction        | 0.416      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.22      |
|    explained_variance   | 0.929      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0414     |
|    n_updates            | 180        |
|    policy_gradient_loss | -0.0412    |
|    value_loss           | 0.0864     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 807        |
|    ep_rew_mean          | -20.9      |
| time/                   |            |
|    fps                  | 625        |
|    iterations           | 20         |
|    time_elapsed         | 262        |
|    total_timesteps      | 163840     |
| train/                  |            |
|    approx_kl            | 0.09500611 |
|    clip_fraction        | 0.405      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.11      |
|    explained_variance   | 0.932      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.086     |
|    n_updates            | 190        |
|    policy_gradient_loss | -0.0392    |
|    value_loss           | 0.0372     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 804        |
|    ep_rew_mean          | -21        |
| time/                   |            |
|    fps                  | 624        |
|    iterations           | 21         |
|    time_elapsed         | 275        |
|    total_timesteps      | 172032     |
| train/                  |            |
|    approx_kl            | 0.10785848 |
|    clip_fraction        | 0.435      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.04      |
|    explained_variance   | 0.914      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0227    |
|    n_updates            | 200        |
|    policy_gradient_loss | -0.0418    |
|    value_loss           | 0.083      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 823        |
|    ep_rew_mean          | -20.9      |
| time/                   |            |
|    fps                  | 624        |
|    iterations           | 22         |
|    time_elapsed         | 288        |
|    total_timesteps      | 180224     |
| train/                  |            |
|    approx_kl            | 0.12440715 |
|    clip_fraction        | 0.535      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.901     |
|    explained_variance   | 0.938      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0862    |
|    n_updates            | 210        |
|    policy_gradient_loss | -0.0535    |
|    value_loss           | 0.0315     |
----------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 842       |
|    ep_rew_mean          | -20.9     |
| time/                   |           |
|    fps                  | 624       |
|    iterations           | 23        |
|    time_elapsed         | 301       |
|    total_timesteps      | 188416    |
| train/                  |           |
|    approx_kl            | 0.1598593 |
|    clip_fraction        | 0.438     |
|    clip_range           | 0.2       |
|    entropy_loss         | -0.755    |
|    explained_variance   | 0.914     |
|    learning_rate        | 0.0003    |
|    loss                 | -0.0395   |
|    n_updates            | 220       |
|    policy_gradient_loss | -0.0301   |
|    value_loss           | 0.0876    |
---------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 863        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 623        |
|    iterations           | 24         |
|    time_elapsed         | 315        |
|    total_timesteps      | 196608     |
| train/                  |            |
|    approx_kl            | 0.16405787 |
|    clip_fraction        | 0.473      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.821     |
|    explained_variance   | 0.819      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0176    |
|    n_updates            | 230        |
|    policy_gradient_loss | -0.0329    |
|    value_loss           | 0.122      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 882        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 623        |
|    iterations           | 25         |
|    time_elapsed         | 328        |
|    total_timesteps      | 204800     |
| train/                  |            |
|    approx_kl            | 0.10990186 |
|    clip_fraction        | 0.462      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.96      |
|    explained_variance   | 0.925      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0312    |
|    n_updates            | 240        |
|    policy_gradient_loss | -0.0288    |
|    value_loss           | 0.111      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 887        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 623        |
|    iterations           | 26         |
|    time_elapsed         | 341        |
|    total_timesteps      | 212992     |
| train/                  |            |
|    approx_kl            | 0.10281505 |
|    clip_fraction        | 0.452      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.966     |
|    explained_variance   | 0.933      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.00327   |
|    n_updates            | 250        |
|    policy_gradient_loss | -0.0457    |
|    value_loss           | 0.132      |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 904         |
|    ep_rew_mean          | -20.7       |
| time/                   |             |
|    fps                  | 622         |
|    iterations           | 27          |
|    time_elapsed         | 355         |
|    total_timesteps      | 221184      |
| train/                  |             |
|    approx_kl            | 0.101563044 |
|    clip_fraction        | 0.444       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.918      |
|    explained_variance   | 0.949       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0374     |
|    n_updates            | 260         |
|    policy_gradient_loss | -0.0445     |
|    value_loss           | 0.105       |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 889        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 622        |
|    iterations           | 28         |
|    time_elapsed         | 368        |
|    total_timesteps      | 229376     |
| train/                  |            |
|    approx_kl            | 0.11701885 |
|    clip_fraction        | 0.462      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.975     |
|    explained_variance   | 0.95       |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0116    |
|    n_updates            | 270        |
|    policy_gradient_loss | -0.0488    |
|    value_loss           | 0.125      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 888        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 622        |
|    iterations           | 29         |
|    time_elapsed         | 381        |
|    total_timesteps      | 237568     |
| train/                  |            |
|    approx_kl            | 0.07321514 |
|    clip_fraction        | 0.435      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.19      |
|    explained_variance   | 0.948      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.00223    |
|    n_updates            | 280        |
|    policy_gradient_loss | -0.0304    |
|    value_loss           | 0.0853     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 871        |
|    ep_rew_mean          | -20.8      |
| time/                   |            |
|    fps                  | 622        |
|    iterations           | 30         |
|    time_elapsed         | 394        |
|    total_timesteps      | 245760     |
| train/                  |            |
|    approx_kl            | 0.06698075 |
|    clip_fraction        | 0.45       |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.29      |
|    explained_variance   | 0.934      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.03      |
|    n_updates            | 290        |
|    policy_gradient_loss | -0.0451    |
|    value_loss           | 0.116      |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 890         |
|    ep_rew_mean          | -20.8       |
| time/                   |             |
|    fps                  | 622         |
|    iterations           | 31          |
|    time_elapsed         | 408         |
|    total_timesteps      | 253952      |
| train/                  |             |
|    approx_kl            | 0.101109914 |
|    clip_fraction        | 0.487       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.39       |
|    explained_variance   | 0.0399      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00299    |
|    n_updates            | 300         |
|    policy_gradient_loss | 0.000138    |
|    value_loss           | 0.16        |
-----------------------------------------



Training completed!
Final model saved to: ./data/pong_ppo_custom_cnn_logs/pong_ppo_custom_final_model.zip

--- Evaluation (Optional) ---
Starting evaluation episode 1/5...
Episode 1 finished with reward: -21.0
Starting evaluation episode 2/5...
Episode 2 finished with reward: -21.0
Starting evaluation episode 3/5...
Episode 3 finished with reward: -21.0
Starting evaluation episode 4/5...
Episode 4 finished with reward: -21.0
Starting evaluation episode 5/5...
Episode 5 finished with reward: -21.0
Evaluation complete.

To view training progress, you can use TensorBoard:
Load TensorBoard in a new Colab cell with: %load_ext tensorboard
Then run: %tensorboard --logdir ./data/pong_ppo_custom_cnn_logs/


In [2]:
# @title Record Atari Pong Game Video

# This script loads a trained Stable Baselines3 model for Atari Pong
# from the Colab runtime disk and records a video of its gameplay.

import gymnasium as gym
from stable_baselines3 import PPO
from gymnasium.wrappers import RecordVideo
import os

# --- Configuration ---
ENV_ID = "ALE/Pong-v5"  # The Gymnasium ID for Atari Pong

# Path to the trained model on the Colab runtime disk
# This assumes your training script saved the final model here.
MODEL_PATH_ON_RUNTIME_DISK = "./data/pong_ppo_logs/pong_ppo_final_model.zip"

# Directory where the video will be saved on the Colab runtime disk
VIDEO_DIR = "./pong_game_videos/"
VIDEO_PREFIX = "pong_agent_game" # Prefix for the video filename

# Create the video directory if it doesn't exist
os.makedirs(VIDEO_DIR, exist_ok=True)

print(f"--- Starting Video Recording for {ENV_ID} ---")
print(f"Loading model from: {MODEL_PATH_ON_RUNTIME_DISK}")
print(f"Video will be saved to: {VIDEO_DIR}")

# --- Load the Trained Model ---
try:
    model = PPO.load(MODEL_PATH_ON_RUNTIME_DISK)
    print("Model loaded successfully!")
except Exception as e:
    print(f"ERROR: Could not load model from {MODEL_PATH_ON_RUNTIME_DISK}: {e}")
    print("Please ensure the training script completed and saved the model to this path.")
    print("If you restarted the runtime, the model might have been deleted. You might need to re-run training or load from Google Drive.")
    exit() # Exit if model cannot be loaded

# --- Create Environment with Video Recording Wrapper ---
# The RecordVideo wrapper will automatically save a video of the episode.
# `video_folder`: directory to save videos.
# `episode_trigger`: records every episode (here, we only run one).
# `disable_logger`: disables verbose logging from the wrapper.
# try:
    # We need render_mode="rgb_array" for video recording
env = gym.make(ENV_ID, render_mode="rgb_array")
env = RecordVideo(
    env,
    video_folder=VIDEO_DIR,
    episode_trigger=lambda x: True, # Record every episode
    name_prefix=VIDEO_PREFIX,
    disable_logger=True # Suppress some logging messages from RecordVideo
)
print(f"Environment '{ENV_ID}' created and wrapped for video recording.")
# except Exception as e:
#     print(f"ERROR: Could not create environment or video wrapper: {e}")
#     print("Ensure gymnasium and its Atari dependencies are correctly installed.")
#     exit()

# --- Play One Episode and Record ---
print("\nStarting game episode and recording video...")
obs, info = env.reset()
done = False
truncated = False
episode_reward = 0


while not done and not truncated:
    action, _states = model.predict(obs, deterministic=True) # deterministic=True for consistent playback
    obs, reward, done, truncated, info = env.step(action)
    episode_reward += reward

print(f"\nGame episode finished. Total reward: {episode_reward}")

# --- Close Environment and Finalize Video ---
env.close() # This is crucial for the RecordVideo wrapper to finalize the video file.
print("Environment closed. Video recording finalized.")

# --- Instructions for Downloading Video ---
print("\n--- Video Saved! ---")
print(f"Your video should be saved in the '{VIDEO_DIR}' directory on the Colab runtime disk.")
print("To download it:")
print("1. Click the 'Files' icon (folder icon) on the left sidebar in Colab.")
print(f"2. Navigate into the '{VIDEO_DIR}' folder.")
print(f"3. Look for a file named something like '{VIDEO_PREFIX}-episode-0.mp4'.")
print("4. Right-click on the video file and select 'Download'.")
print("\nRemember: Colab runtime disk is temporary. Download your video before the session ends!")

--- Starting Video Recording for ALE/Pong-v5 ---
Loading model from: ./data/pong_ppo_logs/pong_ppo_final_model.zip
Video will be saved to: ./pong_game_videos/
Model loaded successfully!


  logger.warn(


Environment 'ALE/Pong-v5' created and wrapped for video recording.

Starting game episode and recording video...

Game episode finished. Total reward: -5.0
Environment closed. Video recording finalized.

--- Video Saved! ---
Your video should be saved in the './pong_game_videos/' directory on the Colab runtime disk.
To download it:
1. Click the 'Files' icon (folder icon) on the left sidebar in Colab.
2. Navigate into the './pong_game_videos/' folder.
3. Look for a file named something like 'pong_agent_game-episode-0.mp4'.
4. Right-click on the video file and select 'Download'.

Remember: Colab runtime disk is temporary. Download your video before the session ends!
