# Testing SafePointGoal Environment with PPO-Lagrange

This notebook tests the new SafePointGoal environment with Safety Gymnasium-style constraints.

In [None]:
from datetime import datetime
import functools
import time

import os

import jax
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt

import mujoco
from mujoco import mjx

from brax import envs
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State as BraxState
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo_lagrange import train as ppo_lagrange
from brax.training.agents.ppo_lagrange_v2 import train as ppo_lagrange_v2
from brax.training.agents.ppo_lagrange_v3 import train as ppo_lagrange_v3
from brax.io import html, mjcf, model as brax_model
from brax.io import json as brax_json
import wandb
from ml_collections import config_dict
import subprocess

In [None]:
# GPU and MuJoCo setup
if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError('Cannot communicate with GPU. Make sure you have NVIDIA drivers installed.')

os.environ['MUJOCO_GL'] = 'egl'

try:
    print('Checking that the installation succeeded:')
    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the error message above '
        'for more information.')

print('Installation successful.')

# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

In [None]:
# Test SafePointGoal environment creation
env_name = 'safe_point_goal'

# Test different hazard counts
hazard_counts = [4, 8]

for num_hazards in hazard_counts:
    print(f"\n=== Testing {env_name} with {num_hazards} hazards ===")
    
    try:
        # Create environment
        env = envs.get_environment(env_name, num_hazards=num_hazards)
        
        print(f"✓ Environment created successfully")
        print(f"  Observation space: {env.observation_size}")
        print(f"  Action space: {env.action_size}")
        
        # Test reset
        rng = jax.random.PRNGKey(42)
        state = env.reset(rng)
        print(f"✓ Reset successful")
        print(f"  Initial reward: {state.reward}")
        
        # Test step
        action = jnp.zeros(env.action_size)
        next_state = env.step(state, action)
        print(f"✓ Step successful")
        print(f"  Next reward: {next_state.reward}")
        
    except Exception as e:
        print(f"✗ Error with {num_hazards} hazards: {e}")

print("\n🎉 Environment testing completed!")