<a href="https://colab.research.google.com/github/johngrahamreynolds/DeepRL/blob/main/HuggingFaceCourse/DeepQLearning/SpaceInvadersWithRendering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unit 3: Deep Q-Learning with Atari Games 👾 using RL Baselines3 Zoo

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit4/thumbnail.jpg" alt="Unit 3 Thumbnail">

In this notebook, **you'll train a Deep Q-Learning agent** playing Space Invaders using [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo), a training framework based on [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) that provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.

We're using the [RL-Baselines-3 Zoo integration, a vanilla version of Deep Q-Learning](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html) with no extensions such as Double-DQN, Dueling-DQN, and Prioritized Experience Replay.

⬇️ Here is an example of what **you will achieve** ⬇️

In [None]:
%%html
<video controls autoplay><source src="https://huggingface.co/ThomasSimonini/ppo-SpaceInvadersNoFrameskip-v4/resolve/main/replay.mp4" type="video/mp4"></video>

### 🎮 Environments:

- [SpacesInvadersNoFrameskip-v4](https://gymnasium.farama.org/environments/atari/space_invaders/)

You can see the difference between Space Invaders versions here 👉 https://gymnasium.farama.org/environments/atari/space_invaders/#variants

### 📚 RL-Library:

- [RL-Baselines3-Zoo](https://github.com/DLR-RM/rl-baselines3-zoo)

## Objectives of this notebook 🏆
At the end of the notebook, you will:
- Be able to understand deeper **how RL Baselines3 Zoo works**.
- Be able to **push your trained agent and the code to the Hub** with a nice video replay and an evaluation score 🔥.




## This notebook is from Deep Reinforcement Learning Course
<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/deep-rl-course-illustration.jpg" alt="Deep RL Course illustration"/>

In this free course, you will:

- 📖 Study Deep Reinforcement Learning in **theory and practice**.
- 🧑‍💻 Learn to **use famous Deep RL libraries** such as Stable Baselines3, RL Baselines3 Zoo, CleanRL and Sample Factory 2.0.
- 🤖 Train **agents in unique environments**

And more check 📚 the syllabus 👉 https://simoninithomas.github.io/deep-rl-course

Don’t forget to **<a href="http://eepurl.com/ic5ZUD">sign up to the course</a>** (we are collecting your email to be able to **send you the links when each Unit is published and give you information about the challenges and updates).**


The best way to keep in touch is to join our discord server to exchange with the community and with us 👉🏻 https://discord.gg/ydHrjt3WP5

## Prerequisites 🏗️
Before diving into the notebook, you need to:

🔲 📚 **[Study Deep Q-Learning by reading Unit 3](https://huggingface.co/deep-rl-course/unit3/introduction)**  🤗

We're constantly trying to improve our tutorials, so **if you find some issues in this notebook**, please [open an issue on the Github Repo](https://github.com/huggingface/deep-rl-class/issues).

# Let's train a Deep Q-Learning agent playing Atari' Space Invaders 👾 and upload it to the Hub.

We strongly recommend students **to use Google Colab for the hands-on exercises instead of running them on their personal computers**.

By using Google Colab, **you can focus on learning and experimenting without worrying about the technical aspects of setting up your environments**.

To validate this hands-on for the certification process, you need to push your trained model to the Hub and **get a result of >= 200**.

To find your result, go to the leaderboard and find your model, **the result = mean_reward - std of reward**

For more information about the certification process, check this section 👉 https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process

## An advice 💡
It's better to run this colab in a copy on your Google Drive, so that **if it timeouts** you still have the saved notebook on your Google Drive and do not need to fill everything from scratch.

To do that you can either do `Ctrl + S` or `File > Save a copy in Google Drive.`

Also, we're going to **train it for 90 minutes with 1M timesteps**. By typing `!nvidia-smi` will tell you what GPU you're using.

And if you want to train more such 10 million steps, this will take about 9 hours, potentially resulting in Colab timing out. In that case, I recommend running this on your local computer (or somewhere else). Just click on: `File>Download`.

## Set the GPU 💪
- To **accelerate the agent's training, we'll use a GPU**. To do that, go to `Runtime > Change Runtime type`

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step1.jpg" alt="GPU Step 1">

- `Hardware Accelerator > GPU`

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step2.jpg" alt="GPU Step 2">

In [None]:
!nvidia-smi

# Install RL-Baselines3 Zoo and its dependencies 📚

If you see `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.` **this is normal and it's not a critical error** there's a conflict of version. But the packages we need are installed.

In [None]:
!pip install git+https://github.com/DLR-RM/rl-baselines3-zoo

In [None]:
!apt-get install swig cmake ffmpeg

To be able to use Atari games in Gymnasium we need to install atari package. And accept-rom-license to download the rom files (games files).

In [None]:
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]

## Create a virtual display 🔽

During the notebook, we'll need to generate a replay video. To do so, with colab, **we need to have a virtual screen to be able to render the environment** (and thus record the frames).

Hence the following cell will install the librairies and create and run a virtual screen 🖥

In [None]:
%%capture
!apt install python-opengl
!apt install xvfb
!pip3 install pyvirtualdisplay

In [None]:
# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

## Train our Deep Q-Learning Agent to Play Space Invaders 👾

To train an agent with RL-Baselines3-Zoo, we just need to do two things:

1. Create a hyperparameter config file that will contain our training hyperparameters called `dqn.yml`.

This is a template example:

```
SpaceInvadersNoFrameskip-v4:
  env_wrapper:
    - stable_baselines3.common.atari_wrappers.AtariWrapper
  frame_stack: 4
  policy: 'CnnPolicy'
  n_timesteps: !!float 1e6
  buffer_size: 100000
  learning_rate: !!float 1e-4
  batch_size: 32
  learning_starts: 100000
  target_update_interval: 1000
  train_freq: 4
  gradient_steps: 1
  exploration_fraction: 0.1
  exploration_final_eps: 0.01
  # If True, you need to deactivate handle_timeout_termination
  # in the replay_buffer_kwargs
  optimize_memory_usage: False
```

Here we see that:
- We use the `Atari Wrapper` that preprocess the input (Frame reduction ,grayscale, stack 4 frames)
- We use `CnnPolicy`, since we use Convolutional layers to process the frames
- We train it for 10 million `n_timesteps`
- Memory (Experience Replay) size is 100000, aka the amount of experience steps you saved to train again your agent with.

💡 My advice is to **reduce the training timesteps to 1M,** which will take about 90 minutes on a P100. `!nvidia-smi` will tell you what GPU you're using. At 10 million steps, this will take about 9 hours, which could likely result in Colab timing out. I recommend running this on your local computer (or somewhere else). Just click on: `File>Download`.

In terms of hyperparameters optimization, my advice is to focus on these 3 hyperparameters:
- `learning_rate`
- `buffer_size (Experience Memory size)`
- `batch_size`

As a good practice, you need to **check the documentation to understand what each hyperparameters does**: https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html#parameters



2. We start the training and save the models on `logs` folder 📁

- Define the algorithm after `--algo`, where we save the model after `-f` and where the hyperparameter config is after `-c`.

In [None]:
!python -m rl_zoo3.train --algo dqn --env SpaceInvadersNoFrameskip-v4  -f logs/  -c dqn.yml

## Let's evaluate our agent 👀
- RL-Baselines3-Zoo provides `enjoy.py`, a python script to evaluate our agent. In most RL libraries, we call the evaluation script `enjoy.py`.
- Let's evaluate it for 5000 timesteps 🔥

In [None]:
!python -m rl_zoo3.enjoy  --algo dqn  --env SpaceInvadersNoFrameskip-v4  --no-render  --n-timesteps 5000  --folder logs/

## Publish our trained model on the Hub 🚀
Now that we saw we got good results after the training, we can publish our trained model on the hub 🤗 with one line of code.

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/unit3/space-invaders-model.gif" alt="Space Invaders model">

By using `rl_zoo3.push_to_hub` **you evaluate, record a replay, generate a model card of your agent and push it to the hub**.

This way:
- You can **showcase our work** 🔥
- You can **visualize your agent playing** 👀
- You can **share with the community an agent that others can use** 💾
- You can **access a leaderboard 🏆 to see how well your agent is performing compared to your classmates** 👉  https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard

To be able to share your model with the community there are three more steps to follow:

1️⃣ (If it's not already done) create an account to HF ➡ https://huggingface.co/join

2️⃣ Sign in and then, you need to store your authentication token from the Hugging Face website.
- Create a new token (https://huggingface.co/settings/tokens) **with write role**

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/create-token.jpg" alt="Create HF Token">

- Copy the token
- Run the cell below and past the token

In [None]:
from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.
notebook_login()
!git config --global credential.helper store

If you don't want to use a Google Colab or a Jupyter Notebook, you need to use this command instead: `huggingface-cli login`

3️⃣ We're now ready to push our trained agent to the 🤗 Hub 🔥

Let's run push_to_hub.py file to upload our trained agent to the Hub.

`--repo-name `: The name of the repo

`-orga`: Your Hugging Face username

`-f`: Where the trained model folder is (in our case `logs`)

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/unit3/select-id.png" alt="Select Id">

In [None]:
!python -m rl_zoo3.push_to_hub  --algo dqn  --env SpaceInvadersNoFrameskip-v4  --repo-name dqn-SpaceInvadersNoFrameskip-v4 -orga MarioBarbeque -f logs/

In [None]:
!cd /tmp && ls

###.

Congrats 🥳 you've just trained and uploaded your first Deep Q-Learning agent using RL-Baselines-3 Zoo. The script above should have displayed a link to a model repository such as https://huggingface.co/ThomasSimonini/dqn-SpaceInvadersNoFrameskip-v4. When you go to this link, you can:

- See a **video preview of your agent** at the right.
- Click "Files and versions" to see all the files in the repository.
- Click "Use in stable-baselines3" to get a code snippet that shows how to load the model.
- A model card (`README.md` file) which gives a description of the model and the hyperparameters you used.

Under the hood, the Hub uses git-based repositories (don't worry if you don't know what git is), which means you can update the model with new versions as you experiment and improve your agent.

**Compare the results of your agents with your classmates** using the [leaderboard](https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard) 🏆

## Load a powerful trained model 🔥
- The Stable-Baselines3 team uploaded **more than 150 trained Deep Reinforcement Learning agents on the Hub**.

You can find them here: 👉 https://huggingface.co/sb3

Some examples:
- Asteroids: https://huggingface.co/sb3/dqn-AsteroidsNoFrameskip-v4
- Beam Rider: https://huggingface.co/sb3/dqn-BeamRiderNoFrameskip-v4
- Breakout: https://huggingface.co/sb3/dqn-BreakoutNoFrameskip-v4
- Road Runner: https://huggingface.co/sb3/dqn-RoadRunnerNoFrameskip-v4

Let's load an agent playing Beam Rider: https://huggingface.co/sb3/dqn-BeamRiderNoFrameskip-v4

In [None]:
%%html
<video controls autoplay><source src="https://huggingface.co/sb3/dqn-BeamRiderNoFrameskip-v4/resolve/main/replay.mp4" type="video/mp4"></video>

1. We download the model using `rl_zoo3.load_from_hub`, and place it in a new folder that we can call `rl_trained`

In [None]:
# Download model and save it into the logs/ folder
!python -m rl_zoo3.load_from_hub --algo dqn --env BeamRiderNoFrameskip-v4 -orga sb3 -f rl_trained/

2. Let's evaluate if for 5000 timesteps

In [None]:
!python -m rl_zoo3.enjoy --algo dqn --env BeamRiderNoFrameskip-v4 -n 5000  -f rl_trained/ --no-render

Why not trying to train your own **Deep Q-Learning Agent playing BeamRiderNoFrameskip-v4? 🏆.**

If you want to try, check https://huggingface.co/sb3/dqn-BeamRiderNoFrameskip-v4#hyperparameters **in the model card, you have the hyperparameters of the trained agent.**

But finding hyperparameters can be a daunting task. Fortunately, we'll see in the next Unit, how we can **use Optuna for optimizing the Hyperparameters 🔥.**


## Some additional challenges 🏆
The best way to learn **is to try things by your own**!

In the [Leaderboard](https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard) you will find your agents. Can you get to the top?

Here's a list of environments you can try to train your agent with:
- BeamRiderNoFrameskip-v4
- BreakoutNoFrameskip-v4
- EnduroNoFrameskip-v4
- PongNoFrameskip-v4

Also, **if you want to learn to implement Deep Q-Learning by yourself**, you definitely should look at CleanRL implementation: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py

<img src="https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit4/atari-envs.gif" alt="Environments"/>

________________________________________________________________________
Congrats on finishing this chapter!

If you’re still feel confused with all these elements...it's totally normal! **This was the same for me and for all people who studied RL.**

Take time to really **grasp the material before continuing and try the additional challenges**. It’s important to master these elements and having a solid foundations.

In the next unit, **we’re going to learn about [Optuna](https://optuna.org/)**. One of the most critical task in Deep Reinforcement Learning is to find a good set of training hyperparameters. And Optuna is a library that helps you to automate the search.






### This is a course built with you 👷🏿‍♀️

Finally, we want to improve and update the course iteratively with your feedback. If you have some, please fill this form 👉 https://forms.gle/3HgA7bEHwAmmLfwh9

We're constantly trying to improve our tutorials, so **if you find some issues in this notebook**, please [open an issue on the Github Repo](https://github.com/huggingface/deep-rl-class/issues).

See you on Bonus unit 2! 🔥

### Custom code to render an mp4 sample of our agent playing

In [None]:
# load the model from HF
!python -m rl_zoo3.load_from_hub --algo dqn --env SpaceInvadersNoFrameskip-v4 -orga MarioBarbeque -f logs/

In [None]:
!pip install ale-py[roms]
!pip install gymnasium[atari]
!pip install opencv-python

In [None]:
!pip install autorom[accept-rom-license]
# Then import to download ROMs
!python -c "import ale_py.roms"

In [None]:
import gymnasium as gym
print(gym.envs.registry.keys())
# Look for Space Invaders variants

In [None]:
!pip install stable-baselines3[extra]

In [None]:
import os
import cv2
import numpy as np
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor

def make_atari_env(env_id='SpaceInvadersNoFrameskip-v4', render_mode='rgb_array'):
    """
    Create Atari environment with the same preprocessing as used in training
    """
    def _init():
        env = gym.make(env_id, render_mode=render_mode)
        env = AtariWrapper(env)  # This applies the standard Atari preprocessing
        return env
    return _init

def record_preprocessed_agent(model_path, output_path="enjoy.mp4",
                            env_id='SpaceInvadersNoFrameskip-v4', n_episodes=1):
    """
    Record agent with proper Atari preprocessing pipeline
    """

    print(f"Creating environment: {env_id}")

    # Create vectorized environment with preprocessing (for model)
    env = DummyVecEnv([make_atari_env(env_id, render_mode='rgb_array')])
    env = VecFrameStack(env, n_stack=4)  # Stack 4 frames as expected by model

    # Create separate environment for rendering (raw frames)
    render_env = gym.make(env_id, render_mode='rgb_array')

    print(f"Loading model from: {model_path}")
    model = DQN.load(model_path, env=env)
    print("Model loaded successfully")

    # Recording setup
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 60
    all_frames = []

    for episode in range(n_episodes):
        # Reset both environments
        obs = env.reset()
        render_obs, render_info = render_env.reset()

        done = False
        episode_frames = []
        episode_reward = 0
        step_count = 0

        print(f"\nRecording episode {episode + 1}/{n_episodes}...")
        print(f"Model input shape: {obs.shape}")
        print(f"Render frame shape: {render_obs.shape}")

        while not done:
            # Get action from model using preprocessed observations
            action, _states = model.predict(obs, deterministic=True)

            # Step both environments with the same action
            obs, reward, done, info = env.step(action)
            render_obs, render_reward, render_terminated, render_truncated, render_info = render_env.step(action[0])

            episode_reward += reward[0]
            step_count += 1

            # Capture high-resolution frame for video
            try:
                frame = render_env.render()
                if frame is not None and len(frame.shape) == 3:
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    episode_frames.append(frame_bgr)
            except Exception as e:
                if step_count == 1:
                    print(f"Warning: Could not capture frame: {e}")

            # Safety break for very long episodes
            if step_count > 50000:
                print("Episode exceeded 50000 steps, ending...")
                break

        all_frames.extend(episode_frames)
        print(f"Episode {episode + 1} completed:")
        print(f"  - Steps: {step_count}")
        print(f"  - Total Reward: {episode_reward:.2f}")
        print(f"  - Frames captured: {len(episode_frames)}")

    # Write video
    if all_frames:
        h, w = all_frames[0].shape[:2]
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        print(f"\nWriting {len(all_frames)} frames to {output_path}...")
        for i, frame in enumerate(all_frames):
            out.write(frame)
            if i % 100 == 0:
                print(f"Progress: {i}/{len(all_frames)} frames")

        out.release()
        print(f"✓ Video saved to {output_path}")
    else:
        print("✗ No frames captured!")

    env.close()
    render_env.close()


if __name__ == "__main__":
    # Replace with your actual model path
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        record_preprocessed_agent(
            model_path=MODEL_PATH,
            output_path="enjoy.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_episodes=1
        )
    else:
        print(f"Model file not found: {MODEL_PATH}")
        print("Please update MODEL_PATH with the correct path to your .zip file")

        # List files in logs directory to help find the model
        if os.path.exists("logs"):
            print("\nAvailable files in logs/:")
            for root, dirs, files in os.walk("logs"):
                for file in files:
                    if file.endswith('.zip'):
                        print(f"  {os.path.join(root, file)}")

In [None]:
import os
import cv2
import numpy as np
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper

def record_with_proper_env_matching(model_path, output_path="spaceinvaders_demo.mp4",
                                   env_id='SpaceInvadersNoFrameskip-v4', n_episodes=1):
    """
    Record agent ensuring exact environment matching with training
    """

    print(f"Loading model to inspect training environment: {model_path}")

    # First, load the model without any environment to inspect its expected spaces
    model_data = DQN.load(model_path, device='cpu')
    expected_obs_space = model_data.observation_space
    expected_action_space = model_data.action_space

    print(f"Model expects observation space: {expected_obs_space}")
    print(f"Model expects action space: {expected_action_space}")

    # Create environment that matches the model's expectations
    def make_training_env():
        env = gym.make(env_id, render_mode='rgb_array')
        env = AtariWrapper(env)
        return env

    # Create vectorized environment with frame stacking
    vec_env = DummyVecEnv([make_training_env])
    vec_env = VecFrameStack(vec_env, n_stack=4)

    print(f"Created environment observation space: {vec_env.observation_space}")
    print(f"Created environment action space: {vec_env.action_space}")

    # Verify spaces match
    if vec_env.observation_space != expected_obs_space:
        print("WARNING: Observation spaces don't match exactly!")
        print(f"Expected: {expected_obs_space}")
        print(f"Got: {vec_env.observation_space}")
        print("Attempting to load with force_reset=True...")

    # Load model with the matching environment
    try:
        model = DQN.load(model_path, env=vec_env, force_reset=True)
        print("Model loaded successfully with force_reset=True")
    except Exception as e:
        print(f"Failed to load with environment: {e}")
        print("Loading model without environment and setting manually...")
        model = DQN.load(model_path)
        model.set_env(vec_env)

    # Create separate high-resolution environment for rendering
    render_env = gym.make(env_id, render_mode='rgb_array')

    # Recording setup
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 30
    all_frames = []

    for episode in range(n_episodes):
        # Reset both environments
        obs = vec_env.reset()
        render_obs, render_info = render_env.reset()

        done = False
        episode_frames = []
        episode_reward = 0
        step_count = 0

        print(f"\nRecording episode {episode + 1}/{n_episodes}...")
        print(f"Vectorized env obs shape: {obs.shape}")
        print(f"Render env obs shape: {render_obs.shape}")

        while not done:
            # Get action from model using preprocessed observations
            action, _states = model.predict(obs, deterministic=True)

            # Step vectorized environment
            obs, reward, done_vec, info = vec_env.step(action)
            done = done_vec[0]

            # Step render environment with the same action
            render_action = action[0] if hasattr(action, '__len__') else action
            render_obs, render_reward, render_term, render_trunc, render_info = render_env.step(render_action)

            episode_reward += reward[0]
            step_count += 1

            # Capture high-resolution frame for video
            try:
                frame = render_env.render()
                if frame is not None and len(frame.shape) == 3:
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    episode_frames.append(frame_bgr)
            except Exception as e:
                if step_count == 1:
                    print(f"Warning: Could not capture frame: {e}")

            # Safety break
            if step_count > 10000:
                print("Episode exceeded 10000 steps, ending...")
                break

        all_frames.extend(episode_frames)
        print(f"Episode {episode + 1} completed:")
        print(f"  - Steps: {step_count}")
        print(f"  - Total Reward: {episode_reward:.2f}")
        print(f"  - Frames captured: {len(episode_frames)}")

    # Write video
    if all_frames:
        h, w = all_frames[0].shape[:2]
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        print(f"\nWriting {len(all_frames)} frames to {output_path}...")
        for i, frame in enumerate(all_frames):
            out.write(frame)
            if i % 100 == 0:
                print(f"Progress: {i}/{len(all_frames)} frames")

        out.release()
        print(f"✓ Video saved to {output_path}")
    else:
        print("✗ No frames captured!")

    vec_env.close()
    render_env.close()

def record_using_rl_zoo_config(model_path, output_path="spaceinvaders_demo.mp4",
                              env_id='SpaceInvadersNoFrameskip-v4', n_episodes=1):
    """
    Use RL Zoo3's configuration files to recreate exact training environment
    """
    try:
        from rl_zoo3 import ALGOS
        from rl_zoo3.utils import get_saved_hyperparams
        from stable_baselines3.common.utils import set_random_seed
        import yaml

        print("Attempting to use RL Zoo3 configuration...")

        # Set random seed
        set_random_seed(0)

        # Get model directory
        model_dir = os.path.dirname(model_path)

        # Try to load hyperparameters
        try:
            hyperparams, stats_path = get_saved_hyperparams(
                model_dir,
                norm_reward=False,
                test_mode=True
            )
            print(f"Loaded hyperparameters: {hyperparams}")
        except Exception as e:
            print(f"Could not load hyperparameters: {e}")
            hyperparams = {}
            stats_path = None

        # Load model without environment first
        print("Loading model without environment...")
        model = ALGOS['dqn'].load(model_path)

        # Create environment manually with Atari preprocessing
        def make_env():
            env = gym.make(env_id, render_mode='rgb_array')
            env = AtariWrapper(env)
            return env

        vec_env = DummyVecEnv([make_env])
        vec_env = VecFrameStack(vec_env, n_stack=4)

        # Set the environment on the model
        print("Setting environment on model...")
        model.set_env(vec_env)

        # Create render environment
        render_env = gym.make(env_id, render_mode='rgb_array')

        # Continue with recording as before...
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = 30
        all_frames = []

        for episode in range(n_episodes):
            obs = vec_env.reset()
            render_obs, render_info = render_env.reset()

            done = False
            episode_frames = []
            episode_reward = 0
            step_count = 0

            print(f"\nRecording episode {episode + 1}/{n_episodes}...")

            while not done:
                action, _states = model.predict(obs, deterministic=True)

                obs, reward, done_vec, info = vec_env.step(action)
                done = done_vec[0]

                render_action = action[0] if hasattr(action, '__len__') else action
                render_obs, render_reward, render_term, render_trunc, render_info = render_env.step(render_action)

                episode_reward += reward[0]
                step_count += 1

                try:
                    frame = render_env.render()
                    if frame is not None and len(frame.shape) == 3:
                        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                        episode_frames.append(frame_bgr)
                except Exception as e:
                    if step_count == 1:
                        print(f"Warning: Could not capture frame: {e}")

                if step_count > 10000:
                    break

            all_frames.extend(episode_frames)
            print(f"Episode {episode + 1} completed:")
            print(f"  - Steps: {step_count}")
            print(f"  - Total Reward: {episode_reward:.2f}")
            print(f"  - Frames captured: {len(episode_frames)}")

        if all_frames:
            h, w = all_frames[0].shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

            print(f"\nWriting {len(all_frames)} frames to {output_path}...")
            for frame in all_frames:
                out.write(frame)

            out.release()
            print(f"✓ Video saved to {output_path}")

        vec_env.close()
        render_env.close()

    except ImportError as e:
        print(f"RL Zoo3 not available: {e}")
        print("Falling back to manual environment creation...")
        record_with_proper_env_matching(model_path, output_path, env_id, n_episodes)
    except Exception as e:
        print(f"RL Zoo3 approach failed: {e}")
        print("Falling back to manual environment creation...")
        record_with_proper_env_matching(model_path, output_path, env_id, n_episodes)

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        print("Trying RL Zoo3 configuration approach...")
        record_using_rl_zoo_config(
            model_path=MODEL_PATH,
            output_path="RL_Zoo3_enjoy.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_episodes=1
        )
    else:
        print(f"Model file not found: {MODEL_PATH}")

        # Help find the correct path
        if os.path.exists("logs"):
            print("\nAvailable model files:")
            for root, dirs, files in os.walk("logs"):
                for file in files:
                    if file.endswith('.zip'):
                        print(f"  {os.path.join(root, file)}")

In [None]:
!python -m rl_zoo3.enjoy  --algo dqn  --env SpaceInvadersNoFrameskip-v4 --no-render --n-timesteps 5000 --folder logs/

In [None]:
import os
import cv2
import numpy as np
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor

def compare_environments(env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Compare different environment setups to understand the discrepancy
    """
    print("=== ENVIRONMENT COMPARISON ===")

    # 1. Raw environment (no wrappers)
    print("\n1. Raw Environment:")
    raw_env = gym.make(env_id, render_mode='rgb_array')
    print(f"   Observation space: {raw_env.observation_space}")
    print(f"   Action space: {raw_env.action_space}")
    print(f"   Max episode steps: {raw_env.spec.max_episode_steps}")
    raw_env.close()

    # 2. AtariWrapper only
    print("\n2. AtariWrapper Environment:")
    atari_env = gym.make(env_id, render_mode='rgb_array')
    atari_env = AtariWrapper(atari_env)
    print(f"   Observation space: {atari_env.observation_space}")
    print(f"   Action space: {atari_env.action_space}")
    print(f"   Max episode steps: {getattr(atari_env.spec, 'max_episode_steps', 'Unknown')}")
    atari_env.close()

    # 3. Vectorized + FrameStack
    print("\n3. Vectorized + FrameStack Environment:")
    def make_env():
        env = gym.make(env_id, render_mode='rgb_array')
        env = AtariWrapper(env)
        return env

    vec_env = DummyVecEnv([make_env])
    vec_env = VecFrameStack(vec_env, n_stack=4)
    print(f"   Observation space: {vec_env.observation_space}")
    print(f"   Action space: {vec_env.action_space}")
    vec_env.close()

def diagnostic_record(model_path, output_path="diagnostic_spaceinvaders.mp4",
                     env_id='SpaceInvadersNoFrameskip-v4', n_episodes=3):
    """
    Record with detailed diagnostics to understand the reward/episode length discrepancy
    """

    print("=== DIAGNOSTIC RECORDING ===")
    compare_environments(env_id)

    # Create environments
    def make_model_env():
        env = gym.make(env_id, render_mode='rgb_array')
        env = AtariWrapper(env)
        return env

    model_env = DummyVecEnv([make_model_env])
    model_env = VecFrameStack(model_env, n_stack=4)

    # Create render environment WITHOUT AtariWrapper to see raw rewards
    render_env = gym.make(env_id, render_mode='rgb_array')

    # Also create an AtariWrapper render env for comparison
    atari_render_env = gym.make(env_id, render_mode='rgb_array')
    atari_render_env = AtariWrapper(atari_render_env)

    print(f"\nModel environment obs space: {model_env.observation_space}")
    print(f"Raw render environment obs space: {render_env.observation_space}")
    print(f"AtariWrapper render environment obs space: {atari_render_env.observation_space}")

    # Load model
    print(f"\nLoading model: {model_path}")
    model = DQN.load(model_path, env=model_env, force_reset=True)

    # Video recording setup
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 60
    all_frames = []

    for episode in range(n_episodes):
        print(f"\n{'='*50}")
        print(f"EPISODE {episode + 1}/{n_episodes}")
        print(f"{'='*50}")

        # Reset all environments
        model_obs = model_env.reset()
        render_obs, render_info = render_env.reset()
        atari_obs, atari_info = atari_render_env.reset()

        done = False
        episode_frames = []

        # Tracking variables
        model_reward = 0
        raw_reward = 0
        atari_reward = 0
        step_count = 0
        lives_info = []

        while not done:
            # Get action from model
            action, _states = model.predict(model_obs, deterministic=True)

            # Step model environment
            model_obs, m_reward, done_vec, m_info = model_env.step(action)
            done = done_vec[0]
            model_reward += m_reward[0]

            # Step raw render environment
            render_action = action[0] if isinstance(action, np.ndarray) else action
            render_obs, r_reward, r_term, r_trunc, r_info = render_env.step(render_action)
            raw_reward += r_reward

            # Step AtariWrapper render environment
            atari_obs, a_reward, a_term, a_trunc, a_info = atari_render_env.step(render_action)
            atari_reward += a_reward

            step_count += 1

            # Capture frame
            frame = render_env.render()
            if frame is not None:
                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                episode_frames.append(frame_bgr)

            # Log detailed info every 50 steps or when interesting things happen
            if (step_count % 50 == 0 or
                'lives' in r_info or
                abs(r_reward) > 0 or
                r_term or r_trunc or
                done):

                print(f"Step {step_count:4d}:")
                print(f"  Model: reward={m_reward[0]:6.1f}, done={done}, info={m_info}")
                print(f"  Raw:   reward={r_reward:6.1f}, term={r_term}, trunc={r_trunc}, info={r_info}")
                print(f"  Atari: reward={a_reward:6.1f}, term={a_term}, trunc={a_trunc}, info={a_info}")

                # Track lives if available
                if 'lives' in r_info:
                    lives_info.append((step_count, r_info['lives']))
                    print(f"  Lives: {r_info['lives']}")

            # Safety break
            if step_count > 20000:
                print(f"Safety break at {step_count} steps")
                break

        all_frames.extend(episode_frames)

        print(f"\nEPISODE {episode + 1} SUMMARY:")
        print(f"  Steps: {step_count}")
        print(f"  Model Environment Reward: {model_reward:.2f}")
        print(f"  Raw Environment Reward: {raw_reward:.2f}")
        print(f"  AtariWrapper Environment Reward: {atari_reward:.2f}")
        print(f"  Frames captured: {len(episode_frames)}")
        print(f"  Lives info: {lives_info}")
        print(f"  Final done reason: Model done={done}")
        print(f"  Raw env final state: term={r_term}, trunc={r_trunc}")
        print(f"  Atari env final state: term={a_term}, trunc={a_trunc}")

    # Write video
    if all_frames:
        print(f"\nWriting video with {len(all_frames)} frames...")
        h, w = all_frames[0].shape[:2]
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        for frame in all_frames:
            out.write(frame)

        out.release()
        print(f"✅ Diagnostic video saved to: {output_path}")

    # Cleanup
    model_env.close()
    render_env.close()
    atari_render_env.close()

def test_rl_zoo_environment_creation(model_path, env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Try to recreate the exact environment setup used by RL Zoo3
    """
    print("\n=== TESTING RL ZOO3 ENVIRONMENT SETUP ===")

    try:
        from rl_zoo3.utils import get_saved_hyperparams, create_test_env
        from stable_baselines3.common.utils import set_random_seed

        model_dir = os.path.dirname(model_path)

        # Get hyperparameters
        try:
            hyperparams, stats_path = get_saved_hyperparams(
                model_dir,
                norm_reward=False,
                test_mode=True
            )
            print(f"Loaded hyperparams: {hyperparams}")
            print(f"Stats path: {stats_path}")
        except Exception as e:
            print(f"Could not load hyperparams: {e}")
            hyperparams = {}
            stats_path = None

        # Create test environment like RL Zoo3 does
        env = create_test_env(
            env_id,
            n_envs=1,
            stats_path=stats_path,
            seed=0,
            log_dir=None,
            should_render=True,
            hyperparams=hyperparams,
            env_kwargs={}
        )

        print(f"RL Zoo3 test env observation space: {env.observation_space}")
        print(f"RL Zoo3 test env action space: {env.action_space}")

        # Test a few steps
        obs, info = env.reset()
        print(f"Initial obs shape: {obs.shape}")
        print(f"Initial info: {info}")

        for i in range(10):
            action = env.action_space.sample()
            obs, reward, terminated, truncated, info = env.step([action])
            print(f"Step {i+1}: reward={reward[0]:.1f}, term={terminated[0]}, trunc={truncated[0]}")

            if terminated[0] or truncated[0]:
                break

        env.close()

    except ImportError:
        print("RL Zoo3 not available for environment comparison")
    except Exception as e:
        print(f"RL Zoo3 environment test failed: {e}")

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        # First test RL Zoo3 environment setup
        test_rl_zoo_environment_creation(MODEL_PATH)

        # Then run diagnostic recording
        diagnostic_record(
            model_path=MODEL_PATH,
            output_path="diagnostic_spaceinvaders.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_episodes=2
        )
    else:
        print(f"Model not found: {MODEL_PATH}")
        print("Please update MODEL_PATH")

In [None]:
!python -m rl_zoo3.enjoy --algo dqn --env SpaceInvadersNoFrameskip-v4 --no-render --n-timesteps 10000 --folder logs/ --verbose 1

In [None]:
import os
import cv2
import numpy as np
import gymnasium as gym
from stable_baselines3 import DQN

def create_rl_zoo_compatible_env(env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Create environment setup that matches RL Zoo3's enjoy script
    """
    try:
        from rl_zoo3.utils import get_saved_hyperparams, create_test_env
        from rl_zoo3 import ALGOS

        # This is likely how RL Zoo3 enjoy script works
        print("Creating environment using RL Zoo3 utilities...")

        # Create test environment (this will be raw, no AtariWrapper)
        env = create_test_env(
            env_id,
            n_envs=1,
            stats_path=None,
            seed=0,
            log_dir=None,
            should_render=True,
            hyperparams={'normalize': False},
            env_kwargs={}
        )

        print(f"RL Zoo3 env observation space: {env.observation_space}")
        return env, "rl_zoo3"

    except ImportError:
        print("RL Zoo3 not available, using manual setup...")
        return None, "manual"

def record_with_rl_zoo_matching(model_path, output_path="rl_zoo_matching.mp4",
                               env_id='SpaceInvadersNoFrameskip-v4', n_episodes=1):
    """
    Record using the exact same setup as RL Zoo3's enjoy script
    """

    print("=== ATTEMPTING RL ZOO3 MATCHING SETUP ===")

    # Try to create RL Zoo3 compatible environment
    rl_zoo_env, env_type = create_rl_zoo_compatible_env(env_id)

    if rl_zoo_env is not None:
        print("✓ Using RL Zoo3 environment setup")

        # Load model with the RL Zoo3 environment
        try:
            # Load model without environment first to avoid conflicts
            model = DQN.load(model_path)
            print("✓ Model loaded without environment")

            # Set the environment
            model.set_env(rl_zoo_env)
            print("✓ Environment set on model")

        except Exception as e:
            print(f"Error with RL Zoo3 setup: {e}")
            print("Falling back to force_reset approach...")
            try:
                model = DQN.load(model_path, env=rl_zoo_env, force_reset=True)
                print("✓ Model loaded with force_reset=True")
            except Exception as e2:
                print(f"Force reset also failed: {e2}")
                rl_zoo_env.close()
                return record_with_manual_setup(model_path, output_path, env_id, n_episodes)

        # Create separate render environment (raw, for high-res video)
        render_env = gym.make(env_id, render_mode='rgb_array')

        # Run recording
        return run_recording_session(model, rl_zoo_env, render_env, output_path, n_episodes, "RL Zoo3")

    else:
        print("Falling back to manual environment setup...")
        return record_with_manual_setup(model_path, output_path, env_id, n_episodes)

def record_with_manual_setup(model_path, output_path, env_id, n_episodes):
    """
    Fallback manual setup that tries to match RL Zoo3 behavior
    """
    print("\n=== USING MANUAL SETUP ===")

    from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
    from stable_baselines3.common.atari_wrappers import AtariWrapper

    # Create environment for model (preprocessed)
    def make_env():
        env = gym.make(env_id, render_mode='rgb_array')
        env = AtariWrapper(env)
        return env

    model_env = DummyVecEnv([make_env])
    model_env = VecFrameStack(model_env, n_stack=4)

    # Load model
    model = DQN.load(model_path, env=model_env, force_reset=True)

    # Create render environment (raw)
    render_env = gym.make(env_id, render_mode='rgb_array')

    return run_recording_session(model, model_env, render_env, output_path, n_episodes, "Manual")

def run_recording_session(model, model_env, render_env, output_path, n_episodes, setup_type):
    """
    Run the actual recording session with detailed logging
    """
    print(f"\n=== RECORDING WITH {setup_type} SETUP ===")
    print(f"Model env obs space: {model_env.observation_space}")
    print(f"Render env obs space: {render_env.observation_space}")

    # Video setup
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 30
    all_frames = []

    for episode in range(n_episodes):
        print(f"\n--- Episode {episode + 1}/{n_episodes} ---")

        # Reset environments
        try:
            model_reset = model_env.reset()
            if isinstance(model_reset, tuple):
                model_obs, model_info = model_reset
            else:
                model_obs = model_reset
                model_info = {}
        except Exception as e:
            print(f"Model env reset error: {e}")
            continue

        render_obs, render_info = render_env.reset()

        print(f"Model obs shape: {model_obs.shape if hasattr(model_obs, 'shape') else type(model_obs)}")
        print(f"Render obs shape: {render_obs.shape}")

        done = False
        episode_frames = []
        model_reward = 0
        render_reward = 0
        step_count = 0
        action_log = []

        while not done and step_count < 15000:
            # Get action from model
            action, _states = model.predict(model_obs, deterministic=True)
            action_log.append(action)

            # Step model environment
            try:
                model_result = model_env.step(action)

                if len(model_result) == 5:
                    model_obs, m_reward, m_terminated, m_truncated, m_info = model_result
                    m_done = m_terminated[0] or m_truncated[0] if hasattr(m_terminated, '__len__') else m_terminated or m_truncated
                elif len(model_result) == 4:
                    model_obs, m_reward, m_done, m_info = model_result
                    if hasattr(m_done, '__len__'):
                        m_done = m_done[0]

                if hasattr(m_reward, '__len__'):
                    m_reward = m_reward[0]

                model_reward += m_reward
                done = m_done

            except Exception as e:
                print(f"Model step error at step {step_count}: {e}")
                break

            # Step render environment with same action
            render_action = action[0] if hasattr(action, '__len__') else action
            render_obs, r_reward, r_term, r_trunc, r_info = render_env.step(render_action)
            render_reward += r_reward

            step_count += 1

            # Capture frame
            try:
                frame = render_env.render()
                if frame is not None:
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    episode_frames.append(frame_bgr)
            except Exception as e:
                if step_count <= 3:
                    print(f"Frame capture error: {e}")

            # Log interesting events
            if (step_count <= 10 or
                step_count % 500 == 0 or
                abs(r_reward) > 0 or
                abs(m_reward) > 0 or
                done):

                print(f"Step {step_count:4d}: model_r={m_reward:5.1f}, render_r={r_reward:5.1f}, done={done}, action={render_action}")

                if 'lives' in r_info:
                    print(f"           lives={r_info['lives']}")

        all_frames.extend(episode_frames)

        print(f"\nEpisode {episode + 1} Results:")
        print(f"  Duration: {step_count} steps")
        print(f"  Model Total Reward: {model_reward:.2f}")
        print(f"  Render Total Reward: {render_reward:.2f}")
        print(f"  Frames: {len(episode_frames)}")
        print(f"  Ended: {'Done flag' if done else 'Safety limit'}")
        print(f"  Actions used: {len(set(action_log))} unique actions out of {len(action_log)}")

    # Write video
    if all_frames:
        print(f"\nSaving video with {len(all_frames)} frames...")
        h, w = all_frames[0].shape[:2]
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        for frame in all_frames:
            out.write(frame)

        out.release()
        print(f"✅ Video saved: {output_path}")

        # File info
        file_size = os.path.getsize(output_path) / (1024 * 1024)
        print(f"📁 Size: {file_size:.1f} MB, Length: {len(all_frames)/fps:.1f}s")

    # Cleanup
    model_env.close()
    render_env.close()

    return True

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        success = record_with_rl_zoo_matching(
            model_path=MODEL_PATH,
            output_path="rl_zoo_matching_spaceinvaders.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_episodes=1
        )

        if success:
            print("\n🎯 Recording completed! Compare this with RL Zoo3 enjoy results.")
        else:
            print("\n❌ Recording failed. Check the error messages above.")
    else:
        print(f"Model not found: {MODEL_PATH}")

In [None]:
import os
import cv2
import numpy as np
import torch as th
from stable_baselines3.common.utils import set_random_seed

def record_exactly_like_rl_zoo3(model_path, output_path="true_rl_zoo3_match.mp4",
                               env_id='SpaceInvadersNoFrameskip-v4', n_timesteps=5000):
    """
    Record agent using EXACTLY the same setup as RL Zoo3's enjoy script
    """

    print("=== EXACT RL ZOO3 REPLICATION ===")

    try:
        from rl_zoo3 import ALGOS, create_test_env, get_saved_hyperparams
        from rl_zoo3.exp_manager import ExperimentManager
        from rl_zoo3.utils import get_model_path
        import yaml

        # Replicate RL Zoo3's setup exactly
        algo = 'dqn'
        seed = 0
        set_random_seed(seed)

        # Get model path info (like RL Zoo3 does)
        model_dir = os.path.dirname(model_path)
        log_path = os.path.dirname(model_dir)

        print(f"Model path: {model_path}")
        print(f"Log path: {log_path}")

        # Check if this is Atari (like RL Zoo3 does)
        is_atari = ExperimentManager.is_atari(env_id)
        print(f"Is Atari: {is_atari}")

        # Get hyperparams exactly like RL Zoo3
        stats_path = os.path.join(log_path, env_id)
        hyperparams, maybe_stats_path = get_saved_hyperparams(
            stats_path,
            norm_reward=False,
            test_mode=True
        )

        print(f"Hyperparams: {hyperparams}")
        print(f"Stats path: {maybe_stats_path}")

        # Load env_kwargs exactly like RL Zoo3
        env_kwargs = {}
        args_path = os.path.join(log_path, env_id, "args.yml")
        if os.path.isfile(args_path):
            with open(args_path) as f:
                loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
                if loaded_args.get("env_kwargs") is not None:
                    env_kwargs = loaded_args["env_kwargs"]

        print(f"Env kwargs: {env_kwargs}")

        # Create environment EXACTLY like RL Zoo3
        env = create_test_env(
            env_id,
            n_envs=1,
            stats_path=maybe_stats_path,
            seed=seed,
            log_dir=None,
            should_render=True,  # Enable rendering
            hyperparams=hyperparams,
            env_kwargs=env_kwargs,
            vec_env_cls=ExperimentManager.default_vec_env_cls,
        )

        print(f"Environment observation space: {env.observation_space}")
        print(f"Environment action space: {env.action_space}")

        # Load model exactly like RL Zoo3
        kwargs = dict(seed=seed)

        # Off-policy algorithm handling (like RL Zoo3)
        off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
        if algo in off_policy_algos:
            kwargs.update(dict(buffer_size=1))
            if "optimize_memory_usage" in hyperparams:
                kwargs.update(optimize_memory_usage=False)

        # Custom objects (like RL Zoo3)
        custom_objects = {
            "learning_rate": 0.0,
            "lr_schedule": lambda _: 0.0,
            "clip_range": lambda _: 0.0,
        }

        model = ALGOS[algo].load(
            model_path,
            custom_objects=custom_objects,
            device='auto',
            **kwargs
        )

        print("✓ Model loaded successfully")

        # Recording setup
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = 30
        all_frames = []

        # Reset environment
        obs = env.reset()

        # Tracking variables (exactly like RL Zoo3)
        episode_reward = 0.0
        episode_rewards, episode_lengths = [], []
        ep_len = 0
        lstm_states = None
        episode_start = np.ones((env.num_envs,), dtype=bool)

        # Atari-specific tracking
        atari_scores = []
        atari_lengths = []

        print(f"\nStarting recording for {n_timesteps} timesteps...")
        print("Deterministic actions (like RL Zoo3 for Atari)")

        for timestep in range(n_timesteps):
            # Get action exactly like RL Zoo3
            action, lstm_states = model.predict(
                obs,
                state=lstm_states,
                episode_start=episode_start,
                deterministic=True,  # Deterministic for Atari (like RL Zoo3)
            )

            # Step environment
            obs, reward, done, infos = env.step(action)
            episode_start = done

            # Capture frame for video
            try:
                frame = env.render('rgb_array')
                if frame is not None:
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    all_frames.append(frame_bgr)
            except:
                pass

            # Track rewards exactly like RL Zoo3
            episode_reward += reward[0]
            ep_len += 1

            # CRITICAL: Atari-specific handling (like RL Zoo3)
            if env.num_envs == 1:
                if is_atari and infos is not None:
                    episode_infos = infos[0].get("episode")
                    if episode_infos is not None:
                        # This is the TRUE Atari score that RL Zoo3 prints!
                        atari_score = episode_infos['r']
                        atari_length = episode_infos['l']

                        print(f"Atari Episode Score: {atari_score:.2f}")
                        print(f"Atari Episode Length: {atari_length}")

                        atari_scores.append(atari_score)
                        atari_lengths.append(atari_length)

                        # Reset counters
                        episode_reward = 0.0
                        ep_len = 0

                elif done and not is_atari:
                    print(f"Episode Reward: {episode_reward:.2f}")
                    print(f"Episode Length: {ep_len}")
                    episode_rewards.append(episode_reward)
                    episode_lengths.append(ep_len)
                    episode_reward = 0.0
                    ep_len = 0

            # Progress indicator
            if timestep % 1000 == 0:
                print(f"Progress: {timestep}/{n_timesteps} timesteps")

        # Final summary
        print(f"\n=== RECORDING SUMMARY ===")
        if atari_scores:
            print(f"Atari Episodes Recorded: {len(atari_scores)}")
            print(f"Atari Scores: {atari_scores}")
            print(f"Atari Lengths: {atari_lengths}")
            print(f"Mean Atari Score: {np.mean(atari_scores):.2f} +/- {np.std(atari_scores):.2f}")
            print(f"Mean Atari Length: {np.mean(atari_lengths):.2f} +/- {np.std(atari_lengths):.2f}")

        # Write video
        if all_frames:
            print(f"\nSaving video with {len(all_frames)} frames...")
            h, w = all_frames[0].shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

            for i, frame in enumerate(all_frames):
                out.write(frame)
                if i % 1000 == 0:
                    print(f"Writing progress: {i}/{len(all_frames)} frames")

            out.release()
            print(f"✅ Video saved: {output_path}")

            # File info
            file_size = os.path.getsize(output_path) / (1024 * 1024)
            duration = len(all_frames) / fps
            print(f"📁 Size: {file_size:.1f} MB")
            print(f"⏱️  Duration: {duration:.1f} seconds")
            print(f"🎮 Episodes in video: {len(atari_scores)}")
        else:
            print("❌ No frames captured!")

        env.close()
        return True

    except ImportError as e:
        print(f"RL Zoo3 not available: {e}")
        return False
    except Exception as e:
        print(f"Error in RL Zoo3 replication: {e}")
        import traceback
        traceback.print_exc()
        return False

def simple_comparison_test(model_path, env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Quick test to compare reward types
    """
    print("=== QUICK REWARD COMPARISON TEST ===")

    try:
        from rl_zoo3 import ALGOS, create_test_env, get_saved_hyperparams
        from rl_zoo3.exp_manager import ExperimentManager

        # Create environment like RL Zoo3
        model_dir = os.path.dirname(model_path)
        log_path = os.path.dirname(model_dir)
        stats_path = os.path.join(log_path, env_id)
        hyperparams, maybe_stats_path = get_saved_hyperparams(stats_path, norm_reward=False, test_mode=True)

        env = create_test_env(
            env_id, n_envs=1, stats_path=maybe_stats_path, seed=0, log_dir=None,
            should_render=False, hyperparams=hyperparams, env_kwargs={},
            vec_env_cls=ExperimentManager.default_vec_env_cls,
        )

        model = ALGOS['dqn'].load(model_path, buffer_size=1)

        obs = env.reset()
        step_count = 0
        episode_reward = 0

        print("Running quick test to compare reward values...")

        for _ in range(1000):
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, infos = env.step(action)

            episode_reward += reward[0]
            step_count += 1

            if infos and infos[0]:
                if 'episode' in infos[0]:
                    episode_info = infos[0]['episode']
                    print(f"\n🎯 FOUND THE DIFFERENCE!")
                    print(f"   Wrapper reward sum: {episode_reward:.2f}")
                    print(f"   True Atari score: {episode_info['r']:.2f}")
                    print(f"   Episode length: {episode_info['l']}")
                    print(f"   Steps taken: {step_count}")
                    break
                elif abs(reward[0]) > 0:
                    print(f"Step {step_count}: reward={reward[0]:.1f}, info={infos[0]}")

        env.close()

    except Exception as e:
        print(f"Comparison test failed: {e}")

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        print("Running quick comparison test first...")
        simple_comparison_test(MODEL_PATH)

        print("\n" + "="*60)
        print("Now recording with exact RL Zoo3 matching...")

        success = record_exactly_like_rl_zoo3(
            model_path=MODEL_PATH,
            output_path="exact_rl_zoo3_spaceinvaders.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_timesteps=5000
        )

        if success:
            print("\n🎉 SUCCESS! This should now match RL Zoo3 exactly!")
            print("The video shows the same episodes that RL Zoo3 evaluates.")
        else:
            print("\n❌ Failed to replicate RL Zoo3 setup.")
    else:
        print(f"Model not found: {MODEL_PATH}")

In [None]:
# diagnostically debugging the episode for completion while rendering all steps for comparison
import os
import cv2
import numpy as np
import torch as th
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper

def create_model_compatible_env(env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Create environment that matches the model's expected input format
    """
    import gymnasium as gym

    def make_env():
        env = gym.make(env_id, render_mode='rgb_array')
        env = AtariWrapper(env)  # This gives us (84, 84, 1) grayscale
        return env

    # Create vectorized environment with frame stacking
    vec_env = DummyVecEnv([make_env])
    vec_env = VecFrameStack(vec_env, n_stack=4)  # This gives us (4, 84, 84)

    return vec_env

def record_with_proper_atari_tracking(model_path, output_path="fixed_atari_recording.mp4",
                                    env_id='SpaceInvadersNoFrameskip-v4', n_timesteps=5000):
    """
    Record with proper Atari reward tracking but compatible observations
    """

    print("=== HYBRID APPROACH: RL Zoo3 Rewards + Compatible Observations ===")

    try:
        from rl_zoo3 import ALGOS

        # Create model-compatible environment (preprocessed)
        model_env = create_model_compatible_env(env_id)
        print(f"Model environment obs space: {model_env.observation_space}")

        # Create separate raw environment for getting true Atari scores
        import gymnasium as gym
        raw_env = gym.make(env_id, render_mode='rgb_array')
        print(f"Raw environment obs space: {raw_env.observation_space}")

        # Load model with compatible environment
        print(f"Loading model: {model_path}")
        model = ALGOS['dqn'].load(model_path, env=model_env, force_reset=True)
        print("✓ Model loaded successfully")

        # Video recording setup
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = 30
        all_frames = []

        # Reset both environments
        model_obs = model_env.reset()
        raw_obs, raw_info = raw_env.reset()

        print(f"Model obs shape: {model_obs.shape}")
        print(f"Raw obs shape: {raw_obs.shape}")

        # Tracking variables (like RL Zoo3)
        episode_reward = 0.0  # Wrapper reward sum
        step_count = 0
        atari_scores = []
        atari_lengths = []

        print(f"\nStarting synchronized recording for {n_timesteps} timesteps...")

        for timestep in range(n_timesteps):
            # Get action from model using preprocessed observations
            action, _ = model.predict(model_obs, deterministic=True)

            # Step both environments with same action
            model_obs, model_reward, model_done, model_info = model_env.step(action)

            # Extract scalar action for raw environment
            raw_action = action[0] if hasattr(action, '__len__') else action
            raw_obs, raw_reward, raw_term, raw_trunc, raw_info = raw_env.step(raw_action)

            # Track wrapper reward (like RL Zoo3 does internally)
            episode_reward += model_reward[0]
            step_count += 1

            # Capture frame from raw environment (high resolution)
            try:
                frame = raw_env.render()
                if frame is not None:
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    all_frames.append(frame_bgr)
            except Exception as e:
                if timestep < 5:
                    print(f"Frame capture warning: {e}")

            # CRITICAL: Check for Atari episode completion in raw_info
            if raw_info and 'episode' in raw_info:
                episode_info = raw_info['episode']
                atari_score = episode_info['r']
                atari_length = episode_info['l']

                print(f"🎮 Atari Episode Score: {atari_score:.2f}")
                print(f"🎮 Atari Episode Length: {atari_length}")
                print(f"   Wrapper reward sum: {episode_reward:.2f}")
                print(f"   Steps in this recording: {step_count}")

                atari_scores.append(atari_score)
                atari_lengths.append(atari_length)

                # Reset tracking
                episode_reward = 0.0
                step_count = 0

            # Progress indicator
            if timestep % 1000 == 0 and timestep > 0:
                print(f"Progress: {timestep}/{n_timesteps} timesteps, "
                      f"Episodes: {len(atari_scores)}, "
                      f"Frames: {len(all_frames)}")

            # Check if we should stop (got enough episodes or model environment is done)
            if model_done[0]:
                print(f"Model environment signaled done at timestep {timestep}")
                # Reset model environment but continue
                model_obs = model_env.reset()

        # Final results
        print(f"\n=== RECORDING RESULTS ===")
        if atari_scores:
            print(f"✅ Atari Episodes: {len(atari_scores)}")
            print(f"📊 Scores: {atari_scores}")
            print(f"📏 Lengths: {atari_lengths}")
            print(f"🎯 Mean Score: {np.mean(atari_scores):.2f} ± {np.std(atari_scores):.2f}")
            print(f"📐 Mean Length: {np.mean(atari_lengths):.2f} ± {np.std(atari_lengths):.2f}")
        else:
            print("⚠️  No complete Atari episodes detected!")
            print(f"   Total timesteps: {timestep + 1}")
            print(f"   Current wrapper reward: {episode_reward:.2f}")
            print("   Try increasing n_timesteps or check episode detection")

        # Save video
        if all_frames:
            print(f"\n🎥 Saving video with {len(all_frames)} frames...")
            h, w = all_frames[0].shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

            for i, frame in enumerate(all_frames):
                out.write(frame)
                if i % 1500 == 0:  # Progress every 50 seconds at 30fps
                    print(f"   Writing: {i}/{len(all_frames)} frames ({100*i/len(all_frames):.1f}%)")

            out.release()

            # File statistics
            file_size = os.path.getsize(output_path) / (1024 * 1024)
            duration = len(all_frames) / fps
            print(f"✅ Video saved: {output_path}")
            print(f"📁 Size: {file_size:.1f} MB")
            print(f"⏱️  Duration: {duration:.1f} seconds")

            if atari_scores:
                episodes_per_minute = len(atari_scores) / (duration / 60)
                print(f"🎮 Episodes/minute: {episodes_per_minute:.1f}")
        else:
            print("❌ No frames captured for video!")

        # Cleanup
        model_env.close()
        raw_env.close()

        return len(atari_scores) > 0

    except Exception as e:
        print(f"Error in hybrid recording: {e}")
        import traceback
        traceback.print_exc()
        return False

def debug_episode_detection(model_path, env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Debug why episodes might not be detected properly
    """
    print("=== DEBUGGING EPISODE DETECTION ===")

    try:
        from rl_zoo3 import ALGOS
        import gymnasium as gym

        # Create environments
        model_env = create_model_compatible_env(env_id)
        raw_env = gym.make(env_id, render_mode='rgb_array')

        # Load model
        model = ALGOS['dqn'].load(model_path, env=model_env, force_reset=True)

        # Reset
        model_obs = model_env.reset()
        raw_obs, raw_info = raw_env.reset()

        print(f"Initial raw_info: {raw_info}")

        episode_reward = 0
        wrapper_reward = 0

        for step in range(2000):  # Limit for debugging
            # Get action and step both environments
            action, _ = model.predict(model_obs, deterministic=True)
            model_obs, m_reward, m_done, m_info = model_env.step(action)

            raw_action = action[0] if hasattr(action, '__len__') else action
            raw_obs, r_reward, r_term, r_trunc, r_info = raw_env.step(raw_action)

            episode_reward += r_reward
            wrapper_reward += m_reward[0]

            # Log interesting events
            if (step < 10 or
                step % 100 == 0 or
                abs(r_reward) > 0 or
                r_info or
                m_done[0] or
                r_term or r_trunc):

                print(f"Step {step:4d}: raw_r={r_reward:5.1f}, wrap_r={m_reward[0]:5.1f}, "
                      f"done={m_done[0]}, term={r_term}, trunc={r_trunc}")

                if r_info:
                    print(f"          raw_info: {r_info}")
                if m_info:
                    print(f"          model_info: {m_info}")

                # Check for episode completion
                if r_info and 'episode' in r_info:
                    episode_info = r_info['episode']
                    print(f"🎯 EPISODE COMPLETE!")
                    print(f"   True Atari Score: {episode_info['r']:.2f}")
                    print(f"   Episode Length: {episode_info['l']}")
                    print(f"   Raw reward sum: {episode_reward:.2f}")
                    print(f"   Wrapper reward sum: {wrapper_reward:.2f}")
                    break

            if m_done[0]:
                print(f"Model environment done at step {step}")
                break

        model_env.close()
        raw_env.close()

    except Exception as e:
        print(f"Debug failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        print("First, let's debug episode detection...")
        debug_episode_detection(MODEL_PATH)

        print("\n" + "="*60)
        print("Now recording with hybrid approach...")

        success = record_with_proper_atari_tracking(
            model_path=MODEL_PATH,
            output_path="hybrid_atari_recording.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            n_timesteps=10000  # Increased to capture more episodes
        )

        if success:
            print("\n🎉 SUCCESS! Video should now show proper Atari episodes!")
            print("This matches RL Zoo3's evaluation but with working observations.")
        else:
            print("\n❌ Recording had issues. Check the debug output above.")
    else:
        print(f"Model not found: {MODEL_PATH}")
        print("Please update MODEL_PATH with the correct path.")

In [None]:
# make use of the 'lives' info in the r_info info dict, not the nonexistant 'episode' dict
import os
import cv2
import numpy as np
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper

def record_with_life_episode_detection(model_path, output_path="life_based_recording.mp4",
                                     env_id='SpaceInvadersNoFrameskip-v4', max_episodes=3):
    """
    Record Atari episodes using life-based episode detection and manual score tracking
    """

    print("=== LIFE-BASED EPISODE DETECTION ===")

    try:
        from rl_zoo3 import ALGOS

        # Create model-compatible environment (preprocessed)
        def make_env():
            env = gym.make(env_id, render_mode='rgb_array')
            env = AtariWrapper(env)
            return env

        model_env = DummyVecEnv([make_env])
        model_env = VecFrameStack(model_env, n_stack=4)

        # Create separate raw environment for rendering and score tracking
        raw_env = gym.make(env_id, render_mode='rgb_array')

        print(f"Model env obs space: {model_env.observation_space}")
        print(f"Raw env obs space: {raw_env.observation_space}")

        # Load model
        print(f"Loading model: {model_path}")
        model = ALGOS['dqn'].load(model_path, env=model_env, force_reset=True)
        print("✓ Model loaded successfully")

        # Video recording setup
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = 60
        all_frames = []

        # Episode tracking variables
        episodes_completed = 0
        episode_scores = []
        episode_lengths = []

        while episodes_completed < max_episodes:
            print(f"\n{'='*50}")
            print(f"STARTING EPISODE {episodes_completed + 1}/{max_episodes}")
            print(f"{'='*50}")

            # Reset both environments
            model_obs = model_env.reset()
            raw_obs, raw_info = raw_env.reset()

            # Episode state tracking
            current_lives = raw_info.get('lives', 3)
            initial_lives = current_lives
            episode_score = 0
            episode_steps = 0
            life_lost_recently = False

            print(f"Starting with {current_lives} lives")

            # Play until all lives are lost
            while current_lives > 0:
                # Get action from model
                action, _ = model.predict(model_obs, deterministic=True)

                # Step both environments
                model_obs, model_reward, model_done, model_info = model_env.step(action)

                raw_action = action[0] if hasattr(action, '__len__') else action
                raw_obs, raw_reward, raw_term, raw_trunc, raw_info = raw_env.step(raw_action)

                # Track score and steps
                episode_score += raw_reward
                episode_steps += 1

                # Capture frame
                try:
                    frame = raw_env.render()
                    if frame is not None:
                        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                        all_frames.append(frame_bgr)
                except:
                    pass

                # Check for life loss
                new_lives = raw_info.get('lives', current_lives)
                if new_lives < current_lives:
                    print(f"  Life lost! Lives: {current_lives} → {new_lives}, Score: {episode_score}, Steps: {episode_steps}")
                    current_lives = new_lives
                    life_lost_recently = True

                # Check if model environment reset (happens on life loss in wrapped env)
                if model_done[0]:
                    print(f"  Model environment reset at step {episode_steps}")
                    model_obs = model_env.reset()

                # Progress indicator
                if episode_steps % 500 == 0:
                    print(f"  Progress: {episode_steps} steps, Score: {episode_score:.0f}, Lives: {current_lives}")

                # Safety break for very long episodes
                if episode_steps > 20000:
                    print(f"  Episode exceeded 20000 steps, ending...")
                    break

            # Episode completed (all lives lost)
            episode_scores.append(episode_score)
            episode_lengths.append(episode_steps)
            episodes_completed += 1

            print(f"\n🎮 EPISODE {episodes_completed} COMPLETED!")
            print(f"   Final Score: {episode_score:.0f}")
            print(f"   Episode Length: {episode_steps} steps")
            print(f"   Lives Used: {initial_lives}")

            # Add a few seconds of the final screen
            print("   Recording final screen for 3 seconds...")
            for _ in range(90):  # 3 seconds at 30fps
                try:
                    frame = raw_env.render()
                    if frame is not None:
                        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                        all_frames.append(frame_bgr)
                except:
                    break

        # Final results
        print(f"\n{'='*50}")
        print(f"RECORDING SUMMARY")
        print(f"{'='*50}")
        print(f"✅ Episodes Completed: {episodes_completed}")
        print(f"📊 Episode Scores: {episode_scores}")
        print(f"📏 Episode Lengths: {episode_lengths}")
        print(f"🎯 Mean Score: {np.mean(episode_scores):.1f} ± {np.std(episode_scores):.1f}")
        print(f"📐 Mean Length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}")
        print(f"🎬 Total Frames: {len(all_frames)}")

        # Compare with RL Zoo3 results
        print(f"\n📈 COMPARISON WITH RL ZOO3:")
        print(f"   RL Zoo3 reported: 275-600 points, 2000-5000 steps")
        print(f"   This recording: {min(episode_scores):.0f}-{max(episode_scores):.0f} points, {min(episode_lengths)}-{max(episode_lengths)} steps")

        # Save video
        if all_frames:
            print(f"\n🎥 SAVING VIDEO...")
            h, w = all_frames[0].shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

            for i, frame in enumerate(all_frames):
                out.write(frame)
                if i % 1500 == 0:
                    print(f"   Writing: {i}/{len(all_frames)} frames ({100*i/len(all_frames):.1f}%)")

            out.release()

            # File statistics
            file_size = os.path.getsize(output_path) / (1024 * 1024)
            duration = len(all_frames) / fps
            print(f"✅ Video saved: {output_path}")
            print(f"📁 Size: {file_size:.1f} MB")
            print(f"⏱️  Duration: {duration:.1f} seconds ({duration/60:.1f} minutes)")
            print(f"🎮 Episodes per minute: {episodes_completed / (duration/60):.1f}")

            return True
        else:
            print("❌ No frames captured!")
            return False

    except Exception as e:
        print(f"Error in life-based recording: {e}")
        import traceback
        traceback.print_exc()
        return False

    finally:
        # Cleanup
        try:
            model_env.close()
            raw_env.close()
        except:
            pass

def quick_life_test(model_path, env_id='SpaceInvadersNoFrameskip-v4'):
    """
    Quick test to understand the life mechanics
    """
    print("=== QUICK LIFE MECHANICS TEST ===")

    try:
        from rl_zoo3 import ALGOS

        # Create environments
        def make_env():
            env = gym.make(env_id, render_mode='rgb_array')
            env = AtariWrapper(env)
            return env

        model_env = DummyVecEnv([make_env])
        model_env = VecFrameStack(model_env, n_stack=4)
        raw_env = gym.make(env_id, render_mode='rgb_array')

        # Load model
        model = ALGOS['dqn'].load(model_path, env=model_env, force_reset=True)

        # Reset and test
        model_obs = model_env.reset()
        raw_obs, raw_info = raw_env.reset()

        print(f"Initial raw_info: {raw_info}")

        current_lives = raw_info.get('lives', 3)
        total_score = 0
        step_count = 0

        print(f"Starting with {current_lives} lives")

        # Run until we see some life changes
        life_changes = 0
        while life_changes < 2 and step_count < 5000:
            action, _ = model.predict(model_obs, deterministic=True)

            model_obs, m_reward, m_done, m_info = model_env.step(action)
            raw_action = action[0] if hasattr(action, '__len__') else action
            raw_obs, r_reward, r_term, r_trunc, r_info = raw_env.step(raw_action)

            total_score += r_reward
            step_count += 1

            # Check for life changes
            new_lives = r_info.get('lives', current_lives)
            if new_lives != current_lives:
                life_changes += 1
                print(f"\n🔄 LIFE CHANGE #{life_changes} at step {step_count}:")
                print(f"   Lives: {current_lives} → {new_lives}")
                print(f"   Score so far: {total_score:.0f}")
                print(f"   Model done: {m_done[0]}")
                print(f"   Raw info: {r_info}")
                current_lives = new_lives

                if m_done[0]:
                    print(f"   Model environment reset")
                    model_obs = model_env.reset()

            # Log rewards
            if abs(r_reward) > 0:
                print(f"Step {step_count}: +{r_reward:.0f} points (total: {total_score:.0f})")

        print(f"\nTest completed:")
        print(f"  Total score: {total_score:.0f}")
        print(f"  Steps: {step_count}")
        print(f"  Final lives: {current_lives}")
        print(f"  Life changes observed: {life_changes}")

        model_env.close()
        raw_env.close()

    except Exception as e:
        print(f"Life test failed: {e}")

if __name__ == "__main__":
    MODEL_PATH = "logs/dqn/SpaceInvadersNoFrameskip-v4_1/SpaceInvadersNoFrameskip-v4.zip"

    if os.path.exists(MODEL_PATH):
        print("First, running quick life mechanics test...")
        quick_life_test(MODEL_PATH)

        print("\n" + "="*60)
        print("Now recording with life-based episode detection...")

        success = record_with_life_episode_detection(
            model_path=MODEL_PATH,
            output_path="life_based_spaceinvaders.mp4",
            env_id='SpaceInvadersNoFrameskip-v4',
            max_episodes=2  # Record 2 complete episodes
        )

        if success:
            print("\n🎉 SUCCESS! Video shows complete Atari episodes with proper scoring!")
            print("The scores should now match RL Zoo3's evaluation results.")
        else:
            print("\n❌ Recording failed. Check the error messages above.")
    else:
        print(f"Model not found: {MODEL_PATH}")
        print("Please update MODEL_PATH with the correct path.")