# Mario Bros RL: CrossQ Demo

This notebook trains a **CrossQ** agent (implemented via `sbx`'s SAC with modifications) on `SuperMarioBros-v0`.
It is self-contained (requires `sbx` package from this repo) and includes all necessary wrappers.

## 0. Colab Setup
**Instructions:**
1. **Git Push**: Ensure you have pushed your local changes (with `sbx` updates) to your GitHub fork.
2. **Update URL**: Replace `YOUR_GITHUB_REPO_URL` below with the URL of your forked repository.
3. **Run**: Execute the cell to install dependencies and the local package.

In [None]:
# @title Install Dependencies
import os

# ==========================================
# CHANGE THIS TO YOUR FORK URL
YOUR_GITHUB_REPO_URL = "https://github.com/martiincooper/RL-CROSSQ.git"
# ==========================================

repo_name = YOUR_GITHUB_REPO_URL.split("/")[-1].replace(".git", "")

# 1. Clone if setup.py is missing and we aren't in the repo folder yet
if not os.path.exists("setup.py"):
    if not os.path.exists(repo_name):
        print(f"Cloning {YOUR_GITHUB_REPO_URL}...")
        !git clone $YOUR_GITHUB_REPO_URL
    
    # Configure Python to work inside the repo
    if os.path.exists(repo_name):
        os.chdir(repo_name)
        print(f"Changed directory to {os.getcwd()}")

# 2. Install project in editable mode so 'sbx' is importable
# We use --ignore-requires-python to bypass strict python checks
# WE USE ABSOLUTE PATH to ensure pip finds it regardless of shell CWD
cwd = os.getcwd()
!pip install -e "$cwd" --ignore-requires-python

# 3. Install compatible versions of libraries
# gymnasium==0.29.1 is strict requirement for wrappers
# numpy<2.0.0 is required for scipy and legacy gym compatibility
# flax>=0.8.0 is required for compatibility with newer JAX versions
# FORCE UNINSTALL TFP & DOPAMINE & NUMPY FIRST to remove Colab's stubborn pre-installed versions
# PIN opencv-python<4.10.0 to avoid numpy>=2 requirement
# USE --force-reinstall to ensure binaries match numpy version
!pip uninstall -y tensorflow-probability dopamine-rl numpy
!pip install --force-reinstall "numpy<2.0.0" "flax>=0.8.0" gymnasium==0.29.1 gym-super-mario-bros==7.4.0 nes-py==8.2.1 shimmy==1.3.0 "opencv-python<4.10.0" matplotlib stable-baselines3 tensorflow-probability==0.23.0

In [None]:
import os
import sys

# Ensure we can import sbx if we are in the cloned repo
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

try:
    import gymnasium as gym
    from gymnasium.wrappers import ResizeObservation, GrayScaleObservation, FrameStack
except ImportError as e:
    print("ImportError:", e)
    print("Please install gymnasium==0.29.1: pip install gymnasium==0.29.1")

import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import numpy as np
import cv2
import jax
import optax
import tensorflow_probability as tfp

# Verify TFP JAX backend
try:
    tfd = tfp.substrates.jax.distributions
    print("TFP JAX backend functional.")
except AttributeError:
    print("Error: tensorflow_probability.substrates.jax not found. Please upgrade tensorflow-probability.")

try:
    from sbx import SAC
    from sbx.sac.utils import ReLU
except ImportError:
    print("Could not import sbx. Make sure you are running this notebook from the CrossQ repository root or have installed it via 'pip install -e .'")

import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy

# Ensure compatibility
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

ENV_ID = 'SuperMarioBros-v0'
TOTAL_TIMESTEPS = 100000 # Adjust as needed
LOG_DIR = "./mario_benchmark_logs/"
os.makedirs(LOG_DIR, exist_ok=True)

print(f"JAX Devices: {jax.devices()}")

## 1. Environment Wrappers
We define the environment wrappers here to make the notebook self-contained.

In [None]:
# Robust wrapper to force Gymnasium API (5-tuple step)
class MarioGymnasiumWrapper(gym.Wrapper):
    def __init__(self, env):
        # We manually wrap the gym environment
        self.env = env
        # Do not copy attributes manually, let Wrapper delegate or Properties handle it.
        
        # We try to get them from the unwrapped env if possible
        # And convert to gymnasium spaces
        gym_obs = env.observation_space
        self._observation_space = gym.spaces.Box(
            low=gym_obs.low, high=gym_obs.high, shape=gym_obs.shape, dtype=gym_obs.dtype
        )
        
        gym_act = env.action_space
        if hasattr(gym_act, 'n'):
            self._action_space = gym.spaces.Discrete(gym_act.n)
        else:
             self._action_space = gym.spaces.Box(
                low=gym_act.low, high=gym_act.high, shape=gym_act.shape, dtype=gym_act.dtype
             )

        self._metadata = getattr(env, 'metadata', {})
        self._reward_range = getattr(env, 'reward_range', (-float('inf'), float('inf')))

    @property
    def action_space(self):
        return self._action_space

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def metadata(self):
        return self._metadata

    @property
    def reward_range(self):
        return self._reward_range

    def reset(self, **kwargs):
        # gym-super-mario-bros reset returns (obs) or (obs, info) depending on version/wrapper
        # Handle seed manually for old gym compatibility
        seed = kwargs.get('seed', None)
        options = kwargs.get('options', None)
        
        # Pop them to avoid passing to old reset which might not accept them
        if 'seed' in kwargs: kwargs.pop('seed')
        if 'options' in kwargs: kwargs.pop('options')
        
        if seed is not None:
             try:
                 self.env.seed(seed)
             except AttributeError:
                 pass # Env might not support seed
        
        # usage: obs = env.reset()
        ret = self.env.reset(**kwargs)
        if isinstance(ret, tuple):
            if len(ret) == 2:
                return ret
            return ret[0], {}
        return ret, {}

    def step(self, action):
        ret = self.env.step(action)
        if len(ret) == 4:
            obs, reward, done, info = ret
            truncated = False
            terminated = done
            return obs, reward, terminated, truncated, info
        elif len(ret) == 5:
            return ret
        else:
            raise ValueError(f"Unexpected step return length: {len(ret)}")
    
    def render(self):
        return self.env.render()
    
    def close(self):
        return self.env.close()

class ContinuousMarioWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        n_actions = env.action_space.n
        # CrossQ/SAC expect a Box action space
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(n_actions,), dtype=np.float32)
        
    def action(self, action):
        # Convert continuous vector to discrete index
        return int(np.argmax(action))

# Transpose to (H, W, C) = (84, 84, 4)
class TransposeWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = env.observation_space.shape # (4, 84, 84)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, 
            shape=(old_shape[1], old_shape[2], old_shape[0]), 
            dtype=env.observation_space.dtype
        )
    
    def observation(self, obs):
        # obs is (4, 84, 84) -> (84, 84, 4)
        return np.moveaxis(obs, 0, -1)

def make_mario_env(env_id='SuperMarioBros-v0', action_space=SIMPLE_MOVEMENT, stack_frames=4, render_mode='rgb_array'):
    # Create the original environment
    # nes-py returns 4-tuple, gym 0.26 expects 5-tuple. compatibility=True might help.
    env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True) 
    env = JoypadSpace(env, action_space)
    
    # Force load (nes_py lazy initialization might be an issue)
    env.reset()
    
    # Wrap it to gymnasium manually
    env = MarioGymnasiumWrapper(env)

    # Convert to continuous for CrossQ/SAC
    env = ContinuousMarioWrapper(env)
    
    # Apply wrappers
    env = GrayScaleObservation(env, keep_dim=False)
    env = ResizeObservation(env, (84, 84))
    env = FrameStack(env, stack_frames)
    
    env = TransposeWrapper(env)
    
    return env

## 2. CrossQ Configuration
Configuration specific to CrossQ as per the paper.

In [None]:
def create_crossq_model(env, seed=1):
    # CrossQ Hyperparameters
    kwargs = {
        "verbose": 1,
        "seed": seed,
        "learning_starts": 1000,
        "buffer_size": 50_000, # Reduced to avoid OOM on laptop
        "ent_coef": "auto",
        "crossq_style": True,
        "policy_delay": 3,
        "gradient_steps": 1, # UTD=1
        "tau": 1.0,          # No target network
        "learning_rate": 1e-3,
        "policy_kwargs": {
             "activation_fn": ReLU,
             "n_critics": 2,
             "batch_norm": True,
             "batch_norm_momentum": 0.99,
             "net_arch": {"pi": [256, 256], "qf": [2048, 2048]}, # Wider critics for CrossQ
             "optimizer_kwargs": {"b1": 0.5}
        },
    }

    model = SAC(
        "CnnPolicy", 
        env,
        **kwargs
    )
    model.name = "CrossQ"
    return model

## 3. Training Loop
Train the CrossQ agent.

In [None]:
# Create environments with Monitor
env_crossq = Monitor(make_mario_env(ENV_ID), filename=os.path.join(LOG_DIR, "CrossQ"))

# Initialize model
model = create_crossq_model(env_crossq)

print(f"Training CrossQ for {TOTAL_TIMESTEPS} steps...")
model.learn(total_timesteps=TOTAL_TIMESTEPS, progress_bar=True)
model.save(f"sbx_CrossQ_mario")
print(f"CrossQ Done.")
model.env.close()

## 4. Performance Check
Plotting the learning curve.

In [None]:
def moving_average(values, window):
    """Smooth values by calculating moving average."""
    if len(values) < window:
        return values
    weights = np.repeat(1.0, window) / window
    return np.convolve(values, weights, 'valid')

def plot_results(log_folder, title='CrossQ Learning Curve'):
    plt.figure(figsize=(10, 6))
    algo = "CrossQ"
    color = 'r'
    
    try:
        if not os.path.exists(os.path.join(log_folder, f"{algo}.monitor.csv")):
            print(f"Log for {algo} not found.")
            return
        
        import pandas as pd
        df = pd.read_csv(os.path.join(log_folder, f"{algo}.monitor.csv"), skiprows=1)
        x = df['t'].values
        y = df['r'].values
        
        if len(x) > 1:
            y_smoothed = moving_average(y, window=50)
            x_smoothed = x[len(x) - len(y_smoothed):]
            plt.plot(x_smoothed, y_smoothed, label=algo, color=color)
        else:
            print(f"Not enough data to plot for {algo}")
    except Exception as e:
        print(f"Could not plot {algo}: {e}")
            
    plt.xlabel('Timesteps')
    plt.ylabel('Episode Reward')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

plot_results(LOG_DIR)

## 5. Visual Demo (Watch it Play)
Watch the trained CrossQ agent play.

In [None]:
def run_demo(model, env_id):
    env = make_mario_env(env_id)
    obs, _ = env.reset()
    done = False
    
    print("Starting Demo! Check the popup window (or output below if on Colab).")
    
    try:
        from google.colab.patches import cv2_imshow
        is_colab = True
    except ImportError:
        is_colab = False

    try:
        frames = []
        # Run for a fixed number of frames
        for _ in range(1000):
            if done: break

            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            frame = env.render()
            frame = np.ascontiguousarray(frame, dtype=np.uint8)
            # Add label
            cv2.putText(frame, "CrossQ", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
            frames.append(frame)
            
            # Show basic animation attempt in non-colab, or just collect frames for colab video (not implemented here since user just wants to "see" it)
            if not is_colab:
                # BGR for opencv
                f_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                cv2.imshow("CrossQ Mario", f_bgr)
                if cv2.waitKey(20) & 0xFF == ord('q'):
                    break
        
        print(f"Demo finished. Episode Length: {len(frames)} frames")
        
        # If you really want to see it in Colab, saving a video is best
        if is_colab and len(frames) > 0:
             print("Saving video to 'crossq_demo.mp4'...")
             height, width, layers = frames[0].shape
             # Use mp4v codec for compatibility
             out = cv2.VideoWriter('crossq_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
             for f in frames:
                 out.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
             out.release()
             print("Video saved! Download 'crossq_demo.mp4' to view.")

    finally:
        env.close()
        if not is_colab:
            cv2.destroyAllWindows()
            cv2.waitKey(1)

run_demo(model, ENV_ID)