# Mario Bros (ALE) RL: CrossQ Demo

This notebook trains a **CrossQ** agent on `ALE/MarioBros-v5` (Arcade Mario).
It uses the maintained **Gymnasium** library via `ale-py` and pin specific JAX versions for stability on Colab T4.

## 0. Colab Setup
**Instructions:**
1. **Git Push**: Ensure you have pushed your local changes to your GitHub fork.
2. **Update URL**: Replace `YOUR_GITHUB_REPO_URL` below.
3. **Run**: Execute the cell to install dependencies.

In [None]:
# @title Install Dependencies
import os

# ==========================================
# CHANGE THIS TO YOUR FORK URL
YOUR_GITHUB_REPO_URL = "https://github.com/martiincooper/RL-CROSSQ.git"
# ==========================================

repo_name = YOUR_GITHUB_REPO_URL.split("/")[-1].replace(".git", "")

# 1. Clone if needed
if not os.path.exists("setup.py"):
    if not os.path.exists(repo_name):
        print(f"Cloning {YOUR_GITHUB_REPO_URL}...")
        !git clone $YOUR_GITHUB_REPO_URL
    
    if os.path.exists(repo_name):
        os.chdir(repo_name)
        print(f"Changed directory to {os.getcwd()}")

# 2. Install project 
cwd = os.getcwd()
!pip install -e "$cwd" --ignore-requires-python

# 3. Install Stable Dependencies for Colab T4
# - JAX pinned to 0.4.28 to avoid 'pytype_aval_mappings' deprecation in JAX 0.5+
# - NumPy < 2.0.0 for compatibility with older wheels
# - Gymnasium[atari] for maintained Mario Bros
!pip uninstall -y jax jaxlib tensorflow-probability dopamine-rl numpy
!pip install "numpy<2.0.0" "gymnasium[atari, accept-rom-license]" "opencv-python<4.10.0" matplotlib stable-baselines3
!pip install -U "jax[cuda12_pip]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install tensorflow-probability==0.23.0

In [None]:
import os
import sys

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

import gymnasium as gym
import numpy as np
import cv2
import jax
import tensorflow_probability as tfp

# Verify TFP JAX backend
try:
    tfd = tfp.substrates.jax.distributions
    print("TFP JAX backend functional.")
except AttributeError:
    print("Error: TFP broken. Ensure JAX <= 0.4.30 is installed.")

try:
    from sbx import SAC
    from sbx.sac.utils import ReLU
except ImportError:
    print("Could not import sbx.")

import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.atari_wrappers import AtariWrapper

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

# Using the maintained Gymnasium ALE environment
ENV_ID = 'ALE/MarioBros-v5'
TOTAL_TIMESTEPS = 100000 
LOG_DIR = "./mario_benchmark_logs/"
os.makedirs(LOG_DIR, exist_ok=True)

print(f"JAX Devices: {jax.devices()}")

## 1. Environment (Atari)
Refined wrapper for CrossQ (requiring continuous action space emulation).

In [None]:
class ContinuousActionWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        n_actions = env.action_space.n
        # CrossQ/SAC expect a Box action space
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(n_actions,), dtype=np.float32)
        
    def action(self, action):
        # Convert continuous vector to discrete index
        return int(np.argmax(action))

def make_atari_env(env_id):
    env = gym.make(env_id, render_mode='rgb_array')
    # Standard Atari Preprocessing (Resize, Grayscale, FrameStack)
    env = AtariWrapper(env)
    env = ContinuousActionWrapper(env)
    # Transpose for PyTorch/Jax style (C, H, W) if needed, 
    # BUT sbx/stable-baselines usually handle (H, W, C). 
    # Let's check sbx policy. It uses NatureCNN which expects (H, W, C) input -> (Batch, H, W, C).
    return env

## 2. CrossQ Configuration
Setup CrossQ agent.

In [None]:
def create_crossq_model(env, seed=1):
    # CrossQ Hyperparameters
    kwargs = {
        "verbose": 1,
        "seed": seed,
        "learning_starts": 1000,
        "buffer_size": 50_000, 
        "ent_coef": "auto",
        "crossq_style": True,
        "policy_delay": 3,
        "gradient_steps": 1, 
        "tau": 1.0,          
        "learning_rate": 1e-3,
        "policy_kwargs": {
             "activation_fn": ReLU,
             "n_critics": 2,
             "batch_norm": True,
             "batch_norm_momentum": 0.99,
             "net_arch": {"pi": [256, 256], "qf": [2048, 2048]},
             "optimizer_kwargs": {"b1": 0.5}
        },
    }

    model = SAC(
        "CnnPolicy", 
        env,
        **kwargs
    )
    model.name = "CrossQ"
    return model

## 3. Training
Train CrossQ on Arcade Mario.

In [None]:
env = Monitor(make_atari_env(ENV_ID), filename=os.path.join(LOG_DIR, "CrossQ"))
model = create_crossq_model(env)

print(f"Training CrossQ on {ENV_ID}...")
model.learn(total_timesteps=TOTAL_TIMESTEPS, progress_bar=True)
model.save(f"sbx_CrossQ_ale_mario")
print(f"CrossQ Done.")
model.env.close()

## 4. Visualization
Watch the agent.

In [None]:
def run_demo(model, env_id):
    env = make_atari_env(env_id)
    obs, _ = env.reset()
    done = False
    
    print("Starting Demo!")
    
    try:
        from google.colab.patches import cv2_imshow
        is_colab = True
    except ImportError:
        is_colab = False

    try:
        frames = []
        for _ in range(1000):
            if done: break
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            frame = env.render() # RGB Array
            # Frame already processed? AtariWrapper render usually returns original frame if render_mode='rgb_array' was passed to gym.make
            # But wait, AtariWrapper modifies observation, not render.
            # make_atari_env -> gym.make(..., render_mode='rgb_array')
            
            if frame is not None:
                frames.append(frame)

        print(f"Demo finished. Frames: {len(frames)}")
        
        if is_colab and len(frames) > 0:
             print("Saving video to 'crossq_demo.mp4'...")
             height, width, layers = frames[0].shape
             out = cv2.VideoWriter('crossq_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
             for f in frames:
                 out.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
             out.release()
             print("Video saved!")

    finally:
        env.close()

run_demo(model, ENV_ID)