In [None]:

from datetime import datetime
import functools
import time # Added for PRNGKey seeding in rollout

import os

import jax
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt

import mujoco
from mujoco import mjx

from brax import envs
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State as BraxState # Alias Brax's State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo_lagrange import train as ppo_lagrange
from brax.training.agents.ppo_lagrange_v2 import train as ppo_lagrange_v2
from brax.io import html, mjcf, model as brax_model # Alias brax model
from brax.io import json as brax_json # Added for saving trajectories if needed
import wandb
from ml_collections import config_dict
import subprocess


In [None]:
if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError('Cannot communicate with GPU. Make sure you have NVIDIA drivers installed.')

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
os.environ['MUJOCO_GL'] = 'egl'

try:
    print('Checking that the installation succeeded:')
    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the error message above '
        'for more information.')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Check for GPU availability
if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError('Cannot communicate with GPU. Make sure you have NVIDIA drivers installed.')

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
os.environ['MUJOCO_GL'] = 'egl'

try:
    print('Checking that the installation succeeded:')
    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the error message above '
        'for more information.')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

In [None]:
def default_config() -> config_dict.ConfigDict:
  """Returns the default config for PointHazardGoal environment."""
  config = config_dict.create(
      # New safety-gymnasium reward parameters
      reward_distance=3,  # Match successful PPO config
      reward_goal=10.0,      # Match successful PPO config
      goal_size=0.7,        # Distance threshold for achieving the goal
      reward_orientation=False, # Optional: Reward for maintaining upright orientation
      reward_orientation_scale=0.002, # Scale for orientation reward
      reward_orientation_body='agent', # Body to check orientation (unused if reward_orientation=False)
      ctrl_cost_weight=0.001, # Match successful PPO config
      hazard_size=0.7,       # Distance threshold for hazard cost
      # Other parameters (kept or adjusted)
      terminate_when_unhealthy=True, # Keep termination based on health
      healthy_z_range=(0.05, 0.3),    # Keep health definition
      reset_noise_scale=0.005,
      exclude_current_positions_from_observation=True,
      max_velocity=5.0,  # Keep velocity limit for calculation stability
      debug=False,
  )
  return config

In [None]:
env_name = 'point_resetting_goal_random_hazard_lidar_sensor_obs'
# metrics_list = [] # This can be removed if wandb is the primary logger for training progress
# Instantiate the training environment
train_environment = envs.get_environment(env_name)
# train_environment = PPOLagrangeCostWrapper(train_environment)
# Instantiate a separate environment for evaluation/rollout
eval_env = envs.get_environment(env_name) # Using the same type for now
# eval_env = PPOLagrangeCostWrapper(eval_env)

print(f"Training environment '{env_name}' instantiated.")
print(f"Evaluation environment '{env_name}' instantiated.")


In [None]:
# Import necessary modules for the wrapper
from brax.envs.base import Wrapper

# Define a custom wrapper to handle cost field properly
class CostExtraWrapper(Wrapper):
    """Wrapper that moves cost from info to extras for PPO Lagrange."""
    
    def step(self, state: BraxState, action: jax.Array) -> BraxState:
        next_state = self.env.step(state, action)
        
        # PPO Lagrange expects cost in state.info during collection,
        # and the training wrapper will move it to extras automatically
        # Just ensure cost is in info
        if 'cost' not in next_state.info:
            # If cost is in metrics, copy it to info
            if 'cost' in next_state.metrics:
                next_state.info['cost'] = next_state.metrics['cost']
            else:
                # Default to 0 if no cost found
                next_state.info['cost'] = jnp.zeros_like(next_state.reward)
        
        return next_state
    
    def reset(self, rng: jax.Array) -> BraxState:
        state = self.env.reset(rng)
        # Ensure cost is initialized in info
        if 'cost' not in state.info:
            state.info['cost'] = jnp.zeros_like(state.reward)
        return state

# Custom wrap function that includes the cost wrapper
def wrap_env_with_cost(env: envs.Env) -> envs.Env:
    """Wrap environment with cost handling for PPO Lagrange."""
    return CostExtraWrapper(env)


In [None]:
def custom_progress_fn(num_steps, metrics, metrics_list=None, use_wandb=False):
    """
    Progress function to print metrics and log to Weights & Biases.
    `metrics` dict can come from Brax's EpisodeMetricsLogger (keys like 'episode/reward')
    or from Evaluator (keys like 'eval/episode_reward').
    """
    print(f"Step {num_steps}:")
    wandb_log_data = {}
    for key, value in metrics.items():
        log_value = value.item() if hasattr(value, 'item') else value 
        # Print lambda and cost-related metrics for debugging
        if "lambda" in key or "cost" in key or "constraint" in key:
            print(f"  {key}: {log_value}")
        
        if not (key.startswith("episode/") or key.startswith("eval/") or key.startswith("training/")):
             wandb_log_data[f"training_batch/{key}"] = log_value 
        else:
             wandb_log_data[key] = log_value

    if use_wandb and wandb.run is not None and wandb_log_data:
        wandb.log(wandb_log_data, step=int(num_steps))

    if metrics_list is not None: 
        metrics_data_local = {'step': num_steps}
        metrics_data_local.update(metrics) 
        metrics_list.append(metrics_data_local)

In [None]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items(): 
            setattr(self, key, value)

# Training Arguments (Args class definition as before)
args = Args(
    num_timesteps=30_000_000,
    num_evals=5, # This controls how many times the separate eval_env is run
    reward_scaling=0.1,
    episode_length=2000,
    normalize_observations=True,
    action_repeat=1, 
    unroll_length=8,
    num_minibatches=32,
    num_updates_per_batch=6,
    discounting=0.99,
    learning_rate=5e-4,
    entropy_cost=5e-3,
    num_envs=2048,
    batch_size=1024,
    max_devices_per_host=None,
    seed=5,  # Match successful PPO config
    safety_bound=0.2,
    lagrangian_coef_rate=0.001,
    initial_lambda_lagr=0.0,
)

# +++ CRITICAL FIX: Apply cost wrapper to BOTH training and evaluation environments +++
# This ensures both training and evaluation use the same cost computation
# train_environment = PPOLagrangeCostWrapper(envs.get_environment(env_name))
# eval_env = PPOLagrangeCostWrapper(envs.get_environment(env_name))
# print(f"Applied PPOLagrangeCostWrapper to both training and evaluation environments")
# # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

# +++ Initialize Weights & Biases +++ (Same as before)
config_dict_for_wandb = vars(args)
env = envs.get_environment(env_name)
current_env_config_for_wandb = default_config().to_dict()
config_dict_for_wandb.update({
    "environment_name": env_name,
    "reward_distance": current_env_config_for_wandb.get("reward_distance"),
    "reward_goal": current_env_config_for_wandb.get("reward_goal"),
    "goal_size": current_env_config_for_wandb.get("goal_size"),
    "ctrl_cost_weight": current_env_config_for_wandb.get("ctrl_cost_weight"),
    "reward_orientation": current_env_config_for_wandb.get("reward_orientation"),
    "reward_orientation_scale": current_env_config_for_wandb.get("reward_orientation_scale"),
})
run_name = f"{env_name}_ppo_lag_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
run = wandb.init(
    project="safe_brax",
    name=run_name,
    config=config_dict_for_wandb,
)
# ++++++++++++++++++++++++++++++++

# Bound progress function (This now becomes more important for training metrics)
metrics_list = []
# Ensure custom_progress_fn is ready to handle metrics dictionary and log to wandb
# (Definition of custom_progress_fn shown after this cell block)
bound_progress_fn = functools.partial(custom_progress_fn, metrics_list=metrics_list, use_wandb=True)

# Setup the PPO training function
train_fn = functools.partial(
    ppo_lagrange_v2,
    num_timesteps=args.num_timesteps,
    num_evals=args.num_evals, # For separate evaluations using eval_env
    reward_scaling=args.reward_scaling,
    episode_length=args.episode_length,
    normalize_observations=args.normalize_observations,
    action_repeat=args.action_repeat,
    unroll_length=args.unroll_length,
    num_minibatches=args.num_minibatches,
    num_updates_per_batch=args.num_updates_per_batch,
    learning_rate=args.learning_rate,
    entropy_cost=args.entropy_cost,
    discounting=args.discounting,
    num_envs=args.num_envs,
    batch_size=args.batch_size,
    max_devices_per_host=args.max_devices_per_host,
    seed=args.seed,
    # +++ Brax's new training metrics logging +++
    log_training_metrics=True, # Enable Brax's internal episode metrics logging
    training_metrics_steps=args.episode_length * args.num_envs / 5, # Example: log ~5 times per "epoch" of data collection
    # Or set to a fixed number like 100000, or None to use default
    # +++++++++++++++++++++++++++++++++++++++++++++
    safety_bound=args.safety_bound,
    lagrangian_coef_rate=args.lagrangian_coef_rate,
    initial_lambda_lagr=args.initial_lambda_lagr,
)
print("Training arguments and PPO train_fn configured. Weights & Biases run initialized.")

In [None]:
print(f"Starting PPO-Lagrange training for {env_name}...")
make_inference_fn, params, final_eval_metrics = train_fn( # Renamed for clarity
    environment=train_environment,
    eval_env=eval_env, # Pass the separate eval_environment
    progress_fn=bound_progress_fn
)
print("Training finished.")
print(f"Final evaluation metrics: {final_eval_metrics}")

# +++ Log final evaluation metrics to W&B +++
if wandb.run is not None and final_eval_metrics:
    # These metrics should already be prefixed with 'eval/' by evaluator.run_evaluation
    # if not, you might need to add it.
    # Let's assume they are correctly prefixed or are self-descriptive.
    final_log_data = {}
    for key, value in final_eval_metrics.items():
        log_value = value.item() if hasattr(value, 'item') else value
        # If keys from final_eval_metrics are not prefixed with 'eval/', add it
        # Example: if key is 'episode_reward', log as 'final_eval/episode_reward'
        # if not key.startswith("eval/"):
        #     final_log_data[f"final_evaluation/{key}"] = log_value
        # else:
        final_log_data[key] = log_value # Assuming keys are like 'eval/episode_reward'

    wandb.log(final_log_data, step=int(args.num_timesteps)) # Log at the final step
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++

In [None]:
# import csv
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# metrics_filename = f"metrics/training_metrics_notebook_{env_name}_{timestamp}.csv"
# os.makedirs('metrics', exist_ok=True)
# with open(metrics_filename, 'w', newline='') as f:
#     writer = csv.writer(f)
#     # ... (write headers and data from metrics_list as in your save_brax_metrics) ...
# print(f"Metrics saved to {metrics_filename}")

In [None]:
# Save the trained model (params)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = f'models/{env_name.lower()}_lag_notebook_{timestamp}'
os.makedirs('models', exist_ok=True)
brax_model.save_params(model_path, params)
print(f"Trained model parameters saved to: {model_path}")

In [None]:
# Instantiate a separate environment for evaluation/rollout
eval_env_name = env_name # Or the same env_name
eval_environment = envs.get_environment(eval_env_name)

# jit reset and step for the eval environment
jit_eval_reset = jax.jit(eval_environment.reset)
jit_eval_step = jax.jit(eval_environment.step)

# Create the actual inference function using the factory and loaded params
# `params` here is the 3-tuple: (normalizer_params, policy_network_weights, value_network_weights)
# `make_inference_fn` knows how to use these.
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

print(f"Inference function for rollout created for {eval_env_name}.")

In [None]:
num_rollout_steps = 5000
rollout_frames = []

# Data collection for plotting (expanded)
rollout_metrics_data = {
    'distance_to_goal': [],
    'last_dist_goal': [], # New
    'reward': [],
    'dist_reward': [],       # New
    'goal_reward': [],       # New
    'orientation_reward': [],# New
    'ctrl_cost': [],         # New
    'x_position': [],
    'y_position': [],
    'agent_pos_x': [],       # New
    'agent_pos_y': [],       # New
    'goal_pos_x': [],        # New
    'goal_pos_y': [],        # New
    'x_velocity': [],
    'y_velocity': [],
    'goals_reached_count': [],
    'cost': []
}

actions = []

rng_rollout = jax.random.PRNGKey(int(time.time()))
eval_state = jit_eval_reset(rng_rollout)

print(f"Starting rollout for {num_rollout_steps} steps...")
for i in range(num_rollout_steps):
    act_rng, rng_rollout = jax.random.split(rng_rollout)
    action, _ = jit_inference_fn(eval_state.obs, act_rng)
    actions.append(action)
    
    # Store agent_pos and goal_pos *before* the step, if they are needed for "last_goal_pos" type logic
    # For current step's agent_pos and goal_pos, they are in eval_state.info *after* the step.

    eval_state = jit_eval_step(eval_state, action)
    rollout_frames.append(eval_state.pipeline_state)

    # Collect metrics from eval_state.metrics
    rollout_metrics_data['distance_to_goal'].append(eval_state.metrics.get('distance_to_goal', np.nan))
    rollout_metrics_data['reward'].append(eval_state.metrics.get('reward', np.nan))
    rollout_metrics_data['cost'].append(eval_state.metrics.get('cost', np.nan))
    rollout_metrics_data['dist_reward'].append(eval_state.metrics.get('dist_reward', np.nan))
    rollout_metrics_data['goal_reward'].append(eval_state.metrics.get('goal_reward', np.nan))
    rollout_metrics_data['orientation_reward'].append(eval_state.metrics.get('orientation_reward', np.nan))
    rollout_metrics_data['ctrl_cost'].append(eval_state.metrics.get('ctrl_cost'))
    rollout_metrics_data['x_position'].append(eval_state.metrics.get('x_position', np.nan))
    rollout_metrics_data['y_position'].append(eval_state.metrics.get('y_position', np.nan))
    rollout_metrics_data['x_velocity'].append(eval_state.metrics.get('x_velocity', np.nan))
    rollout_metrics_data['y_velocity'].append(eval_state.metrics.get('y_velocity', np.nan))
    rollout_metrics_data['goals_reached_count'].append(eval_state.metrics.get('goals_reached_count', np.nan))

    # Collect metrics from eval_state.info (these are from the *current* state after the step)
    rollout_metrics_data['last_dist_goal'].append(eval_state.info.get('last_dist_goal', np.nan))
    current_agent_pos = eval_state.info.get('agent_pos', np.array([np.nan, np.nan, np.nan]))
    current_goal_pos = eval_state.info.get('goal_pos', np.array([np.nan, np.nan, np.nan]))
    rollout_metrics_data['agent_pos_x'].append(current_agent_pos[0])
    rollout_metrics_data['agent_pos_y'].append(current_agent_pos[1])
    rollout_metrics_data['goal_pos_x'].append(current_goal_pos[0])
    rollout_metrics_data['goal_pos_y'].append(current_goal_pos[1])
    
    if i % 100 == 0 or i == num_rollout_steps - 1:
        print(f"Rollout step {i+1}/{num_rollout_steps} completed. Goals reached: {eval_state.metrics.get('goals_reached_count', 0)}")

    if eval_state.done:
        print(f"Rollout terminated early at step {i+1} due to done signal.")
        remaining_steps = num_rollout_steps - (i + 1)
        for key_metric in rollout_metrics_data.keys():
            rollout_metrics_data[key_metric].extend([np.nan] * remaining_steps)
        break
print("Rollout finished.")

# Save trajectory as JSON (optional)
trajectory_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs('trajectories', exist_ok=True)
rollout_trajectory_path = f'trajectories/{eval_env_name}_lag_rollout_notebook.json'
brax_json.save(rollout_trajectory_path, eval_environment.sys, rollout_frames)
print(f"Rollout trajectory saved to {rollout_trajectory_path}")

# Render video (optional)
# video_html = HTML(html.render(eval_environment.sys, rollout_frames))
# video_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# os.makedirs('videos', exist_ok=True)
# media.write_video(f'videos/{eval_env_name}_rollout_notebook_{video_timestamp}.mp4', 
#                   eval_environment.render(rollout_frames, camera='side'),
#                   fps=1.0 / eval_environment.dt)
# print(f"Rollout video saved.")
# video_html

# save metrics to wandb

In [None]:
def analyze_rollout_bug(rollout_metrics_data):
    """Analyze the rollout data to find the distance reward bug."""
    print("=== Analyzing Rollout Distance Reward Bug ===")
    
    # Extract data
    distance_to_goal = np.array(rollout_metrics_data['distance_to_goal'])
    last_dist_goal = np.array(rollout_metrics_data['last_dist_goal'])
    dist_reward = np.array(rollout_metrics_data['dist_reward'])
    goal_reward = np.array(rollout_metrics_data['goal_reward'])
    
    # Find goal achievement steps
    goal_steps = np.where(goal_reward > 0)[0]
    print(f"Goal achieved at steps: {goal_steps}")
    
    # Focus on the problem area - after first goal achievement
    if len(goal_steps) > 0:
        problem_start = goal_steps[0]
        
        print(f"\nAnalyzing steps around goal reset (step {problem_start}):")
        
        for i in range(max(0, problem_start-2), min(problem_start + 10, len(distance_to_goal))):
            # Manual calculation of expected distance reward
            if i > 0:
                expected_dist_reward = last_dist_goal[i] - distance_to_goal[i]
            else:
                expected_dist_reward = 0.0
                
            actual_dist_reward = dist_reward[i]
            
            marker = " <-- GOAL!" if i == problem_start else ""
            print(f"  Step {i}: dist_to_goal={distance_to_goal[i]:.3f}, "
                  f"last_dist_goal={last_dist_goal[i]:.3f}, "
                  f"expected_reward={expected_dist_reward:.4f}, "
                  f"actual_reward={actual_dist_reward:.4f}{marker}")
    
    # Check if last_dist_goal equals distance_to_goal
    diff = np.abs(last_dist_goal - distance_to_goal)
    max_diff = np.max(diff)
    steps_with_diff = np.sum(diff > 1e-4)
    
    print(f"\nComparison of last_dist_goal vs distance_to_goal:")
    print(f"  Max difference: {max_diff:.6f}")
    print(f"  Steps with significant difference (>1e-4): {steps_with_diff}")
    print(f"  Are they essentially identical? {max_diff < 1e-3}")
    
    # If they're identical, that's the bug!
    if max_diff < 1e-3:
        print(f"\n🚨 BUG FOUND: last_dist_goal ≈ distance_to_goal")
        print(f"   This makes dist_reward = (last_dist_goal - distance_to_goal) ≈ 0")
        print(f"   The agent can move freely without distance penalties!")
    
    # Check the actual values during the major movement
    if len(goal_steps) > 0:
        start_step = goal_steps[0] + 1
        end_step = min(start_step + 50, len(distance_to_goal))
        
        print(f"\nDuring major movement (steps {start_step}-{end_step}):")
        movement_distances = distance_to_goal[end_step-1] - distance_to_goal[start_step]
        total_dist_rewards = np.sum(dist_reward[start_step:end_step])
        
        print(f"  Distance change: {distance_to_goal[start_step]:.2f} → {distance_to_goal[end_step-1]:.2f} "
              f"(Δ = {movement_distances:.2f})")
        print(f"  Total distance rewards: {total_dist_rewards:.4f}")
        print(f"  Expected total if working: ≈ {-movement_distances:.2f}")
        
        if abs(total_dist_rewards) < 0.1 and abs(movement_distances) > 10:
            print(f"  🚨 CONFIRMED: Large movement ({movement_distances:.1f}) with tiny rewards ({total_dist_rewards:.4f})")

# Run the analysis
analyze_rollout_bug(rollout_metrics_data)

In [None]:

plot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_dir = 'plots'
os.makedirs(plot_dir, exist_ok=True)
plot_path_base = f'{plot_dir}/{eval_env_name}_lag_rollout_notebook_{plot_timestamp}'
num_actual_rollout_steps = len(rollout_metrics_data['distance_to_goal']) # Use a consistent metric for length
time_steps_rollout = np.arange(num_actual_rollout_steps)

plt.style.use('seaborn-v0_8-darkgrid')

# Plot 1: Distance and Last Distance to Goal
plt.figure(figsize=(12, 7))
plt.plot(time_steps_rollout, rollout_metrics_data['distance_to_goal'], label='Current Distance to Goal (metrics)', linestyle='-')
plt.plot(time_steps_rollout, rollout_metrics_data['last_dist_goal'], label='Last Distance to Goal (info)', linestyle='--')
plt.xlabel("Time Step")
plt.ylabel("Distance")
plt.title(f"{eval_env_name} - Rollout: Goal Tracking")
plt.legend()
plt.tight_layout()
goal_tracking_plot_path = f'{plot_path_base}_goal_distances.png'
plt.savefig(goal_tracking_plot_path)
plt.show()
plt.close()
print(f"Goal tracking plot saved to: {goal_tracking_plot_path}")

# Plot: Cost Plot 
plt.figure(figsize=(12, 7))
plt.plot(time_steps_rollout, rollout_metrics_data['cost'], label='Cost', linestyle='-')
plt.xlabel("Time Step")
plt.ylabel("Cost")
plt.title(f"{eval_env_name} - Rollout: Cost")
plt.legend()
plt.tight_layout()
cost_plot_path = f'{plot_path_base}_cost.png'
plt.savefig(cost_plot_path)
plt.show()
plt.close()
print(f"Cost plot saved to: {cost_plot_path}")

# Cumulative Cost Plot
cumulative_cost = np.cumsum(rollout_metrics_data['cost'])
plt.figure(figsize=(12, 7))
plt.plot(time_steps_rollout, cumulative_cost, label='Cumulative Cost', color='red')
plt.xlabel("Time Step")
plt.ylabel("Cumulative Cost")
plt.title(f"{eval_env_name} - Rollout: Cumulative Cost Over Time")
plt.legend()
plt.tight_layout()
cumulative_cost_plot_path = f'{plot_path_base}_cumulative_cost.png'
plt.savefig(cumulative_cost_plot_path)
plt.show()
plt.close()
print(f"Cumulative cost plot saved to: {cumulative_cost_plot_path}")


# Plot 2: Reward Component Breakdown
plt.figure(figsize=(12, 7))
plt.plot(time_steps_rollout, rollout_metrics_data['dist_reward'], label='Distance Reward', alpha=0.7)
plt.plot(time_steps_rollout, rollout_metrics_data['goal_reward'], label='Goal Reward', alpha=0.7)
plt.plot(time_steps_rollout, rollout_metrics_data['orientation_reward'], label='Orientation Reward', alpha=0.7)
plt.plot(time_steps_rollout, -np.array(rollout_metrics_data['ctrl_cost']), label='Negative Control Cost', alpha=0.7) # Plotting as negative
plt.plot(time_steps_rollout, rollout_metrics_data['reward'], label='Total Reward', linestyle='--', color='black', linewidth=2)
plt.xlabel("Time Step")
plt.ylabel("Reward Value")
plt.title(f"{eval_env_name} - Rollout: Reward Component Breakdown")
plt.legend()
plt.tight_layout()
reward_breakdown_plot_path = f'{plot_path_base}_reward_breakdown.png'
plt.savefig(reward_breakdown_plot_path)
plt.show()
plt.close()
print(f"Reward breakdown plot saved to: {reward_breakdown_plot_path}")

# Plot 3: Agent and Goal Positions (X and Y over time)
fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

# X Positions
axs[0].plot(time_steps_rollout, rollout_metrics_data['agent_pos_x'], label='Agent X Position (info)', linestyle='-')
axs[0].plot(time_steps_rollout, rollout_metrics_data['goal_pos_x'], label='Goal X Position (info)', linestyle='--')
axs[0].set_ylabel("X Position")
axs[0].set_title(f"{eval_env_name} - Agent and Goal X Positions Over Time")
axs[0].legend()
axs[0].grid(True)

# Y Positions
axs[1].plot(time_steps_rollout, rollout_metrics_data['agent_pos_y'], label='Agent Y Position (info)', linestyle='-')
axs[1].plot(time_steps_rollout, rollout_metrics_data['goal_pos_y'], label='Goal Y Position (info)', linestyle='--')
axs[1].set_xlabel("Time Step")
axs[1].set_ylabel("Y Position")
axs[1].set_title(f"{eval_env_name} - Agent and Goal Y Positions Over Time")
axs[1].legend()
axs[1].grid(True)

plt.tight_layout()
agent_goal_pos_plot_path = f'{plot_path_base}_agent_goal_positions.png'
plt.savefig(agent_goal_pos_plot_path)
plt.show()
plt.close()
print(f"Agent and Goal positions plot saved to: {agent_goal_pos_plot_path}")


# Keep existing X-Y Trajectory Plot (or merge agent/goal start/end points if desired)
plt.figure(figsize=(10, 8))
valid_x = np.array(rollout_metrics_data['x_position']) # from metrics (agent body)
valid_y = np.array(rollout_metrics_data['y_position']) # from metrics (agent body)
goal_x_series = np.array(rollout_metrics_data['goal_pos_x']) # from info
goal_y_series = np.array(rollout_metrics_data['goal_pos_y']) # from info

# Filter out NaNs for agent path
valid_indices_agent = ~(np.isnan(valid_x) | np.isnan(valid_y))
valid_x_agent = valid_x[valid_indices_agent]
valid_y_agent = valid_y[valid_indices_agent]

# Filter out NaNs for goal path (if goal moves)
valid_indices_goal = ~(np.isnan(goal_x_series) | np.isnan(goal_y_series))
valid_x_goal = goal_x_series[valid_indices_goal]
valid_y_goal = goal_y_series[valid_indices_goal]


if len(valid_x_agent) > 0 and len(valid_y_agent) > 0:
    plt.plot(valid_x_agent, valid_y_agent, 'k-', alpha=0.7, label='Agent Path')
    plt.scatter(valid_x_agent[0], valid_y_agent[0], c='green', s=100, label='Agent Start', zorder=5, marker='o')
    plt.scatter(valid_x_agent[-1], valid_y_agent[-1], c='red', s=100, label='Agent End', zorder=5, marker='x')
    
    if len(valid_x_goal) > 0 and len(valid_y_goal) > 0:
        # Plot goal path if it changes, or just start/end points
        # For a fixed goal, goal_x_series[0] and goal_y_series[0] would be the goal position
        plt.scatter(valid_x_goal[0], valid_y_goal[0], c='blue', s=150, label='Initial Goal', zorder=4, marker='*')
        if any(g_x != valid_x_goal[0] for g_x in valid_x_goal) or any(g_y != valid_y_goal[0] for g_y in valid_y_goal):
             plt.plot(valid_x_goal, valid_y_goal, 'b--', alpha=0.5, label='Goal Path (if dynamic)')
             plt.scatter(valid_x_goal[-1], valid_y_goal[-1], c='purple', s=150, label='Final Goal', zorder=4, marker='*')


    plt.xlabel("X Position")
    plt.ylabel("Y Position")
    plt.title(f"{eval_env_name} - Rollout: X-Y Trajectory with Goal(s)")
    plt.legend()
    plt.axis('equal')
    plt.grid(True)
else:
    plt.text(0.5, 0.5, "No valid position data for trajectory plot", ha='center', va='center')
plt.tight_layout()
trajectory_plot_path_updated = f'{plot_path_base}_xy_trajectory_with_goals.png'
plt.savefig(trajectory_plot_path_updated)
plt.show()
plt.close()
print(f"X-Y trajectory plot with goals saved to: {trajectory_plot_path_updated}")


# Hyperparameter Tuning Diagnostic Checklist

## 1. Performance Health Check
- [ ] **Episode Reward**: Is it increasing? Target: > 0 after 5M steps
- [ ] **Goals Reached**: Are any goals being reached? Target: > 0.5 per episode
- [ ] **Distance Reward**: Is it positive on average? Target: > 0

## 2. Safety Balance Check  
- [ ] **Cost Return**: Is it near cost_limit? Target: within 20% of limit
- [ ] **Lambda**: Has it stabilized? Target: 0.1 - 2.0 range
- [ ] **Constraint Violation**: Is it converging to 0? Target: |violation| < 0.05

## 3. Learning Dynamics Check
- [ ] **Policy Loss**: Is it decreasing? Check for plateaus
- [ ] **Value Loss**: Both reward and cost V-loss < 1.0?
- [ ] **Entropy**: Is it > 0.001? (not collapsing)

## 4. Trajectory Analysis
- [ ] **Movement Pattern**: Circular? Straight? Avoiding hazards?
- [ ] **Goal Approach**: Is distance to goal decreasing?
- [ ] **Hazard Interaction**: Appropriate avoidance?


In [None]:
# # Automated hyperparameter grid search for systematic tuning

# def create_hyperparameter_grid():
#     """Create a grid of hyperparameters to try based on common issues."""
    
#     # Base configuration
#     base_config = {
#         'num_timesteps': 5_000_000,  # Shorter for quick tests
#         'num_evals': 10,
#         'episode_length': 2000,
#         'normalize_observations': True,
#         'action_repeat': 1,
#         'unroll_length': 10,
#         'num_minibatches': 32,
#         'num_updates_per_batch': 8,
#         'discounting': 0.97,
#         'num_envs': 2048,
#         'batch_size': 512,
#         'clipping_epsilon': 0.2,
#     }
    
#     # Hyperparameter grid for different scenarios
#     grid = {
#         # Conservative baseline
#         'conservative': {
#             **base_config,
#             'reward_scaling': 1.0,
#             'cost_limit': 0.3,
#             'lambda_init': 0.1,
#             'lambda_lr': 0.01,
#             'lambda_max': 5.0,
#             'entropy_cost': 5e-3,
#             'learning_rate': 3e-4,
#         },
        
#         # Aggressive exploration
#         'exploratory': {
#             **base_config,
#             'reward_scaling': 3.0,
#             'cost_limit': 0.3,
#             'lambda_init': 0.0,
#             'lambda_lr': 0.001,
#             'lambda_max': 2.0,
#             'entropy_cost': 2e-2,
#             'learning_rate': 5e-4,
#         },
        
#         # Balanced approach
#         'balanced': {
#             **base_config,
#             'reward_scaling': 2.0,
#             'cost_limit': 0.25,
#             'lambda_init': 0.0,
#             'lambda_lr': 0.005,
#             'lambda_max': 3.0,
#             'entropy_cost': 1e-2,
#             'learning_rate': 3e-4,
#         },
        
#         # High safety
#         'safety_first': {
#             **base_config,
#             'reward_scaling': 1.5,
#             'cost_limit': 0.15,
#             'lambda_init': 0.5,
#             'lambda_lr': 0.02,
#             'lambda_max': 10.0,
#             'entropy_cost': 5e-3,
#             'learning_rate': 3e-4,
#         },
        
#         # Curriculum (start easy, constraint later)
#         'curriculum': {
#             **base_config,
#             'reward_scaling': 2.5,
#             'cost_limit': 0.4,  # Very permissive initially
#             'lambda_init': 0.0,
#             'lambda_lr': 0.0001,  # Very slow growth
#             'lambda_max': 2.0,
#             'entropy_cost': 1.5e-2,
#             'learning_rate': 4e-4,
#         }
#     }
    
#     return grid

# def run_hyperparameter_experiments(grid, env_name, num_seeds=3):
#     """Run experiments with different hyperparameter configurations."""
    
#     results = {}
    
#     for config_name, config in grid.items():
#         print(f"\n{'='*60}")
#         print(f"Running configuration: {config_name}")
#         print(f"{'='*60}")
        
#         config_results = []
        
#         for seed in range(num_seeds):
#             print(f"\nSeed {seed}...")
            
#             # Update seed in config
#             config['seed'] = seed
            
#             # Create Args object
#             exp_args = Args(**config)
            
#             # Create unique wandb run name
#             run_name = f"{env_name}_{config_name}_seed{seed}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            
#             # Initialize wandb
#             run = wandb.init(
#                 project="safe_brax_hyperparam_search",
#                 name=run_name,
#                 config=config,
#                 group=config_name,
#                 reinit=True
#             )
            
#             try:
#                 # Run training (abbreviated for grid search)
#                 train_env = envs.get_environment(env_name)
#                 eval_env = envs.get_environment(env_name)
                
#                 metrics_list = []
#                 progress_fn = functools.partial(
#                     custom_progress_fn, 
#                     metrics_list=metrics_list, 
#                     use_wandb=True
#                 )
                
#                 train_fn = functools.partial(
#                     ppo_lagrange.train,
#                     **{k: v for k, v in config.items() if k != 'seed'}
#                 )
                
#                 make_inference_fn, params, final_metrics = train_fn(
#                     environment=train_env,
#                     eval_env=eval_env,
#                     progress_fn=progress_fn
#                 )
                
#                 # Extract key metrics
#                 result = {
#                     'seed': seed,
#                     'final_goals': final_metrics.get('eval/episode_goals_reached_count', 0),
#                     'final_reward': final_metrics.get('eval/episode_reward', 0),
#                     'final_cost': final_metrics.get('training/cost_return', 0),
#                     'final_lambda': final_metrics.get('training/lambda', 0),
#                     'metrics_list': metrics_list
#                 }
                
#                 config_results.append(result)
                
#             except Exception as e:
#                 print(f"Error in {config_name} seed {seed}: {e}")
#                 config_results.append({'error': str(e)})
            
#             finally:
#                 wandb.finish()
        
#         # Aggregate results
#         results[config_name] = {
#             'config': config,
#             'runs': config_results,
#             'avg_goals': np.mean([r['final_goals'] for r in config_results if 'final_goals' in r]),
#             'avg_reward': np.mean([r['final_reward'] for r in config_results if 'final_reward' in r]),
#             'avg_cost': np.mean([r['final_cost'] for r in config_results if 'final_cost' in r]),
#         }
    
#     return results

# # Usage example:
# # grid = create_hyperparameter_grid()
# # results = run_hyperparameter_experiments(grid, env_name, num_seeds=2)
# # 
# # # Analyze results
# # for config_name, data in results.items():
# #     print(f"\n{config_name}:")
# #     print(f"  Avg Goals: {data['avg_goals']:.2f}")
# #     print(f"  Avg Reward: {data['avg_reward']:.2f}")
# #     print(f"  Avg Cost: {data['avg_cost']:.3f}")


In [None]:
run.finish()

In [None]:
import csv

def save_metrics_list_to_csv(metrics_list, filename="metrics_list.csv"):
    """
    Save a list of metrics dictionaries to a CSV file.

    Args:
        metrics_list (list): List of dictionaries containing metrics.
        filename (str): Output CSV filename.
    """
    if not metrics_list:
        print("metrics_list is empty, nothing to save.")
        return

    # Flatten nested numpy arrays and handle missing keys
    def flatten_dict(d):
        flat = {}
        for k, v in d.items():
            # Convert numpy arrays to scalars if needed
            try:
                if isinstance(v, np.ndarray):
                    flat[k] = v.item()
                else:
                    flat[k] = v
            except Exception:
                flat[k] = v
        return flat

    # Collect all possible keys
    all_keys = set()
    for entry in metrics_list:
        all_keys.update(flatten_dict(entry).keys())
    all_keys = sorted(all_keys)

    # Write to CSV
    with open(filename, mode='w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=all_keys)
        writer.writeheader()
        for entry in metrics_list:
            flat_entry = flatten_dict(entry)
            writer.writerow({k: flat_entry.get(k, "") for k in all_keys})

    print(f"Saved {len(metrics_list)} metric entries to {filename}")

filename = './metrics/' + run_name + "_seed_" + str(args.seed) + "_metrics_list.csv"
# Example usage:
save_metrics_list_to_csv(metrics_list, filename)

In [None]:
from IPython.display import Audio, display, Javascript
import time

def play_completion_sound():
    """
    Plays a short notification sound in the notebook to notify the user that processing is complete.
    Uses IPython.display.Audio to play a simple beep with forced autoplay.
    """
    # Generate a 440 Hz sine wave beep for 0.5 seconds
    fs = 22050  # Sampling rate
    duration = 1  # seconds (longer for better audibility)
    frequency = 500 # Hz
    t = np.linspace(0, duration, int(fs * duration), False)
    beep = 0.3 * np.sin(2 * np.pi * frequency * t)
    
    # Create audio with autoplay enabled
    audio = Audio(beep, rate=fs, autoplay=True)
    
    # Display the audio widget with a unique ID
    unique_id = f"completion_sound_{int(time.time() * 1000)}"
    display(audio, display_id=unique_id)
    
    # Enhanced Javascript to force audio playback
    display(Javascript(f"""
    setTimeout(function() {{
        // Try multiple selectors to find the audio element
        var audios = document.querySelectorAll('audio');
        var lastAudio = audios[audios.length - 1]; // Get the most recently added audio
        
        if (lastAudio) {{
            lastAudio.volume = 0.2;
            lastAudio.play().catch(function(error) {{
                console.log('Audio autoplay failed:', error);
                // Fallback: create a simple beep using Web Audio API
                try {{
                    var audioContext = new (window.AudioContext || window.webkitAudioContext)();
                    var oscillator = audioContext.createOscillator();
                    var gainNode = audioContext.createGain();
                    
                    oscillator.connect(gainNode);
                    gainNode.connect(audioContext.destination);
                    
                    oscillator.frequency.value = 5;
                    oscillator.type = 'sine';
                    gainNode.gain.setValueAtTime(0.1, audioContext.currentTime);
                    gainNode.gain.exponentialRampToValueAtTime(0.01, audioContext.currentTime + 0.5);
                    
                    oscillator.start(audioContext.currentTime);
                    oscillator.stop(audioContext.currentTime + 0.5);
                }} catch(e) {{
                    console.log('Web Audio API fallback also failed:', e);
                }}
            }});
        }}
    }}, 100);
    """))

# Call the function to play the sound
play_completion_sound()


In [None]:
play_completion_sound()

In [1]:
import wandb

# Bayesian sweep optimizing eval/episode_reward; change to cost or a composite if desired
sweep_config = {
    "name": "ppol_bayes_pointgoal_nb",
    "method": "bayes",
    "metric": {"name": "episode/goals_reached_count", "goal": "maximize"},
    "early_terminate": {"type": "hyperband", "min_iter": 3, "eta": 3},
    "parameters": {
        # Fixed context for this project
        "env": {"value": "point_resetting_goal_random_hazard_lidar_sensor_obs"},
        "alg": {"value": "ppo_lagrange"},
        # PPO core
        "learning_rate": {"distribution": "log_uniform_values", "min": 1e-5, "max": 1e-3},
        "entropy_cost": {"distribution": "log_uniform_values", "min": 1e-4, "max": 1e-2},
        "batch_size": {"values": [256, 512, 1024]},
        "num_minibatches": {"values": [16, 32, 64]},
        "num_updates_per_batch": {"values": [2, 4, 8]},
        "unroll_length": {"values": [5, 10, 20]},
        "gae_lambda": {"min": 0.9, "max": 0.98},
        "clipping_epsilon": {"min": 0.1, "max": 0.3},
        # PPO-Lagrange hparams
        "safety_bound": {"value": 0.2},
        "lagrangian_coef_rate": {"distribution": "log_uniform_values", "min": 1e-3, "max": 1e-1},
        "initial_lambda_lagr": {"values": [0.0, 0.1, 1.0]},
        # Env overrides (optional)
        "env_kwargs": {"value": {"config_overrides": {"hazard_size": 0.7}}},
        # Runtime + eval
        "num_timesteps": {"value": 30_000_000},
        "episode_length": {"value": 1000},
        "num_envs": {"value": 1024},
        "num_evals": {"value": 3},
        "num_eval_envs": {"value": 128},
        "deterministic_eval": {"value": False},
        "normalize_observations": {"value": True},
        # Seeds to compare robustness
        "seed": {"value": 243512},
    },
}


In [2]:
import json
import functools
import jax
from brax import envs
from brax.training.agents.ppo_lagrange_v2 import train as ppo_lagrange_v2
import wandb
# Uses imports and definitions already present in this notebook:
# - ppo_lagrange_v2 (imported as train)
# - custom_progress_fn (logs to wandb and prints)
# - env construction logic is reproduced here to be self-contained per run

# Fallback logging if custom_progress_fn wasn't executed in this session
if 'custom_progress_fn' not in globals():
    def custom_progress_fn(num_steps, metrics, metrics_list=None, use_wandb=True):
        wandb_log = {"env_steps": int(num_steps)}
        for k, v in (metrics or {}).items():
            try:
                wandb_log[k] = float(v)
            except Exception:
                continue
        if use_wandb:
            wandb.log(wandb_log, step=int(num_steps))
        if isinstance(metrics_list, list):
            # store a shallow float-only copy
            flat = {}
            for k, v in (metrics or {}).items():
                try:
                    flat[k] = float(v)
                except Exception:
                    pass
            metrics_list.append(flat)


def train():
    # Do not pass project explicitly; the sweep sets it and avoids the warning
    with wandb.init() as run:
        c = wandb.config

        # Create train/eval envs with optional overrides from sweep config
        env_kwargs = c.get("env_kwargs", None) or {}
        train_environment = envs.get_environment(c.env, **env_kwargs)
        eval_env = envs.get_environment(c.env, **env_kwargs)

        # Bind progress fn for this run (fresh list so runs don't leak state)
        metrics_list = []
        bound_progress_fn = functools.partial(custom_progress_fn, metrics_list=metrics_list, use_wandb=True)

        # Build training callable with sweep hyperparams
        train_fn = functools.partial(
            ppo_lagrange_v2,  # you imported: from brax.training.agents.ppo_lagrange_v2 import train as ppo_lagrange_v2
            num_timesteps=int(c.num_timesteps),
            num_evals=int(c.num_evals),
            num_eval_envs=int(c.num_eval_envs),
            deterministic_eval=bool(c.deterministic_eval),
            episode_length=int(c.episode_length),
            num_envs=int(c.num_envs),
            action_repeat=int(getattr(c, "action_repeat", 1)),
            unroll_length=int(c.unroll_length),
            batch_size=int(c.batch_size),
            num_minibatches=int(c.num_minibatches),
            num_updates_per_batch=int(c.num_updates_per_batch),
            learning_rate=float(c.learning_rate),
            entropy_cost=float(c.entropy_cost),
            discounting=float(getattr(c, "discounting", 0.99)),
            reward_scaling=float(getattr(c, "reward_scaling", 1.0)),
            gae_lambda=float(c.gae_lambda),
            clipping_epsilon=float(c.clipping_epsilon),
            normalize_observations=bool(c.normalize_observations),
            safety_bound=float(c.safety_bound),
            lagrangian_coef_rate=float(c.lagrangian_coef_rate),
            initial_lambda_lagr=float(getattr(c, "initial_lambda_lagr", 0.0)),
            # Ensure we get training metrics periodically
            log_training_metrics=True,
            training_metrics_steps=None,
            seed=int(c.seed),
        )

        # Train
        make_inference_fn, params, final_eval_metrics = train_fn(
            environment=train_environment,
            eval_env=eval_env,
            progress_fn=bound_progress_fn
        )

        # Log final metrics at the last step to avoid out-of-order warnings
        if final_eval_metrics:
            last_step = int(c.num_timesteps)
            wandb.log({k if k.startswith("eval/") else f"eval/{k}": float(v)
                       for k, v in final_eval_metrics.items() if isinstance(v, (int, float))},
                      step=last_step)

# 1) Create the sweep (one-time)
sweep_id = wandb.sweep(sweep_config, project="safe_brax")

# 2) Launch agent(s). In a notebook, run count>1 to iterate configs sequentially here.
#    For parallel workers, start multiple agents (e.g., in terminals) with the same sweep_id.
wandb.agent(sweep_id, function=train, count=60)


Create sweep with ID: 4gqfm0w5
Sweep URL: https://wandb.ai/m-boustani-eindhoven-university-of-technology/safe_brax/sweeps/4gqfm0w5


[34m[1mwandb[0m: Agent Starting Run: tajsjzqt with config:
[34m[1mwandb[0m: 	alg: ppo_lagrange
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	clipping_epsilon: 0.21600191113706177
[34m[1mwandb[0m: 	deterministic_eval: False
[34m[1mwandb[0m: 	entropy_cost: 0.006172038247113296
[34m[1mwandb[0m: 	env: point_resetting_goal_random_hazard_lidar_sensor_obs
[34m[1mwandb[0m: 	env_kwargs: {'config_overrides': {'hazard_size': 0.7}}
[34m[1mwandb[0m: 	episode_length: 1000
[34m[1mwandb[0m: 	gae_lambda: 0.9076281169918482
[34m[1mwandb[0m: 	initial_lambda_lagr: 1
[34m[1mwandb[0m: 	lagrangian_coef_rate: 0.015475408878600418
[34m[1mwandb[0m: 	learning_rate: 8.252955652905625e-05
[34m[1mwandb[0m: 	normalize_observations: True
[34m[1mwandb[0m: 	num_envs: 1024
[34m[1mwandb[0m: 	num_eval_envs: 128
[34m[1mwandb[0m: 	num_evals: 3
[34m[1mwandb[0m: 	num_minibatches: 32
[34m[1mwandb[0m: 	num_timesteps: 30000000
[34m[1mwandb[0m: 	num_updates_per_bat



Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for required sensors...
  Found sensor: accelerometer, ID: 0, Address: 0, Dim: 3
  Found sensor: velocimeter, ID: 1, Address: 3, Dim: 3
  Found sensor: gyro, ID: 2, Address: 6, Dim: 3
  Found sensor: magnetometer, ID: 3, Address: 9, Dim: 3
Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for 

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Reset method - Goal Position: Traced<ShapedArray(float32[3])>with<BatchTrace> with
  val = Traced<ShapedArray(float32[128,3])>with<DynamicJaxprTrace>
  batch_dim = 0
Reset method - Goal Position: Traced<ShapedArray(float32[3])>with<BatchTrace> with
  val = Traced<ShapedArray(float32[128,3])>with<DynamicJaxprTrace>
  batch_dim = 0


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
env_steps,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
episode/cost,▆▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▄█▅▅▃▃▃▃▅▅▇▆▆▅
episode/ctrl_cost,█████▇▇▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
episode/dist_reward,▄▁▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████
episode/distance_to_goal,▃▃▃█▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
episode/goal_reward,█▇▇▇▄▄▇▇▇▇▂▂▁▁▁▁▃▃▃▃▃▃▁▁▂▁▂▂▃▃▁▁▁▁▁▁▂▃▃▂
episode/goals_reached_count,▁▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█████
episode/last_dist_goal,▃██▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
episode/length,████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
episode/orientation_reward,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
env_steps,30064640.0
episode/cost,148.68
episode/ctrl_cost,0.15959
episode/dist_reward,0.01625
episode/distance_to_goal,1695.09381
episode/goal_reward,0.1
episode/goals_reached_count,763.75
episode/last_dist_goal,1695.09921
episode/length,999.0
episode/orientation_reward,0.0


[34m[1mwandb[0m: Agent Starting Run: f4obidul with config:
[34m[1mwandb[0m: 	alg: ppo_lagrange
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	clipping_epsilon: 0.11739169907954404
[34m[1mwandb[0m: 	deterministic_eval: False
[34m[1mwandb[0m: 	entropy_cost: 0.0007906313819905515
[34m[1mwandb[0m: 	env: point_resetting_goal_random_hazard_lidar_sensor_obs
[34m[1mwandb[0m: 	env_kwargs: {'config_overrides': {'hazard_size': 0.7}}
[34m[1mwandb[0m: 	episode_length: 1000
[34m[1mwandb[0m: 	gae_lambda: 0.9362177424349144
[34m[1mwandb[0m: 	initial_lambda_lagr: 1
[34m[1mwandb[0m: 	lagrangian_coef_rate: 0.07873242402786093
[34m[1mwandb[0m: 	learning_rate: 1.267672037417473e-05
[34m[1mwandb[0m: 	normalize_observations: True
[34m[1mwandb[0m: 	num_envs: 1024
[34m[1mwandb[0m: 	num_eval_envs: 128
[34m[1mwandb[0m: 	num_evals: 3
[34m[1mwandb[0m: 	num_minibatches: 16
[34m[1mwandb[0m: 	num_timesteps: 30000000
[34m[1mwandb[0m: 	num_updates_per_bat



Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for required sensors...
  Found sensor: accelerometer, ID: 0, Address: 0, Dim: 3
  Found sensor: velocimeter, ID: 1, Address: 3, Dim: 3
  Found sensor: gyro, ID: 2, Address: 6, Dim: 3
  Found sensor: magnetometer, ID: 3, Address: 9, Dim: 3
Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for 

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Reset method - Goal Position: Traced<ShapedArray(float32[3])>with<BatchTrace> with
  val = Traced<ShapedArray(float32[128,3])>with<DynamicJaxprTrace>
  batch_dim = 0


E0825 09:41:05.224910  474025 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 973, in _bootstrap
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
  File "/home/mrdbstn/school/safe-brax/safe-brax/env/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
  File "/usr/lib/python3.10/threading.py", line 953, in run
  File "/home/mrdbstn/school/safe-brax/safe-brax/env/lib/python3.10/site-packages/wandb/agents/pyagent.py", line 306, in _run_job
  File "/tmp/ipykernel_176727/291527096.py", line 79, in train
  File "/home/mrdbstn/school/safe-brax/safe-brax/brax/training/agents/ppo_lagrange_v2/train.py", line 744, in train
  File "/home/mrdbstn/school/safe-brax/safe-brax/brax/training/agents/ppo_lagrange_v2/train.py", line 611, in training_epoch_with_timing
  File "/home/mrdbstn/school/safe-brax/safe

0,1
env_steps,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇████
episode/cost,█████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁
episode/ctrl_cost,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████
episode/dist_reward,█████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁
episode/distance_to_goal,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████
episode/goal_reward,█████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁
episode/goals_reached_count,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████
episode/last_dist_goal,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████
episode/length,█████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁
episode/orientation_reward,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
env_steps,2498560.0
episode/cost,207.26
episode/ctrl_cost,0.56136
episode/dist_reward,-1.60944
episode/distance_to_goal,1843.4392
episode/goal_reward,0.0
episode/goals_reached_count,29.97
episode/last_dist_goal,1842.90271
episode/length,999.0
episode/orientation_reward,0.0


[34m[1mwandb[0m: Agent Starting Run: j3i5vnif with config:
[34m[1mwandb[0m: 	alg: ppo_lagrange
[34m[1mwandb[0m: 	batch_size: 512
[34m[1mwandb[0m: 	clipping_epsilon: 0.27612835142968584
[34m[1mwandb[0m: 	deterministic_eval: False
[34m[1mwandb[0m: 	entropy_cost: 0.00019619198590600128
[34m[1mwandb[0m: 	env: point_resetting_goal_random_hazard_lidar_sensor_obs
[34m[1mwandb[0m: 	env_kwargs: {'config_overrides': {'hazard_size': 0.7}}
[34m[1mwandb[0m: 	episode_length: 1000
[34m[1mwandb[0m: 	gae_lambda: 0.9371334851850563
[34m[1mwandb[0m: 	initial_lambda_lagr: 1
[34m[1mwandb[0m: 	lagrangian_coef_rate: 0.010989660630117222
[34m[1mwandb[0m: 	learning_rate: 6.650888753766045e-05
[34m[1mwandb[0m: 	normalize_observations: True
[34m[1mwandb[0m: 	num_envs: 1024
[34m[1mwandb[0m: 	num_eval_envs: 128
[34m[1mwandb[0m: 	num_evals: 3
[34m[1mwandb[0m: 	num_minibatches: 64
[34m[1mwandb[0m: 	num_timesteps: 30000000
[34m[1mwandb[0m: 	num_updates_per_b



Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for required sensors...
  Found sensor: accelerometer, ID: 0, Address: 0, Dim: 3
  Found sensor: velocimeter, ID: 1, Address: 3, Dim: 3
  Found sensor: gyro, ID: 2, Address: 6, Dim: 3
  Found sensor: magnetometer, ID: 3, Address: 9, Dim: 3
Body 0 name: world, mocapid: [-1]
Body 1 name: agent, mocapid: [-1]
Body 2 name: goal, mocapid: [0]
Goal body found with mocapid: [0]
Body 3 name: hazard1, mocapid: [1]
Hazard body found with mocapid: [1]
Body 4 name: hazard2, mocapid: [2]
Hazard body found with mocapid: [2]
Body 5 name: hazard3, mocapid: [3]
Hazard body found with mocapid: [3]
Model has 7 sensors. Searching for 

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Reset method - Goal Position: Traced<ShapedArray(float32[3])>with<BatchTrace> with
  val = Traced<ShapedArray(float32[128,3])>with<DynamicJaxprTrace>
  batch_dim = 0


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
