In [1]:
import os
import torch
import pickle
import mujoco
import mujoco.viewer
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from agent import SACAgent
from utils import ReplayBuffer
from trainer import SACTrainer

# This solve the pendulum environment with soft actor critic algorithm

# It uses a bunch of tricks to make this work better:
# Twin Q-Networks
# Memory Replay (Experience Replay Buffer)
# Target Networks with Polyak Averaging
# Automatic Entropy Tuning
# Reparameterization Trick
# Tanh Action Squashing with Log-Probability Correction
# Gradient Clipping
# Random Action Initialization

In [2]:
class SequentialReachingEnv(gym.Env):
    """Custom 2-Joint Limb with 4 Muscles, 12 Sensors, and a Target Position"""

    def __init__(
        self,
        xml_file="your_model.xml",
        max_num_targets=10,
        max_target_duration=3,
    ):
        super().__init__()

        mj_dir = "../mujoco"
        xml_path = os.path.join(mj_dir, xml_file)
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)
        self.max_num_targets = max_num_targets
        self.max_target_duration = max_target_duration
        self.viewer = None

        # Get the site ID using the name of your end effector
        self.hand_id = self.model.geom("hand").id
        
        # Load sensor stats
        sensor_stats_path = os.path.join(mj_dir, "sensor_stats.pkl")
        with open(sensor_stats_path, "rb") as f:
            self.sensor_stats = pickle.load(f)

        # Load target stats
        target_stats_path = os.path.join(mj_dir, "target_stats.pkl")
        with open(target_stats_path, "rb") as f:
            self.target_stats = pickle.load(f)

        # Define the lower and upper bounds for each feature (15 features)
        low_values = np.concatenate(
            [
                self.sensor_stats["min"].values,
                self.target_stats["min"].values,
            ]
        )
        high_values = np.concatenate(
            [
                self.sensor_stats["max"].values,
                self.target_stats["max"].values,
            ]
        )

        # Observation space: 12 sensor readings + 3D target position
        self.observation_space = spaces.Box(
            low=low_values, high=high_values, dtype=np.float64
        )

        # Action space: 4 muscle activations
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float64)

        # Load valid target positions
        reachable_positions_path = os.path.join(mj_dir, "reachable_positions.pkl")
        with open(reachable_positions_path, "rb") as f:
            self.reachable_positions = pickle.load(f)

    def sample_targets(self, num_samples=10):
        return self.reachable_positions.sample(num_samples).values

    def update_target(self, position):
        self.data.mocap_pos = position
        mujoco.mj_forward(self.model, self.data)

    def step(self, action):
        self.data.ctrl[:] = action
        mujoco.mj_step(self.model, self.data)

        sensor_data = self.data.sensordata.copy()
        hand_position = self.data.site_xpos[self.hand_id]
        distance = np.linalg.norm(
            hand_position - self.target_positions[self.target_idx]
        )
        reward = -distance

        done = self.data.time > self.max_target_duration * self.max_num_targets
        terminated = False

        # doesn't make sense for learning
        if distance < .05: # self.data.time > self.max_target_duration * (self.target_idx + 1):
            terminated = True
            reward += 1

            if self.target_idx < self.max_num_targets - 1:
                self.target_idx += 1
                self.update_target(self.target_positions[self.target_idx])
            else:
                done = True

        obs = np.concatenate([self.target_positions[self.target_idx], sensor_data])
        return obs, reward, done, terminated, {}

    def reset(self, seed=None):
        super().reset(seed=seed)
        mujoco.mj_resetData(self.model, self.data)

        self.target_positions = self.sample_targets(self.max_num_targets)
        self.target_idx = 0
        self.update_target(self.target_positions[self.target_idx])

        sensor_data = self.data.sensordata.copy()
        obs = np.concatenate([self.target_positions[self.target_idx], sensor_data])
        return obs, {}

    def render(self):
        if self.viewer is not None:
            self.viewer.sync()
        else:
            self.viewer = mujoco.viewer.launch_passive(self.model, self.data)
            self.viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
            self.viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_ACTUATOR] = True
            self.viewer.cam.lookat[:] = [0, -1.5, -0.5]
            self.viewer.cam.azimuth = 90
            self.viewer.cam.elevation = 0

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None

In [3]:
# env = gym.make("Pendulum-v1", render_mode="human")
env = SequentialReachingEnv(
    xml_file="arm_model.xml",
    max_num_targets=1,
    max_target_duration=3,
)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print("State dimension:", state_dim)
print("Action dimension:", action_dim)

hidden_layers = [256, 256]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device:", device)

State dimension: 15
Action dimension: 4
Device: cuda


In [None]:
# Create SAC agent, replay buffer, and trainer
agent = SACAgent(state_dim, action_dim, hidden_layers, device=device)
replay_buffer = ReplayBuffer(state_dim, action_dim)
trainer = SACTrainer(
    env,
    agent,
    replay_buffer,
    batch_size=256,
    start_steps=1000,
    update_after=1000,
    update_every=50,
    max_episode_steps=200,
)

# Run training for a specified number of episodes
trainer.run(num_episodes=200)

Episode: 001 | Reward: -94.11
Episode: 002 | Reward: -139.34
Episode: 003 | Reward: -121.23
Episode: 004 | Reward: -61.14
Episode: 005 | Reward: -121.18
Episode: 006 | Reward: -127.99


In [None]:
env = gym.make("Pendulum-v1", render_mode="human")

# uncomment below if you want to visualize the result of the training
state, _ = env.reset()
done = False
while not done:
   env.render()  # Renders the environment window (ensure you have a display)
   action = agent.select_action(state, deterministic=False)
   next_state, reward, terminated, truncated, info = env.step(action)
   done = terminated or truncated
   state = next_state
env.close()