# CrossQ on Super-Mario-RL (Replicated)

This notebook runs **CrossQ** on the Super-Mario-RL environment.
To ensure stability and simplicity, we replicate the environment stack (Noop, Skip, EpisodicLife, Warp, Stack) from `Super-Mario-RL` natively in **Gymnasium**, avoiding legacy dependency conflicts.

## 0. Setup
1. **Git Push** your sbx changes.
2. **Run** the setup cell.

In [None]:
# @title Install Dependencies
import os

# ==========================================
# YOUR FORK
YOUR_GITHUB_REPO_URL = "https://github.com/martiincooper/RL-CROSSQ.git"
# ==========================================

# 1. Clone CrossQ
repo_name = YOUR_GITHUB_REPO_URL.split("/")[-1].replace(".git", "")
if not os.path.exists("setup.py"):
    if not os.path.exists(repo_name):
        !git clone $YOUR_GITHUB_REPO_URL
    if os.path.exists(repo_name):
        os.chdir(repo_name)

# 2. Clone Super-Mario-RL (For reference/assets)
if not os.path.exists("Super-Mario-RL"):
    !git clone https://github.com/jiseongHAN/Super-Mario-RL.git

# 3. Install SBX
cwd = os.getcwd()
!pip install -e "$cwd" --ignore-requires-python

# 4. Install Dependencies (Shimmy = Bridge from Old Gym -> Gymnasium)
!pip uninstall -y jax jaxlib tensorflow-probability dopamine-rl numpy
!pip install "numpy<2.0.0" "shimmy>=1.3.0" "gym-super-mario-bros==7.4.0" "nes-py==8.2.1" "opencv-python<4.10.0" matplotlib stable-baselines3
!pip install -U "jax[cuda12_pip]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install tensorflow-probability==0.23.0

In [None]:
import os
import jax
import numpy as np
import gymnasium as gym
import shimmy
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from nes_py.wrappers import JoypadSpace
from sbx import SAC
from sbx.sac.utils import ReLU
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.atari_wrappers import NoopResetEnv, MaxAndSkipEnv, WarpFrame
from gymnasium.wrappers import FrameStack

# Prevent JAX memory issues
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

# === Replicating Super-Mario-RL Wrappers in Gymnasium ===

class EpisodicLifeMario(gym.Wrapper):
    """Re-implementation of EpisodicLifeEnv for Mario (Gymnasium)"""
    def __init__(self, env):
        super().__init__(env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.was_real_done = terminated or truncated
        
        # The shimmy wrapped env puts the original env in `.unwrapped`
        # nes-py env property is `_life`
        lives = self.env.unwrapped._life
        
        if lives < self.lives and lives > 0:
            terminated = True
            
        self.lives = lives
        return obs, reward, terminated, truncated, info

    def reset(self, **kwargs):
        if self.was_real_done:
            obs, info = self.env.reset(**kwargs)
        else:
            # No-op step to advance
            obs, _, _, _, info = self.env.step(0)
            
        self.lives = self.env.unwrapped._life
        return obs, info

class ContinuousToDiscreteWrapper(gym.ActionWrapper):
    """Adapts CrossQ (Continuous) -> Mario (Discrete)"""
    def __init__(self, env):
        super().__init__(env)
        self.n = env.action_space.n
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.n,), dtype=np.float32)
    
    def action(self, action):
        return int(np.argmax(action))

def make_mario_env_replicated():
    # 1. Base Env (Old Gym)
    env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    
    # 2. Convert to Gymnasium (Shimmy)
    env = shimmy.GymV21CompatibilityV0(env=env)
    
    # 3. Apply Replicated Wrappers (Order matters! Matches wrap_mario)
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    env = EpisodicLifeMario(env)
    env = WarpFrame(env) # 84x84 Grayscale
    # ClipRewardEnv is NOT used in wrap_mario from the repo
    # ScaledFloatFrame is implicitly handled or we can add
    # SB3 WarpFrame returns uint8 0-255 generally. 
    # CrossQ usually wants 0-1 floats? SBX CNN policy handles normalization often. 
    # Let's verify: NatureCNN in SB3 divides by 255. SBX uses same structure.
    # So uint8 0-255 is CORRECT for standard CnnPolicy.
    
    env = FrameStack(env, 4)
    
    # 4. Agent Adapter
    env = ContinuousToDiscreteWrapper(env)
    
    return env

## 1. Train CrossQ

In [None]:
LOG_DIR = "./mario_replicated_logs/"
os.makedirs(LOG_DIR, exist_ok=True)

env = Monitor(make_mario_env_replicated(), filename=os.path.join(LOG_DIR, "CrossQ"))

kwargs = {
    "verbose": 1,
    "learning_starts": 1000,
    "buffer_size": 50_000,
    "ent_coef": "auto",
    "crossq_style": True,
    "policy_delay": 3,
    "gradient_steps": 1,
    "tau": 1.0,
    "learning_rate": 1e-3,
    "policy_kwargs": {
            "activation_fn": ReLU,
            "n_critics": 2,
            "batch_norm": True,
            "net_arch": {"pi": [256, 256], "qf": [2048, 2048]},
            "optimizer_kwargs": {"b1": 0.5}
    },
}

model = SAC("CnnPolicy", env, **kwargs)

print("Training CrossQ on Super-Mario-RL stack...")
model.learn(total_timesteps=100000, progress_bar=True)
model.save("crossq_super_mario_replicated")
print("Done.")