In [1]:
# Make gymnasium environment
import gymnasium as gym
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium import utils
from gymnasium.spaces import Box

import math
from scipy.spatial.transform import Rotation as R
import numpy as np
import os
from typing import Optional
import imageio

DEFAULT_CAMERA_CONFIG = {
    "distance": 6.0,
}

In [2]:
menagerie_path = '../../assets/Unitree_GO1' # Hi Kanghyun:)

class Go1JoystickGymEnv(MujocoEnv, utils.EzPickle):
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 50,
    }
    def __init__(
        self,
        obs_noise: float = 0.05,
        action_scale: float = 0.3,
        frame_skip=1,
        **kwargs,
    ):
        
        self._xml_file = os.path.join(menagerie_path, 'scene_playground.xml')
        utils.EzPickle.__init__(
            self,
            self._xml_file,
            **kwargs,
        )

        obs_shape = 48
        observation_space = Box(low=-100., high=100., shape=(obs_shape,), dtype=np.float64)

        # self._obs_noise = obs_noise
        self._action_scale = action_scale

        self._xml_dt = 0.02 # timestep from the XML file
        MujocoEnv.__init__(
            self,
            self._xml_file,
            frame_skip,
            observation_space=observation_space,
            default_camera_config=DEFAULT_CAMERA_CONFIG,
            **kwargs,
        )

        assert self.model.opt.integrator == 0, 'Use Euler integration for scene_mjx_gym.xml.'
        # assert self.frame_skip == 1, 'Use frame_skip=1 for Euler integration.'

        self._init_q = np.array([0, 0, 0.27, 1, 0, 0, 0, 0.0, 0.9, -1.8, 0.0, 0.9, -1.8, 0.0, 0.9, -1.8, 0.0, 0.9, -1.8])
        # self._init_q = np.array([0, 0, 0.27, 1, 0, 0, 0, 0.1, 0.9, -1.8, -0.1, 0.9, -1.8, 0.1, 0.9, -1.8, -0.1, 0.9, -1.8])
        self._default_pose = np.array([0.0, 0.9, -1.8] * 4)
        # self._default_pose = np.array([0.1, 0.9, -1.8, -0.1, 0.9, -1.8, 0.1, 0.9, -1.8, -0.1, 0.9, -1.8])
        self.lowers = np.array([-0.7, -1.0, -2.2] * 4)
        self.uppers = np.array([0.52, 2.1, -0.4] * 4)

    def sample_command(self, command: Optional[np.ndarray] = None) -> np.ndarray:
        lin_vel_x = [-0.6, 1.5]  # min max [m/s]
        lin_vel_y = [-0.8, 0.8]  # min max [m/s]
        ang_vel_yaw = [-0.7, 0.7]  # min max [rad/s]

        if command is None:
            lin_vel_x = np.random.uniform(
                low=lin_vel_x[0], high=lin_vel_x[1]
            )
            lin_vel_y = np.random.uniform(
                low=lin_vel_y[0], high=lin_vel_y[1]
            )
            ang_vel_yaw = np.random.uniform(
                low=ang_vel_yaw[0], high=ang_vel_yaw[1]
            )
            new_cmd = np.array([lin_vel_x, lin_vel_y, ang_vel_yaw])
        else:
            new_cmd = command
        
        new_cmd = np.array([1.0, 0.0, 0.0])

        return new_cmd
    
    def reset(self, seed=None, options=None):
        qpos = self._init_q # + np.random.uniform(low=-0.05, high=0.05, size=self.model.nq)
        qvel = np.zeros(self.model.nv)

        self.data.qacc_warmstart[:] = 0.0
        self.set_state(qpos, qvel)

        state_info = {
            'last_act': np.zeros(12),
            'last_vel': np.zeros(12),
            'command': self.sample_command(),
            'kick': np.array([0.0, 0.0]),
            'step': 0,
        }
        self.info = state_info

        obs = self._get_obs(state_info)
        self.obs = obs

        return obs, state_info

    def step(self, action):
        
        # Physics step
        motor_targets = self._default_pose + action * self._action_scale
        motor_targets = np.clip(motor_targets, self.lowers, self.uppers)
        # self.data.qacc_warmstart[:] = 0.0
        self.do_simulation(motor_targets, self.frame_skip)

        # Observation data
        obs = self._get_obs(self.info)
        joint_angles = self.data.qpos[7:]
        joint_vel = self.data.qvel[6:]
        self.obs = obs

        # done if joint limits are reached or robot is falling
        done = False
        done |= np.any(joint_angles < self.lowers)
        done |= np.any(joint_angles > self.uppers)
        done |= self.data.qpos[2] < 0.18

        # state management
        self.info['last_act'] = action

        # Reward
        reward = 0.0

        if self.render_mode == "human":
            self.render()

        return obs, reward, False, False, self.info


    def _get_obs(self, state_info):
        # lin_vel = self.data.sensor('local_linvel').data
        lin_vel = self.data.qvel[:3]
        gyro = self.data.sensor('gyro').data
        # gyro = self.data.sensor('base_gyro').data
        imu_site_id = 0 # imu site id
        # print(self.data.site_xmat[imu_site_id].reshape(3, 3))
        gravity = self.data.site_xmat[imu_site_id].reshape(3,3).T @ np.array([0, 0, -1])
        joint_angles = self.data.qpos[7:] - self._default_pose
        joint_vel = self.data.qvel[6:]
        last_act = state_info['last_act']
        command = state_info['command']

        obs = np.concatenate([
            lin_vel,
            gyro,
            gravity,
            joint_angles,
            joint_vel,
            last_act,
            command,
        ])

        # clip, no noise
        obs = np.clip(obs, -100.0, 100.0) # + self._obs_noise * np.random.uniform(-1, 1, obs.shape)

        previliged_obs = np.concatenate([
            obs,
            gyro,
            np.zeros(3), #accelerometer
            gravity, 
            lin_vel,
            np.zeros(3), # angluar velocity
            joint_angles,
            joint_vel,
            self.data.actuator_force,
            np.zeros(4), # last contact
            np.zeros(12), # feet velocity
            np.zeros(4), # feet air time
            np.zeros(3), # torso xfrc_applied
            np.zeros(1), # step count
        ])

        return {"state": obs, "privileged_state": previliged_obs}

    def reset_to(self, qpos, qvel):
        self.reset()
        # self.data.qacc_warmstart[:] = 0.0
        self.set_state(qpos, qvel)

        state_info = {
            'last_act': np.zeros(12),
            'last_vel': np.zeros(12),
            'command': self.sample_command(),
            'kick': np.array([0.0, 0.0]),
            'step': 0,
        }
        self.info = state_info

        obs = self._get_obs(state_info)
        self.obs = obs

        return obs
    
    def get_full_state(self):
        qpos = self.data.qpos.ravel().copy()
        qvel = self.data.qvel.ravel().copy()
        return np.concatenate([qpos, qvel])
    
# Register the Gym environment
gym.envs.register('joystick_go1', Go1JoystickGymEnv)

Check the environment by applying zero control                                                                      

In [None]:
# Create the environment
env = gym.make('joystick_go1', render_mode='human')
env = env.unwrapped
obs, _ = env.reset()

# Save obs and action history
obs_hist = []
act_hist = []

for t in range(300):
    act = np.zeros(12)
    obs_hist.append(env.get_full_state())
    act_hist.append(act)
    obs, _, terminated, _, info = env.step(act)

    # image = env.render()
    # # save the image
    # imageio.imwrite(f"figure/image_{t}.png", image)

    if terminated:
        break

env.close()

Load jax policy

In [11]:
# @title Import MuJoCo, MJX, and Brax
import functools
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.io import model
import jax
import numpy as np
import matplotlib.pyplot as plt

from mujoco_playground import wrapper
from mujoco_playground import registry

In [7]:
env_name = 'Go1JoystickFlatTerrain'
env = registry.load(env_name)
env_cfg = registry.get_default_config(env_name)
env_cfg.Kd = 0.7
env_cfg.Kp = 40

from mujoco_playground.config import locomotion_params
ppo_params = locomotion_params.brax_ppo_config(env_name)
ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

In [8]:
train_fn = functools.partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
)
restore_checkpoint_path = "/home/kanghyunryu/mujoco_playground/checkpoints/Go1JoystickFlatTerrain/final/200540160"
make_inference_fn, params, metrics = train_fn(
    environment=registry.load(env_name, config=env_cfg),
    eval_env=registry.load(env_name, config=env_cfg),
    wrap_env_fn=wrapper.wrap_for_brax_training,
    # restore_checkpoint_path=restore_checkpoint_path,  # restore from the checkpoint!
    num_timesteps=0,
)

# If you have problem with checkpoint loading, you can load the params directly
model_path = './logs/go1_policy'
params = model.load_params(model_path)

jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [9]:
# Create the environment
env = gym.make('joystick_go1', render_mode='human')
env = env.unwrapped
obs, _ = env.reset()
env.info["command"] = np.array([0.0, 0.0, 0.0])

# Save obs and action history
obs_hist = []
act_hist = []

rng = jax.random.PRNGKey(seed=1)

for t in range(1000):
    act_rng, rng = jax.random.split(rng)
    act, _ = jit_inference_fn(obs, act_rng)
    obs_hist.append(env.get_full_state())
    act_hist.append(act)
    obs, _, terminated, _, info = env.step(act)

    if t == 100:
        env.info["command"] = np.array([1.0, 0.0, 0.0])

    if terminated:
        break

# Save the history
trajectory = {
    'obs': np.array(obs_hist),
    'act': np.array(act_hist)
}

np.save('go1_trajectory.npy', trajectory)
np.savetxt('go1_state_trajectory.csv', np.array(obs_hist), delimiter=',')
np.savetxt('go1_action_trajectory.csv', np.array(act_hist), delimiter=',')

env.close()

In [14]:
## Load the trajectory and try to replay
env = gym.make('joystick_go1', render_mode='human')
env = env.unwrapped
obs, _ = env.reset()

trajectory = np.load('go1_trajectory.npy', allow_pickle=True).item()
obs_hist = trajectory['obs']
act_hist = trajectory['act']

obs_init = obs_hist[0]
qpos_init = obs_init[:19]
qvel_init = obs_init[19:]

for t in range(1000):
    act = act_hist[t]
    obs, _, terminated, _, info = env.step(act)

    if terminated:
        break

env.close()