In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8" # 0.9 causes too much lag.
from datetime import datetime
import functools

# Math
import jax.numpy as jp
import numpy as np
import jax
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', 'high')
from brax import math

# Sim
import mujoco
import mujoco.mjx as mjx

# Brax
from brax import envs
from brax.base import Motion, Transform
from brax.io import mjcf
from brax.envs.base import PipelineEnv, State
from brax.mjx.pipeline import _reformat_contact
from brax.training.acme import running_statistics
from brax.io import model

# Algorithms
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.ppo import train as ppo

# Supporting
from etils import epath
import mediapy as media
import matplotlib.pyplot as plt
from ml_collections import config_dict
from typing import Any, Dict

In [None]:
#@title Check if MuJoCo installation was successful

# Set up GPU rendering.
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Check if installation was succesful.
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Other imports and helper functions
import time
import itertools
import numpy as np

# Graphics and plotting.
print('Installing mediapy:')

import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

In [None]:
import mujoco
import mediapy as media
import numpy as np

# ==========================================
# BIO-INSPIRED ANT NAVIGATION SIMULATION
# ==========================================

# --- RANDOMIZATION LOGIC ---
# Generate random positions for Package and Destination
# Range: -3.5 to 3.5 (Floor is 10x10, leaving margin for walls)
def get_random_pos(exclude_points, min_dist=2.0):
    while True:
        pos = np.random.uniform(-3.5, 3.5, 2)
        # Check distance from all excluded points (Start, other targets)
        valid = True
        for p in exclude_points:
            if np.linalg.norm(pos - p) < min_dist:
                valid = False
                break
        if valid:
            return pos

# 1. Start is at 0,0
start_pos = np.array([0.0, 0.0])

# 2. Randomize Package Position
pkg_pos = get_random_pos([start_pos], min_dist=2.5)
pkg_x, pkg_y = pkg_pos

# 3. Randomize Destination Position
dest_pos = get_random_pos([start_pos, pkg_pos], min_dist=3.0)
dest_x, dest_y = dest_pos

print(f"Mission Setup -> Package: ({pkg_x:.2f}, {pkg_y:.2f}) | Destination: ({dest_x:.2f}, {dest_y:.2f})")

# NOTE: XML Reordered to place Robot first (indices 0-6) so it is the controllable agent.
warehouse_xml = f"""
<mujoco>
  <option timestep="0.005"/>
  <visual>
    <headlight active="1"/>
  </visual>

  <worldbody>
    <!-- FLOOR -->
    <geom name="floor" type="plane" size="10 10 0.1" rgba="0.9 0.9 0.9 1"/>

    <!-- LIGHTS -->
    <light pos="0 0 10" castshadow="false"/>

    <!-- OVERHEAD CAMERA -->
    <camera name="overhead" pos="0 0 9" mode="fixed" euler="0 0 0"/>

    <!-- 1. ROBOT (Black) - First body so it gets qpos[0:7] and qvel[0:6] -->
    <body name="robot" pos="0 0 0.2" euler="0 0 0">
      <freejoint/>
      <geom name="chassis" type="box" size="0.3 0.2 0.1" rgba="0.1 0.1 0.1 1" mass="2.0"/>
      <geom name="head" type="box" size="0.1 0.1 0.1" pos="0.3 0 0" rgba="0.8 0 0 1"/> <!-- Red head for visibility -->
      <camera name="eye" pos="0.35 0 0.1" fovy="80" mode="fixed"/>
    </body>

    <!-- 2. CHARGING DOCK (Green - Home) -->
    <site name="dock_marker" type="cylinder" size="0.5 0.02" pos="0 0 0.01" rgba="0 1 0 0.5"/>

    <!-- 3. PACKAGE (Yellow Box) - Random Position -->
    <body name="package" pos="{pkg_x} {pkg_y} 0.2">
      <freejoint/>
      <geom name="pkg_geom" type="box" size="0.3 0.3 0.3" rgba="1 0.8 0 1" mass="0.1"/>
    </body>

    <!-- 4. DESTINATION (Red Rack) - Random Position -->
    <body name="rack" pos="{dest_x} {dest_y} 0.5">
        <geom name="shelf" type="box" size="0.5 0.5 0.5" rgba="0.8 0 0 1"/>
    </body>

    <!-- Obstacle -->
    <geom name="col" type="cylinder" size="0.2 1" pos="1.5 1.5 0" rgba="0.5 0.5 0.5 1"/>

  </worldbody>
</mujoco>
"""

# ==========================================
# PATH INTEGRATION
# ==========================================
class PathIntegrator:
    def __init__(self):
        self.home_vector = np.array([0.0, 0.0])
        self.last_pos = np.array([0.0, 0.0])

    def update(self, current_pos):
        displacement = current_pos - self.last_pos
        self.home_vector += displacement
        self.last_pos = current_pos.copy()

    def get_home_direction(self):
        dist = np.linalg.norm(self.home_vector)
        if dist < 0.1:
            return None
        return np.arctan2(-self.home_vector[1], -self.home_vector[0])

    def get_home_distance(self):
        return np.linalg.norm(self.home_vector)

# ==========================================
# VISUAL PROCESSING
# ==========================================
def get_visual_info(model, data, renderer, target_color):
    renderer.update_scene(data, camera="eye")
    pixels = renderer.render()

    r, g, b = pixels[:, :, 0], pixels[:, :, 1], pixels[:, :, 2]

    if target_color == 'yellow':
        mask = (r > 150) & (g > 150) & (b < 100)
    elif target_color == 'red':
        mask = (r > 150) & (g < 100) & (b < 100)
    elif target_color == 'green':
        mask = (r < 100) & (g > 150) & (b < 100)
    else:
        return None, 0, pixels

    y_coords, x_coords = np.where(mask)

    if len(x_coords) < 10:
        return None, 0, pixels

    target_x_center = np.mean(x_coords)
    image_width = pixels.shape[1]
    image_center = image_width / 2

    error = (image_center - target_x_center) / image_center
    confidence = min(len(x_coords) / 1000.0, 1.0)

    return error, confidence, pixels

# ==========================================
# SETUP
# ==========================================
model = mujoco.MjModel.from_xml_string(warehouse_xml)
data = mujoco.MjData(model)
renderer = mujoco.Renderer(model, height=240, width=320)

path_integrator = PathIntegrator()

frames = []
mission_phase = "FIND_PACKAGE"
carrying_package = False
package_picked = False

# UPDATE TARGETS TO MATCH RANDOM POSITIONS
PACKAGE_POS = np.array([pkg_x, pkg_y])
DESTINATION_POS = np.array([dest_x, dest_y])
HOME_POS = np.array([0.0, 0.0])

# Get robot view once for initialization
_, _, robot_view = get_visual_info(model, data, renderer, 'yellow')

print("=" * 60)
print("BIO-INSPIRED ANT NAVIGATION - WAREHOUSE MISSION")
print("=" * 60)
print("Starting at charging dock (Green). Searching for package (Yellow)...")
print()

# ==========================================
# MAIN LOOP
# ==========================================
for step in range(4500):

    # 1. Update Path Integration (Odometry)
    # Robot is now at index 0 (indices 0,1 are X,Y)
    current_pos = data.qpos[:2].copy()
    path_integrator.update(current_pos)

    # Get current Robot Yaw
    quat = data.qpos[3:7]
    current_yaw = np.arctan2(2.0*(quat[0]*quat[3] + quat[1]*quat[2]),
                              1.0 - 2.0*(quat[2]**2 + quat[3]**2))

    # ===========================================
    # PHASE 1: FIND PACKAGE
    # ===========================================
    if mission_phase == "FIND_PACKAGE":
        error, confidence, robot_view = get_visual_info(model, data, renderer, 'yellow')

        if error is not None and confidence > 0.05:
            # Visual homing - package is visible!
            turn_speed = error * 7.0
            fwd_speed = 1.0 * (1.0 - abs(error) * 0.4)
        else:
            # Exploration - navigate toward package area
            direction_to_pkg = PACKAGE_POS - current_pos
            dist_to_pkg = np.linalg.norm(direction_to_pkg)

            if dist_to_pkg > 1.0:
                # Far from package, navigate there
                desired_yaw = np.arctan2(direction_to_pkg[1], direction_to_pkg[0])
                heading_error = desired_yaw - current_yaw
                heading_error = np.arctan2(np.sin(heading_error), np.cos(heading_error))

                turn_speed = heading_error * 5.0
                fwd_speed = 1.0
            else:
                # Close to package area, systematic search
                turn_speed = 2.5 + 0.5 * np.sin(step * 0.02)
                fwd_speed = 0.5

        # Check if package reached
        dist_to_package = np.linalg.norm(current_pos - PACKAGE_POS)
        if dist_to_package < 0.8 and not package_picked:
            print(f"[Step {step}] ✓ PACKAGE PICKED UP!")
            print(f"  Position: ({current_pos[0]:.2f}, {current_pos[1]:.2f})")
            print(f"  Home vector: {path_integrator.get_home_distance():.2f}m")
            print()
            mission_phase = "GO_TO_DESTINATION"
            carrying_package = True
            package_picked = True

    # ===========================================
    # PHASE 2: GO TO DESTINATION
    # ===========================================
    elif mission_phase == "GO_TO_DESTINATION":
        error, confidence, robot_view = get_visual_info(model, data, renderer, 'red')

        if error is not None and confidence > 0.05:
            # Visual homing - destination is visible!
            turn_speed = error * 7.0
            fwd_speed = 1.0 * (1.0 - abs(error) * 0.4)
        else:
            # Navigate toward destination
            direction_to_dest = DESTINATION_POS - current_pos
            dist_to_dest = np.linalg.norm(direction_to_dest)

            if dist_to_dest > 1.0:
                # Far from destination
                desired_yaw = np.arctan2(direction_to_dest[1], direction_to_dest[0])
                heading_error = desired_yaw - current_yaw
                heading_error = np.arctan2(np.sin(heading_error), np.cos(heading_error))

                turn_speed = heading_error * 5.0
                fwd_speed = 1.0
            else:
                # Close to destination, search
                turn_speed = 2.5
                fwd_speed = 0.5

        # Check if destination reached
        dist_to_destination = np.linalg.norm(current_pos - DESTINATION_POS)
        if dist_to_destination < 1.0:
            print(f"[Step {step}] ✓ PACKAGE DELIVERED!")
            print(f"  Position: ({current_pos[0]:.2f}, {current_pos[1]:.2f})")
            print(f"  Home vector: {path_integrator.get_home_distance():.2f}m")
            print()
            mission_phase = "RETURN_HOME"
            carrying_package = False

    # ===========================================
    # PHASE 3: RETURN HOME (PATH INTEGRATION)
    # ===========================================
    elif mission_phase == "RETURN_HOME":
        home_direction = path_integrator.get_home_direction()
        home_distance = path_integrator.get_home_distance()

        error, confidence, robot_view = get_visual_info(model, data, renderer, 'green')

        # Use path integration when far from home
        if home_distance > 1.2:
            if home_direction is not None:
                # PATH INTEGRATION - Primary navigation
                heading_error = home_direction - current_yaw
                heading_error = np.arctan2(np.sin(heading_error), np.cos(heading_error))

                turn_speed = heading_error * 6.0
                fwd_speed = 1.0

                # Blend with visual if home is visible
                if error is not None and confidence > 0.1:
                    turn_speed = 0.6 * turn_speed + 0.4 * (error * 7.0)
            else:
                # Fallback to position-based
                direction_to_home = HOME_POS - current_pos
                desired_yaw = np.arctan2(direction_to_home[1], direction_to_home[0])
                heading_error = desired_yaw - current_yaw
                heading_error = np.arctan2(np.sin(heading_error), np.cos(heading_error))

                turn_speed = heading_error * 5.0
                fwd_speed = 0.9

        elif error is not None and confidence > 0.05:
            # VISUAL HOMING - Home is visible
            turn_speed = error * 7.0
            fwd_speed = 0.8 * (1.0 - abs(error) * 0.4)

        else:
            # Close to home but not visible, gentle search
            direction_to_home = HOME_POS - current_pos
            dist = np.linalg.norm(direction_to_home)

            if dist > 0.3:
                desired_yaw = np.arctan2(direction_to_home[1], direction_to_home[0])
                heading_error = desired_yaw - current_yaw
                heading_error = np.arctan2(np.sin(heading_error), np.cos(heading_error))

                turn_speed = heading_error * 4.0
                fwd_speed = 0.6
            else:
                turn_speed = 1.5
                fwd_speed = 0.3

        # Check if home reached
        dist_to_home = np.linalg.norm(current_pos - HOME_POS)
        if dist_to_home < 0.8:
            print(f"[Step {step}] ✓✓✓ MISSION COMPLETE! ✓✓✓")
            print(f"  Final position: ({current_pos[0]:.2f}, {current_pos[1]:.2f})")
            print(f"  Path integration error: {path_integrator.get_home_distance():.2f}m")
            print(f"  Total time: {step * 0.005:.1f} seconds")
            print("=" * 60)
            mission_phase = "COMPLETE"
            fwd_speed = 0
            turn_speed = 0
            # Add completion frames
            for _ in range(90):
                renderer.update_scene(data, camera="overhead")
                map_view = renderer.render()
                try:
                    import cv2
                    map_view = map_view.copy()
                    cv2.putText(map_view, "MISSION COMPLETE!", (50, 120),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                except:
                    pass
                combined_frame = np.hstack((map_view, robot_view))
                frames.append(combined_frame)
            break

    else:  # COMPLETE
        fwd_speed = 0
        turn_speed = 0

    # ===========================================
    # MOTOR CONTROL (2D Planar Motion Only)
    # ===========================================
    vx = fwd_speed * np.cos(current_yaw)
    vy = fwd_speed * np.sin(current_yaw)

    data.qvel[0] = vx
    data.qvel[1] = vy
    data.qvel[2] = 0  # No Z velocity
    data.qvel[3] = 0  # No rotation around X
    data.qvel[4] = 0  # No rotation around Y
    data.qvel[5] = np.clip(turn_speed, -8.0, 8.0)  # Only yaw rotation

    # Keep robot at constant height (Index 0 is robot now)
    data.qpos[2] = 0.2

    # Keep robot level (no pitch/roll)
    # Quaternion for no rotation around X and Y axes
    current_yaw_only = np.arctan2(2.0*(quat[0]*quat[3] + quat[1]*quat[2]),
                                   1.0 - 2.0*(quat[2]**2 + quat[3]**2))
    data.qpos[3] = np.cos(current_yaw_only / 2)  # w
    data.qpos[4] = 0  # x
    data.qpos[5] = 0  # y
    data.qpos[6] = np.sin(current_yaw_only / 2)  # z

    # ===========================================
    # PACKAGE ATTACHMENT
    # ===========================================
    if carrying_package and package_picked:
        # Attach package to robot (on top of chassis)
        robot_pos = data.qpos[:3]
        robot_quat = data.qpos[3:7]

        # Position package on top of robot
        package_offset = np.array([0.0, 0.0, 0.35])

        # Package is now at qpos index 7 (Robot 0-6, Package 7-13)
        pkg_start_idx = 7
        data.qpos[pkg_start_idx:pkg_start_idx+3] = robot_pos + package_offset
        data.qpos[pkg_start_idx+3:pkg_start_idx+7] = robot_quat

        # Zero package velocity (indices 6-11)
        # Note: Robot qvel 0-5. Package qvel 6-11.
        data.qvel[6:12] = 0

    mujoco.mj_step(model, data)

    # ===========================================
    # RENDERING
    # ===========================================
    if step % 3 == 0:
        renderer.update_scene(data, camera="overhead")
        map_view = renderer.render()

        try:
            import cv2
            map_view = map_view.copy()

            # Phase display with color coding
            phase_colors = {
                "FIND_PACKAGE": (255, 255, 0),      # Yellow
                "GO_TO_DESTINATION": (0, 0, 255),   # Blue
                "RETURN_HOME": (0, 255, 0),         # Green
                "COMPLETE": (0, 255, 0)
            }
            color = phase_colors.get(mission_phase, (255, 255, 255))

            cv2.putText(map_view, f"Phase: {mission_phase}", (10, 20),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1)

            home_dist = path_integrator.get_home_distance()
            cv2.putText(map_view, f"Home Vec: {home_dist:.2f}m", (10, 40),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)

            cv2.putText(map_view, f"Time: {step*0.005:.1f}s", (10, 60),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)

            # Package carrying status
            if carrying_package:
                cv2.putText(map_view, "[Carrying Package]", (10, 80),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 0), 1)

            cv2.putText(map_view, f"Pos: ({current_pos[0]:.1f},{current_pos[1]:.1f})", (10, 100),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
        except:
            pass

        combined_frame = np.hstack((map_view, robot_view))
        frames.append(combined_frame)

print("\nGenerating video...")
media.show_video(frames, fps=30)
print("✓ Simulation Complete!")

Mission Setup -> Package: (2.13, 3.06) | Destination: (-3.11, -2.61)
BIO-INSPIRED ANT NAVIGATION - WAREHOUSE MISSION
Starting at charging dock (Green). Searching for package (Yellow)...

[Step 972] ✓ PACKAGE PICKED UP!
  Position: (1.65, 2.43)
  Home vector: 2.93m

[Step 2271] ✓ PACKAGE DELIVERED!
  Position: (-2.57, -1.77)
  Home vector: 3.12m

[Step 2904] ✓✓✓ MISSION COMPLETE! ✓✓✓
  Final position: (-0.34, -0.72)
  Path integration error: 0.80m
  Total time: 14.5 seconds

Generating video...


0
This browser does not support the video tag.


✓ Simulation Complete!


Trail