<a href="https://colab.research.google.com/github/effypelayotran/C_Undergrad_Neural_Trainers/blob/main/with_hip_rotation_render_script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax


Collecting mujoco
  Downloading mujoco-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco)
  Downloading glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Downloading mujoco-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl (243 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.5/243.5 kB[0m [31m13.9 MB/s[0

In [None]:
#@title Check if MuJoCo installation was successful

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp # jp is jax numpy
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


In [None]:
#@title Mount GDrive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/effypelayotran/mujoco_resources.git

Cloning into 'mujoco_resources'...
remote: Enumerating objects: 83, done.[K
remote: Counting objects: 100% (83/83), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 83 (delta 33), reused 35 (delta 3), pack-reused 0 (from 0)[K
Receiving objects: 100% (83/83), 3.78 MiB | 30.28 MiB/s, done.
Resolving deltas: 100% (33/33), done.


# Define Humanoid Env and Reward

In [None]:
#@title Define Humanoid Env 1 - TARGET HEADING DEFINED BUT NOT IN OBSERVATION SPACE

HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 3.0,
      **kwargs,
  ):
#
    mj_model = mujoco.MjModel.from_xml_path(
        (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics)
    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    state.info['goal'] = d_star

    return state

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt

    # New Reward: Target Heading Task Reward
    d_star    = state.info['goal']
    v_xy   = velocity[0:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    heading_r = jp.exp(-2.5 * speed_err**2)
    forward_reward = self._heading_weight * heading_r

    # Old Reward
    # forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])


envs.register_environment('humanoid_og', Humanoid1)

# set the env as this humanoid env
env_name = 'humanoid_og'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Parse Mocap

from pathlib import Path
import os
from scipy.spatial.transform import Rotation as R
import numpy as np

tasks = {}


def parse_bvh_file():
	print("Parsing")
	# Path to current script
	#current_file = Path(__file__)

	# Go two levels up to reach 'parent' folder
	#parent_dir = current_file.parent.parent

	mocap_folder = Path('/content/drive/MyDrive/make_him_walk/mocap')
	#mocap_folder = drive_folder / 'mocap'

	print("got mocap folder")

	#go through all the different types of actions in the mocap folder
	for action in mocap_folder.iterdir():
		if action.is_dir():
			#for each action make a string value in the tasks dictionary so we can map a list of clips to it
			tasks[action.name] = []
			curr_action = tasks[action.name]
			#get each individual clip
			for clip in action.glob('*.txt'):  # Only files in this subfolder, not recursive
				#for each clip we append to the list of clips for the given action
				curr_action.append([])
				#access last element in the list of clips corresponding to the action we're looking at (this is the clip we are looking at)
				curr_clip = curr_action[-1]
				frame_data = []
				with open(clip, 'r', encoding='utf-8') as f:
					frame_num = 0
					joints = []
					motion_data_started = False
					current_frame = []

					#now we want to read through the clip and append a dict of all the joint info every frame to the list we just created representing that clip
					for line in f:
						stripped = line.strip()
						if not stripped:
							continue  # skip blank lines

						if stripped.startswith('ROOT') or stripped.startswith('JOINT'):
							joint_name = stripped.split()[1]
							print("found JOINT: ", joint_name)
							joints.append(joint_name)
							continue

						if stripped.startswith('Frame Time:'):
							motion_data_started = True
							print("Looking at motion data")
							continue

						#at this point we can start storing the joint data for each frame
						if motion_data_started:
							#store all the floats representing the joint data in current_frame
							floats = list(map(float, stripped.split()))
							current_frame.extend(floats)

							# Once we have 78 floats, store the frame and reset the list
							if len(current_frame) >= 78:
								frame_num += 1
								curr_clip.append({})
								frame = curr_clip[-1]
								frame_data.append(current_frame[:78])  # Store the first 78 floats
								curr_frame_data = frame_data[-1]
								current_frame = current_frame[78:]  # Keep the remaining floats for next frame
								for i in range (len(joints)):
									joint_name = joints[i]
									frame[joint_name] = {}
									#if we are looking at the hips
									if i == 0:
										hips = frame[joint_name]
										#hard code position and rotation for the hips since it's the only one with those two attirbutes
										hips["position"] = np.array(curr_frame_data[:3])

										#THIS IS WHERE WE SWITCH AXES !!!!!!!!
										#!!!!!!!!!!!!
										#

										hips["position"][0], hips["position"][1], hips["position"][2] = hips["position"][2], hips["position"][0], hips["position"][1]
										rotation = curr_frame_data[3:6]

										#SWITCH AXES TO BE CONSITENT WITH MUJOCO
										rotation[0], rotation[1], rotation[2] = rotation[2], rotation[0], rotation[1]

										#THINK WE NEED QUATERNION FOR THE HIPS
										quaternion = R.from_euler('xyz', rotation, degrees=True)
										hips["rotation"] = quaternion

										#I THINK WE WANT QUATERNION FOR THE HIPS SO I ACTUALLY COMMENT OUT THE ROTATION LINE
										#hips["rotation"] = rotation
										#any other joint
									else:
										joint = frame[joint_name]
										start_ind = (i+1) * 3
										end_ind = (i+2) * 3
										rotation = curr_frame_data[start_ind:end_ind]
										#SWITCH AXES TO BE CONSITENT WITH MUJOCO
										rotation[0], rotation[1], rotation[2] = rotation[2], rotation[0], rotation[1]

										#COMMENTING OUT QUATERNION STUFF SINCE WE DONT WANT TO USE IT
										#quaternion = R.from_euler('xyz', rotation, degrees=True)

										#CONVERT TO RADIANS SINCE IT STARTS OUT IN DEGREES
										joint["rotation"] = np.deg2rad(rotation)

				#go through frames to calculate velocity
				for f in range (len(curr_clip)):
					frame = curr_clip[f]

					#if we're at the end of the list we don't want to check rotation in next frame since there is no next frame
					if f == (len(curr_clip) - 1):
						prev_frame = curr_clip[f-1]
						for joint in frame.keys():
							#just set the final velocity to whatever the previous frame's velocity was
							frame[joint]["velocity"] = prev_frame[joint]["velocity"]

					else:
						next_frame = curr_clip[f + 1]
						for joint in frame.keys():
							frame1_rot = frame[joint]["rotation"]
							frame2_rot = next_frame[joint]["rotation"]

							#HIPS ARE QUATERNION, OTHER JOINTS ARE RADIANS
							if joint == "Hips":
								# Compute relative quaternion: q_diff = q2 * q1^-1
								rot_diff = frame2_rot * frame1_rot.inv()
								# Get rotation vector (axis * angle)
								diff_angle = rot_diff.as_rotvec()  # 3D vector: axis * angle (already in radians)
							else:
								diff_angle = frame2_rot - frame1_rot
							#DIVIDE BY DT WHICH WILL BE PASSED THROUGH AND DETERMINED BY THE MUJOCO ENVIRONMENT
							vel = diff_angle
							frame[joint]["velocity"] = vel

					#FOR NOW NOT USING END EFFECTS (MIGHT DELETE COMPLETELY TBH)
					#NEXT FIND POSITIONS OF HANDS AND FEET
					#find hip rotation and position first to start accumalation
					# hip_world_pos = frame["Hips"]["position"]
					# hip_world_rot = frame["Hips"]["rotation"]

					# #spine
					# spine_offset = np.array([0.00000, 0.00000, 0.03937])
					# #apply parent rotation (being the hip rot) to get the actual position in world space
					# spine_world = hip_world_rot.apply(spine_offset)
					# #now create world_pos as a new variable so that we can save the hip rot and pos
					# world_pos = hip_world_pos + spine_world
					# world_rot = hip_world_rot * frame["Spine"]["rotation"]

					# #spine1
					# spine1_offset = np.array([0.00000, 0.00000, 10.24829])
					# spine1_world = world_rot.apply(spine1_offset)
					# world_pos = world_pos + spine1_world
					# world_rot = world_rot * frame["Spine1"]["rotation"]

					# #right shoulder
					# r_should_offset = np.array([0.00000, 0.00000, 7.82687])
					# r_should_world = world_rot.apply(r_should_offset)
					# #change to right world pos/rotation now since we'll want to save pos/rot of spine for the left side
					# right_world_pos = world_pos + r_should_world
					# right_world_rot = world_rot * frame["RightShoulder"]["rotation"]

					# #right arm
					# r_arm_offset = np.array([0.00000, -6.71018, -0.00002])
					# r_arm_world = right_world_rot.apply(r_arm_offset)
					# right_world_pos = right_world_pos + r_arm_world
					# right_world_rot = right_world_rot * frame["RightArm"]["rotation"]

					# #right forearm
					# r_forearm_offset = np.array([0.00000, -10.94419, -0.00004])
					# r_forearm_world = right_world_rot.apply(r_forearm_offset)
					# right_world_pos = right_world_pos + r_forearm_world
					# right_world_rot = right_world_rot * frame["RightForeArm"]["rotation"]

					# #right hand
					# r_hand_offset = np.array([0.00000 ,-8.52010, -0.00003])
					# r_hand_world = right_world_rot.apply(r_hand_offset)
					# r_hand_world = right_world_pos + r_hand_world

					# frame["RightHand"]["position"] = r_hand_world

					# #left shoulder
					# l_should_offset = np.array([0.00000, 0.00000, 7.82687])
					# l_should_world = world_rot.apply(l_should_offset)
					# #now do l world pos/rot
					# l_world_pos = world_pos + l_should_world
					# l_world_rot = world_rot * frame["LeftShoulder"]["rotation"]

					# #left arm
					# l_arm_offset = np.array([0.00000, 6.71018, -0.00002])
					# l_arm_world = l_world_rot.apply(l_arm_offset)
					# l_world_pos = l_world_pos + l_arm_world
					# l_world_rot = l_world_rot * frame["LeftArm"]["rotation"]

					# #left forearm
					# l_forearm_offset = np.array([0.00000, 10.94419, -0.00004])
					# l_forearm_world = l_world_rot.apply(l_forearm_offset)
					# l_world_pos = l_world_pos + l_forearm_world
					# l_world_rot = l_world_rot * frame["LeftForeArm"]["rotation"]

					# #left hand
					# l_hand_offset = np.array([0.00000 ,8.52010, -0.00003])
					# l_hand_world = l_world_rot.apply(l_hand_offset)
					# l_hand_world = l_world_pos + l_hand_world

					# frame["LeftHand"]["position"] = l_hand_world

					# #left upper leg
					# #we start back at the hips and luckily we stored it earlier and didnt change it during accumalation
					# l_up_leg_offset = np.array([0.00000, 3.64953, 0.00000])
					# l_up_leg_world = hip_world_rot.apply(l_up_leg_offset)
					# #now do l world pos/rot again since we're not doing anything else with the left hand
					# l_world_pos = hip_world_pos + l_up_leg_world
					# l_world_rot = hip_world_rot * frame["LeftUpLeg"]["rotation"]

					# #left leg
					# l_leg_offset = np.array([0.00000, 0.00000, -15.70580])
					# l_leg_world = l_world_rot.apply(l_leg_offset)
					# l_world_pos = l_world_pos + l_leg_world
					# l_world_rot = l_world_rot * frame["LeftLeg"]["rotation"]

					# #left foot
					# l_foot_offset = np.array([0.00000, 0.00000, -15.41867])
					# l_foot_world = l_world_rot.apply(l_foot_offset)
					# l_foot_world = l_world_pos + l_foot_world

					# frame["LeftFoot"]["position"] = l_foot_world

					# #right upper leg
					# r_up_leg_offset = np.array([0.00000, -3.64953, 0.00000])
					# r_up_leg_world = hip_world_rot.apply(r_up_leg_offset)
					# #now do l world pos/rot again since we're not doing anything else with the left hand
					# r_world_pos = hip_world_pos + r_up_leg_world
					# r_world_rot = hip_world_rot * frame["RightUpLeg"]["rotation"]

					# #right leg
					# r_leg_offset = np.array([0.00000, 0.00000, -15.70580])
					# r_leg_world = r_world_rot.apply(r_leg_offset)
					# r_world_pos = r_world_pos + r_leg_world
					# r_world_rot = r_world_rot * frame["RightLeg"]["rotation"]

					# #right foot
					# r_foot_offset = np.array([0.00000, 0.00000, -15.41867])
					# r_foot_world = r_world_rot.apply(r_foot_offset)
					# r_foot_world = r_world_pos + r_foot_world

					# frame["RightFoot"]["position"] = r_foot_world

	bvh_to_mujoco()
	return tasks


def bvh_to_mujoco():
	for task in tasks.keys():
		print("TASK ", task)
		for clip in tasks[task]:
			for frame in clip:
				qpos = []
				qpos.append(frame["Hips"]["position"][0])
				qpos.append(frame["Hips"]["position"][1])
				qpos.append(frame["Hips"]["position"][2] + 0.857)

				#WE DO THIS WEIRD ORDER BECAUSE MUJOCO ACTUALLY USED W,X,Y,Z WHILE SCIPY USED X,Y,Z,W
				qpos.append(frame["Hips"]["rotation"].as_quat()[3])
				qpos.append(frame["Hips"]["rotation"].as_quat()[0])
				qpos.append(frame["Hips"]["rotation"].as_quat()[1])
				qpos.append(frame["Hips"]["rotation"].as_quat()[2])

				#waist_lower
				#z rotation
				qpos.append(frame["Spine"]["rotation"][2])
				#y rotation
				qpos.append(frame["Spine"]["rotation"][1])

				#torso y
				qpos.append(frame["Spine1"]["rotation"][1])

				#upper_arm_right
				qpos.append(frame["RightArm"]["rotation"][0])
				qpos.append(frame["RightArm"]["rotation"][1] * -1)

				#elbow_right
				qpos.append(frame["RightForeArm"]["rotation"][2])


				#upper_arm_left

				qpos.append(frame["LeftArm"]["rotation"][0])
				qpos.append(frame["LeftArm"]["rotation"][1])

				#elbow_left
				#0,-1,-1
				qpos.append(frame["LeftForeArm"]["rotation"][2] * -1)

				#thigh_right
				qpos.append(frame["RightUpLeg"]["rotation"][0])
				qpos.append(frame["RightUpLeg"]["rotation"][2])
				qpos.append(frame["RightUpLeg"]["rotation"][1])

				#knee_right
				#multiply by -1 because axis is 0,-1,0
				qpos.append(frame["RightLeg"]["rotation"][1] * -1)

				#foot_right
				qpos.append(frame["RightFoot"]["rotation"][1])
				#second axis is 1,0,.5
				axis = np.array([1,0,0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["RightFoot"]["rotation"], axis)
				qpos.append(proj)


				#thigh_left
				qpos.append(frame["LeftUpLeg"]["rotation"][0] * -1)
				qpos.append(frame["LeftUpLeg"]["rotation"][2] * -1)
				qpos.append(frame["LeftUpLeg"]["rotation"][1])

				#knee_left
				#multiply by -1 because axis is 0,-1,0
				qpos.append(frame["LeftLeg"]["rotation"][1] * -1)

				#foot_left
				qpos.append(frame["LeftFoot"]["rotation"][1])
				#second axis is -1,0,-.5
				axis = np.array([-1,0,-0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["LeftFoot"]["rotation"], axis)
				qpos.append(proj)

				frame["qpos"] = qpos

				# qvel = []

				# qvel.append(frame["Hips"]["velocity"][0])
				# qvel.append(frame["Hips"]["velocity"][1])
				# qvel.append(frame["Hips"]["velocity"][2])
				# qvel.append(frame["Hips"]["velocity"][3])

				# qvel.append(frame["Spine"]["velocity"][0])
				# qvel.append(frame["Spine"]["velocity"][1])
				# qvel.append(frame["Spine"]["velocity"][2])
				# qvel.append(frame["Spine"]["velocity"][3])

				# qvel.append(frame["Spine1"]["velocity"][0])
				# qvel.append(frame["Spine1"]["velocity"][1])
				# qvel.append(frame["Spine1"]["velocity"][2])
				# qvel.append(frame["Spine1"]["velocity"][3])

				# qvel.append(frame["RightArm"]["velocity"][0])
				# qvel.append(frame["RightArm"]["velocity"][1])
				# qvel.append(frame["RightArm"]["velocity"][2])
				# qvel.append(frame["RightArm"]["velocity"][3])

				# qvel.append(frame["RightForeArm"]["velocity"][0])
				# qvel.append(frame["RightForeArm"]["velocity"][1])
				# qvel.append(frame["RightForeArm"]["velocity"][2])
				# qvel.append(frame["RightForeArm"]["velocity"][3])

				# qvel.append(frame["LeftArm"]["velocity"][0])
				# qvel.append(frame["LeftArm"]["velocity"][1])
				# qvel.append(frame["LeftArm"]["velocity"][2])
				# qvel.append(frame["LeftArm"]["velocity"][3])

				# qvel.append(frame["LeftForeArm"]["velocity"][0])
				# qvel.append(frame["LeftForeArm"]["velocity"][1])
				# qvel.append(frame["LeftForeArm"]["velocity"][2])
				# qvel.append(frame["LeftForeArm"]["velocity"][3])

				# qvel.append(frame["RightUpLeg"]["velocity"][0])
				# qvel.append(frame["RightUpLeg"]["velocity"][1])
				# qvel.append(frame["RightUpLeg"]["velocity"][2])
				# qvel.append(frame["RightUpLeg"]["velocity"][3])

				# qvel.append(frame["RightLeg"]["velocity"][0])
				# qvel.append(frame["RightLeg"]["velocity"][1])
				# qvel.append(frame["RightLeg"]["velocity"][2])
				# qvel.append(frame["RightLeg"]["velocity"][3])

				# qvel.append(frame["RightFoot"]["velocity"][0])
				# qvel.append(frame["RightFoot"]["velocity"][1])
				# qvel.append(frame["RightFoot"]["velocity"][2])
				# qvel.append(frame["RightFoot"]["velocity"][3])

				# qvel.append(frame["LeftUpLeg"]["velocity"][0])
				# qvel.append(frame["LeftUpLeg"]["velocity"][1])
				# qvel.append(frame["LeftUpLeg"]["velocity"][2])
				# qvel.append(frame["LeftUpLeg"]["velocity"][3])

				# qvel.append(frame["LeftLeg"]["velocity"][0])
				# qvel.append(frame["LeftLeg"]["velocity"][1])
				# qvel.append(frame["LeftLeg"]["velocity"][2])
				# qvel.append(frame["LeftLeg"]["velocity"][3])

				# qvel.append(frame["LeftFoot"]["velocity"][0])
				# qvel.append(frame["LeftFoot"]["velocity"][1])
				# qvel.append(frame["LeftFoot"]["velocity"][2])
				# qvel.append(frame["LeftFoot"]["velocity"][3])

				# frame["qvel"] = qvel


#COMMENTING THIS OUT (WILL PROBABLY DELETE) SINCE WE'RE NOT USING QUATERNIONS ANYMORE
# def quat_from_xyz(x_angle=0.0, y_angle=0.0, z_angle=0.0):
# 	cx = jp.cos(x_angle / 2)
# 	sx = jp.sin(x_angle / 2)
# 	cy = jp.cos(y_angle / 2)
# 	sy = jp.sin(y_angle / 2)
# 	cz = jp.cos(z_angle / 2)
# 	sz = jp.sin(z_angle / 2)

# 	qx = jp.array([cx, sx, 0, 0])
# 	qy = jp.array([cy, 0, sy, 0])
# 	qz = jp.array([cz, 0, 0, sz])

# 	# Multiply quaternions: q = qx * (qy * qz)
# 	def quat_mul(q1, q2):
# 			w1, x1, y1, z1 = q1
# 			w2, x2, y2, z2 = q2
# 			return jp.array([
# 					w1*w2 - x1*x2 - y1*y2 - z1*z2,
# 					w1*x2 + x1*w2 + y1*z2 - z1*y2,
# 					w1*y2 - x1*z2 + y1*w2 + z1*x2,
# 					w1*z2 + x1*y2 - y1*x2 + z1*w2
# 			])

# 	q = quat_mul(qx, quat_mul(qy, qz))
# 	return q

# def mujoco_to_quaternion(qpos):

# 	quaternions=jp.array([])

# 	#hip data
# 	quaternions = jp.concatenate([quaternions, jp.array(qpos[:7])])

# 	#lower waist
# 	abdomen_z, abdomen_y = qpos[7:9]
# 	abdomen_quat = quat_from_xyz(y_angle=abdomen_y, z_angle=abdomen_z)
# 	quaternions = jp.concatenate([quaternions, abdomen_quat])


# 	#torso
# 	torso_y = qpos[9]
# 	torso_quat = quat_from_xyz(y_angle = torso_y)
# 	quaternions = jp.concatenate([quaternions, torso_quat])


# 	#upper right arm
# 	should1_right, should2_right = qpos[10:12]
# 	should_right_quat = quat_from_xyz(x_angle = should1_right, y_angle = should2_right)
# 	quaternions = jp.concatenate([quaternions, should_right_quat])


# 	#right elbow
# 	elbow_right = qpos[12]
# 	elbow_right_quat = quat_from_xyz(y_angle = elbow_right)
# 	quaternions = jp.concatenate([quaternions, elbow_right_quat])


# 	#upper left arm
# 	should1_left, should2_left = qpos[13:15]
# 	should_left_quat = quat_from_xyz(x_angle = should1_left, y_angle =should2_left)
# 	quaternions = jp.concatenate([quaternions, should_left_quat])


# 	#left elbow
# 	elbow_left = qpos[15]
# 	elbow_left_quat = quat_from_xyz(y_angle = elbow_left)
# 	quaternions = jp.concatenate([quaternions, elbow_left_quat])


# 	#right thigh
# 	hip_x_right, hip_z_right, hip_y_right = qpos[16:19]
# 	hip_right_quat = quat_from_xyz(x_angle=hip_x_right, y_angle=hip_y_right, z_angle=hip_z_right)
# 	quaternions = jp.concatenate([quaternions, hip_right_quat])


# 	#right shin
# 	knee_right = qpos[19]
# 	knee_right_quat = quat_from_xyz(y_angle=knee_right)
# 	quaternions = jp.concatenate([quaternions, knee_right_quat])


# 	#right foot
# 	ankle_y_right, ankle_x_right = qpos[20:22]
# 	ankle_right_quat = quat_from_xyz(x_angle=ankle_x_right, y_angle=ankle_y_right)
# 	quaternions = jp.concatenate([quaternions, ankle_right_quat])


# 	#left thigh
# 	hip_x_left, hip_z_left, hip_y_left = qpos[22:25]
# 	hip_left_quat = quat_from_xyz(x_angle=hip_x_left * -1.0, y_angle=hip_y_left, z_angle=hip_z_left * -1.0)
# 	quaternions = jp.concatenate([quaternions, hip_left_quat])


# 	#left shin
# 	knee_left = qpos[25]
# 	knee_left_quat = quat_from_xyz(y_angle=knee_left)
# 	quaternions = jp.concatenate([quaternions, knee_left_quat])


# 	#left foot
# 	ankle_y_left, ankle_x_left = qpos[26:28]
# 	ankle_left_quat = quat_from_xyz(x_angle=ankle_x_left * -1.0, y_angle=ankle_y_left)
# 	quaternions = jp.concatenate([quaternions, ankle_left_quat])

# 	print("QUATERNIONS LENGTH: ", len(quaternions))
# 	return quaternions





In [None]:
#@title Define Humanoid Env 1.5 -- no f_xy. only v_xy in OBSERVATION SPACE AND USING HUMANOID_FIXEX_2.xml
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')
FIXED_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')
from jax import lax


class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.5, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model = mujoco.MjModel.from_xml_path(
         (FIXED_ROOT_PATH / 'humanoid_fixed_4.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.frames = self.mocap_dict["walking"][0]
    self.frames_qpos = []
    for frame in self.frames:
      qpos = frame["qpos"]
      self.frames_qpos.append(qpos)
    #each row is a frame
    self.frames_qpos = np.array(self.frames_qpos)
    print(self.frames_qpos. shape[0], ", ", self.frames_qpos.shape[1])
    self.frames_qpos = jp.array(self.frames_qpos)
    self.step_count = 0

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    self.step_count = 0
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'frame_count': jp.array(0)
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(self, mo_pos, stt_pos):

    mo_pos = jp.array(mo_pos)
    stt_pos = jp.array(stt_pos)

    pose_reward = 0
    vel_reward = 0
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[7:]
    stt_rot_data = stt_pos[7:]

    print("mo_rot_data: ", len(mo_rot_data))
    print("stt_rot_data: ", len(stt_rot_data))

    for r in range(0, len(mo_rot_data)):
        rot_diff = mo_rot_data[r] - stt_rot_data[r]
        pose_reward += (rot_diff ** 2)

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        #pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = jp.exp(jp.maximum(pose_reward, -50))



    # mo_ang_vel = mo_vel[3:]
    # stt_ang_vel = stt_vel[:]


    # for v in range(0, len(mo_ang_vel)):
    #     vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
    #     vel_reward += (vel_diff) ** 2.0

    # vel_reward *= -0.1
    # vel_reward = math.exp(max(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    # center_mass_reward = jp.linalg.norm(mo_pos[:3] - stt_pos[:3])
    # center_mass_reward *= -10.0
    # center_mass_reward = jp.exp(jp.maximum(center_mass_reward, -50))

    #full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)
    full_reward = pose_reward
    return full_reward


  def step(self, state: State, action: jp.ndarray) -> State:
    current_frame = state.info['frame_count']
    print("Current Frame:", current_frame)
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    #FOR DEBUGGING MOCAP DATA
    # data = data0
    # hip_pos = data.qpos[:7]
    # rest = self.frames_qpos[current_frame][7:]
    # new_qpos = jp.concatenate([hip_pos, rest])
    # data = self.pipeline_init(new_qpos, data.qvel)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt

    # # New Reward: Target Heading Task Reward
    d_star    = state.info['goal']
    v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    heading_r = jp.exp(-2.5 * speed_err**2)
    forward_reward = self._heading_weight * heading_r
    state.info['v_xy'] = velocity[:2]

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    # forward_reward = self._forward_reward_weight * velocity[0]
        # forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    frame_count = state.info['frame_count'] + 1
    state_info = dict(state.info)
    state_info['frame_count'] = frame_count

    obs = self._get_obs(data, action, state.info)

    #IMITATION REWARD!!!!!!
    max_frame = self.frames_qpos.shape[0] - 1

    def compute_reward(_):
        mocap_qpos = self.frames_qpos[current_frame]
        return 3.0 * self.imit_reward(mocap_qpos, data.qpos)

    def zero_reward(_):
        return 0.0

    imit_reward = lax.cond(
        current_frame < max_frame,
        compute_reward,
        zero_reward,
        operand=None
    )
    # if current_frame < (self.frames_qpos.shape[0] - 1):
    #   mocap_qpos = self.frames_qpos[current_frame]

    #   imit_reward = 3.0 * self.imit_reward(mocap_qpos, data.qpos)
    # else:
    #   imit_reward = 0


    #ADD IT ALL UP
    reward = forward_reward + healthy_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    # reward = 1
    # done = False

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state_info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['goal'],
        state_info['v_xy'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

Parsing
got mocap folder
found JOINT:  Hips
found JOINT:  LeftUpLeg
found JOINT:  LeftLeg
found JOINT:  LeftFoot
found JOINT:  LeftToeBase
found JOINT:  RightUpLeg
found JOINT:  RightLeg
found JOINT:  RightFoot
found JOINT:  RightToeBase
found JOINT:  Spine
found JOINT:  Spine1
found JOINT:  Neck
found JOINT:  Head
found JOINT:  LeftShoulder
found JOINT:  LeftArm
found JOINT:  LeftForeArm
found JOINT:  LeftHand
found JOINT:  LeftHandThumb
found JOINT:  L_Wrist_End
found JOINT:  RightShoulder
found JOINT:  RightArm
found JOINT:  RightForeArm
found JOINT:  RightHand
found JOINT:  RightHandThumb
found JOINT:  R_Wrist_End
Looking at motion data
found JOINT:  Hips
found JOINT:  LeftUpLeg
found JOINT:  LeftLeg
found JOINT:  LeftFoot
found JOINT:  LeftToeBase
found JOINT:  RightUpLeg
found JOINT:  RightLeg
found JOINT:  RightFoot
found JOINT:  RightToeBase
found JOINT:  Spine
found JOINT:  Spine1
found JOINT:  Neck
found JOINT:  Head
found JOINT:  LeftShoulder
found JOINT:  LeftArm
found JOIN

In [None]:
#@title Train
train_fn = functools.partial(
    ppo.train, num_timesteps=20_000_000, num_evals=20, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])


  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']

  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

Current Frame: Traced<ShapedArray(int32[], weak_type=True)>with<BatchTrace> with
  val = Traced<ShapedArray(int32[128], weak_type=True)>with<DynamicJaxprTrace>
  batch_dim = 0
mo_rot_data:  21
stt_rot_data:  21
Step: 0 | Eval Reward: 18.60 ± 13.70
Current Frame: Traced<ShapedArray(int32[])>with<BatchTrace> with
  val = Traced<ShapedArray(int32[3072])>with<DynamicJaxprTrace>
  batch_dim = 0
mo_rot_data:  21
stt_rot_data:  21
Step: 1105920 | Eval Reward: 126.22 ± 37.32
Step: 2211840 | Eval Reward: 179.92 ± 38.66
Step: 3317760 | Eval Reward: 218.25 ± 37.61
Step: 4423680 | Eval Reward: 244.45 ± 45.91
Step: 5529600 | Eval Reward: 264.86 ± 47.78
Step: 6635520 | Eval Reward: 301.26 ± 50.58
Step: 7741440 | Eval Reward: 316.92 ± 56.29


In [None]:
#@title Define the Humanoid Env 2 - both v_xy and f_xy IN OBSERVATION SPACE

HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'

TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')

def quat_to_yaw(q: jp.ndarray) -> jp.ndarray: # <-- yaw is rotation around z-axis
    # q in (w,x,y,z) order
    w, x, y, z = q[0], q[1], q[2], q[3]
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    return jp.arctan2(siny_cosp, cosy_cosp)

class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    mj_model = mujoco.MjModel.from_xml_path(
        (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'f_xy': jp.zeros((2,), dtype=jp.float32),
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state


  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt


    # New Reward: Target Heading Task Reward
    d_star    = state.info['goal']
    v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    heading_r = jp.exp(-2.5 * speed_err**2)
    forward_reward = self._heading_weight * heading_r
    state.info['v_xy'] = velocity[:2]

    # Old Version of Hip Rotation
    # q_rotation = state.data.qpos[3:7]
    # hip_quat = data.x.rot[0] # x is root link quaternion?
    # curr_yaw = quat_to_yaw(q_rotation)
    # # ^^^^ jp.arctan2(siny_cosp, cosy_cosp)
    # desired_yaw =  jp.arctan2(d_star[1], d_star[0])
    # diff = jp.exp((desired_yaw - curr_yaw)^2)
    # hip_reward = 3 * diff


    # New Version of Hip Rotation
    hip_quat = data.x.rot[0]
    forward_world = math.rotate(jp.array([1.0, 0.0, 0.0]), hip_quat) # <--- Rotate the local forward axis [1,0,0] into world coordinates:
    f_xy = forward_world[:2]
    f_xy = f_xy / (jp.linalg.norm(f_xy) + 1e-6) #<--- normalize and + small eps to avoid division by zero
    hip_align = f_xy @ d_star

    # # 5. (Optionally clamp negatives if you only want positive reward)
    # hip_align = jp.maximum(0.0, hip_align)

    hip_reward = 3.0 * hip_align
    state.info['f_xy'] = f_xy


    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    # forward_reward = self._forward_reward_weight * velocity[0]
        # forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action, state.info)

    reward = forward_reward + hip_reward + healthy_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['goal'],
        state_info['v_xy'],
        state_info['f_xy']
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title FOR CMU VERSION - Render 1 scene of the humanoid_CMU.xml
ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')

# xml_path = CMU_ROOT_PATH / 'humanoid_CMU.xml'

# print("Exists?", xml_path.exists(), "  Path:", xml_path)

# with open(xml_path.as_posix(), 'r') as f:
#     xml = f.read()

# #model = mujoco.MjModel.from_xml_string(xml)

model = mujoco.MjModel.from_xml_path((ROOT_PATH / 'humanoid_fixed.xml').as_posix())
data  = mujoco.MjData(model)
renderer = mujoco.Renderer(model)

# render the first frame
mujoco.mj_forward(model, data)
renderer.update_scene(data)
media.show_image(renderer.render())

# play through keyframes
for key in range(model.nkey):
    mujoco.mj_resetDataKeyframe(model, data, key)
    mujoco.mj_forward(model, data)
    renderer.update_scene(data)
    media.show_image(renderer.render())

In [None]:
#@title Define the CMU Humanoid Env - STILL TESING THIS OUT NOT READY TO RUN.
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')
CMU_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 3.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())


    # # Option 3: HUMANOID CMU
    mj_model = mujoco.MjModel.from_xml_path(
        (CMU_ROOT_PATH / 'humanoid_CMU.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    # New Order: attempting to put d_star in observation space
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state_info = {
    #     'goal': d_star,
    # }

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    state = State(data, obs, reward, done, metrics)

    # Sample a random 2D unit vector for the desired heading
    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    state.info['goal'] = d_star

    return state


  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt

    # New Reward: Target Heading Task Reward
    # d_star    = state.info['goal']
    # v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    # speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    # heading_r = jp.exp(-2.5 * speed_err**2)
    # forward_reward = self._heading_weight * heading_r

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    forward_reward = self._forward_reward_weight * velocity[0]
        # forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)

    reward = forward_reward + healthy_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Visualize Currently Activated Env
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(500):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  print("qpos:", state.pipeline_state.qpos)
  #print("qpos_length:", len(state.pipeline_state.qpos))
  #print("qvel:", state.pipeline_state.qvel)
  #print("qvel_length:", len(state.pipeline_state.qvel))

media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)

qpos: [-0.01  -0.01   0.859  1.    -0.006 -0.008 -0.008  0.025 -0.062  0.193  1.31   0.085  0.329 -1.187
 -0.34   0.352 -0.075 -0.026 -0.489 -0.844 -0.307 -0.032  0.026 -0.001  0.09  -0.112 -0.296  0.151]
qpos: [-0.01  -0.01   0.859  1.    -0.006 -0.008 -0.008  0.024 -0.061  0.191  1.309  0.079  0.328 -1.177
 -0.337  0.363 -0.075 -0.022 -0.488 -0.792 -0.299 -0.029  0.025 -0.001  0.098 -0.112 -0.302  0.149]
qpos: [-0.01  -0.01   0.859  1.    -0.006 -0.008 -0.008  0.023 -0.063  0.193  1.309  0.073  0.327 -1.168
 -0.342  0.367 -0.071 -0.022 -0.488 -0.734 -0.286 -0.025  0.02  -0.001  0.103 -0.114 -0.31   0.15 ]
qpos: [-0.01  -0.01   0.859  1.    -0.006 -0.008 -0.008  0.022 -0.062  0.19   1.307  0.069  0.325 -1.161
 -0.338  0.376 -0.072 -0.022 -0.484 -0.671 -0.271 -0.02   0.019  0.001  0.11  -0.115 -0.316  0.146]
qpos: [-0.01  -0.01   0.859  1.    -0.006 -0.008 -0.008  0.021 -0.062  0.19   1.305  0.063  0.324 -1.152
 -0.337  0.384 -0.073 -0.021 -0.479 -0.607 -0.25  -0.006  0.019  0.004  0.1

0
This browser does not support the video tag.


In [None]:
#@title Define the Humanoid Env with mocap reward

HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')

from scipy.spatial.transform import Rotation as R


class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    mj_model = mujoco.MjModel.from_xml_path(
        (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.step_count = 0


  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    self.step_count = 0

    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'goal': d_star,
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }


    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(mo_pos, stt_pos):

    pose_reward = 0
    vel_reward = 0
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[3:]
    stt_rot_data = stt_pos[3:]

    for q in range(0, len(mo_rot_data), 4):
        mo_quat = mo_rot_data[q:q+4]
        stt_quat = stt_rot_data[q:q+4]

        # Create Rotation objects
        mo_rot = R.from_quat(mo_quat)  # x, y, z, w order
        stt_rot = R.from_quat(stt_quat)

        # Compute the inverse of mo_rot
        mo_rot_inv = mo_rot.inv()

        # Compute the relative rotation (difference)
        diff_rot = stt_rot * mo_rot_inv

        # If you want the result as a quaternion again
        diff_quat = diff_rot.as_quat()  # still in (x, y, z, w) order

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = math.exp(max(pose_reward, -50))



    # mo_ang_vel = mo_vel[3:]
    # stt_ang_vel = stt_vel[:]


    # for v in range(0, len(mo_ang_vel)):
    #     vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
    #     vel_reward += (vel_diff) ** 2.0

    # vel_reward *= -0.1
    # vel_reward = math.exp(max(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    center_mass_reward = np.linalg.norm(mo_pos[:3] - stt_pos[:3])
    center_mass_reward *= -10.0
    center_mass_reward = math.exp(max(center_mass_reward, -50))

    full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)


    return full_reward


  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""

    self.step_count += 1

    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt


    # New Reward: Target Heading Task Reward
    d_star    = state.info['goal']
    v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    heading_r = jp.exp(-2.5 * speed_err**2)
    forward_reward = self._heading_weight * heading_r

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    # forward_reward = self._forward_reward_weight * velocity[0]
        # forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action, state.info)

    if self.step_count < (len(self.mocap_dict["walking"][1]) - 1):
      mocap_qpos = self.mocap_dict["walking"][1][self.step_count]["qpos"]
      guy_qpos = mujoco_to_quaternion(data.qpos)

      imit_reward = 3.0 * imit_reward(mocap_qpos, data.qpos)
    else:
      imit_reward = 0

    reward = forward_reward + healthy_reward + imit_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )


    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['goal'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Save Model
model_path = '/tmp/imitation_policy_humanoid_fixed_2' #<---- version 3 has d_star in observation space
model.save_params(model_path, params)
from google.colab import files
files.download(model_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#@title Render
# NOTE: Make sure to play/activate the current Humanoid Env you use to train the policy you wanted to render, so there is not shape mismatches.
train_fn = functools.partial(
    ppo.train, num_timesteps=0, num_evals=0, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']
  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print('Load Model...')
model_path = '/tmp/imitation_policy_humanoid_fixed_2'
# model_path = '/content/drive/MyDrive/mjx_humanoid_target_heading_run_policy'
# model_path = '/content/drive/MyDrive/run_policy'

params = model.load_params(model_path)

inference_fn = make_inference_fn(params) # <----- plugs the weights in

jit_inference_fn = jax.jit(inference_fn)

eval_env = envs.get_environment('humanoid')

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)

theta = jp.deg2rad(-90.0)
d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
state.info['goal'] = d_star

rollout = [state.pipeline_state] # <---- initliazie rollout list with the first reset state

n_steps = 500
render_every = 2

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng) # this is where your model is making the inference!
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

  if state.done:
    break

media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0 / env.dt / render_every)

print("Done.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
New frame:  4398
New frame:  4399
New frame:  4400
New frame:  4401
New frame:  4402
New frame:  4403
New frame:  4404
New frame:  4405
New frame:  4406
New frame:  4407
New frame:  4408
New frame:  4409
New frame:  4410
New frame:  4411
New frame:  4412
New frame:  4413
New frame:  4414
New frame:  4415
New frame:  4416
New frame:  4417
New frame:  4418
New frame:  4419
New frame:  4420
New frame:  4421
New frame:  4422
New frame:  4423
New frame:  4424
New frame:  4425
New frame:  4426
New frame:  4427
New frame:  4428
New frame:  4429
New frame:  4430
New frame:  4431
New frame:  4432
New frame:  4433
New frame:  4434
New frame:  4435
New frame:  4436
New frame:  4437
New frame:  4438
New frame:  4439
New frame:  4440
New frame:  4441
New frame:  4442
New frame:  4443
New frame:  4444
New frame:  4445
New frame:  4446
New frame:  4447
New frame:  4448
New frame:  4449
New frame:  4450
New frame:  4451
New frame:  4452


0
This browser does not support the video tag.


Done.
