# Installation Commands (You will need to these 6 cells before running anything else in this file!)

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax




In [None]:
#@title Check if MuJoCo installation was successful

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


In [None]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp # jp is jax numpy
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


In [None]:
#@title Mount GDrive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!git clone https://github.com/effypelayotran/mujoco_resources.git

fatal: destination path 'mujoco_resources' already exists and is not an empty directory.


# Define Humanoid Env and Reward

In [None]:
#@title Parse the Motion Capture into qpos & qvels

from pathlib import Path
import os
from scipy.spatial.transform import Rotation as R
import numpy as np

tasks = {}


def parse_bvh_file():
	print("Parsing")
	# Path to current script
	#current_file = Path(__file__)

	# Go two levels up to reach 'parent' folder
	#parent_dir = current_file.parent.parent

	mocap_folder = Path('/content/drive/MyDrive/make_him_walk/mocap')
	#mocap_folder = drive_folder / 'mocap'

	print("got mocap folder")

	#go through all the different types of actions in the mocap folder
	for action in mocap_folder.iterdir():
		if action.is_dir():
			#for each action make a string value in the tasks dictionary so we can map a list of clips to it
			tasks[action.name] = []
			curr_action = tasks[action.name]
			#get each individual clip
			for clip in action.glob('*.bvh'):  # Only files in this subfolder, not recursive
				#for each clip we append to the list of clips for the given action
				curr_action.append([])
				#access last element in the list of clips corresponding to the action we're looking at (this is the clip we are looking at)
				curr_clip = curr_action[-1]
				frame_data = []
				with open(clip, 'r', encoding='utf-8') as f:
					frame_num = 0
					joints = []
					motion_data_started = False
					current_frame = []

					#now we want to read through the clip and append a dict of all the joint info every frame to the list we just created representing that clip
					for line in f:
						stripped = line.strip()
						if not stripped:
							continue  # skip blank lines

						if stripped.startswith('ROOT') or stripped.startswith('JOINT'):
							joint_name = stripped.split()[1]
							print("found JOINT: ", joint_name)
							joints.append(joint_name)
							continue

						if stripped.startswith('Frame Time:'):
							motion_data_started = True
							print("Looking at motion data")
							continue

						#at this point we can start storing the joint data for each frame
						if motion_data_started:
							#store all the floats representing the joint data in current_frame
							floats = list(map(float, stripped.split()))
							current_frame.extend(floats)

							# Once we have 78 floats, store the frame and reset the list
							if len(current_frame) >= 78:
								frame_num += 1
								curr_clip.append({})
								frame = curr_clip[-1]
								frame_data.append(current_frame[:78])  # Store the first 78 floats
								curr_frame_data = frame_data[-1]
								current_frame = current_frame[78:]  # Keep the remaining floats for next frame
								for i in range (len(joints)):
									joint_name = joints[i]
									frame[joint_name] = {}
									#if we are looking at the hips
									if i == 0:
										hips = frame[joint_name]
										#hard code position and rotation for the hips since it's the only one with those two attirbutes
										hips["position"] = np.array(curr_frame_data[:3])

										#THIS IS WHERE WE SWITCH AXES !!!!!!!!
										#!!!!!!!!!!!!
										#

										hips["position"][0], hips["position"][1], hips["position"][2] = hips["position"][2], hips["position"][0], hips["position"][1]

                    # NEW
										z_rot, x_rot, y_rot = curr_frame_data[3:6]
										rot_bvh = R.from_euler('YZX', [z_rot, x_rot, y_rot], degrees=True)

										bvh_to_mj = R.from_euler('X', 0, degrees=True)
										turn_around_bitch = R.from_euler('Z', 180, degrees=True)
										quat = rot_bvh #bvh_to_mj * rot_bvh
										hips["rotation"] = quat

										# rotation = curr_frame_data[3:6]

										# #SWITCH AXES TO BE CONSITENT WITH MUJOCO
										# rotation[0], rotation[1], rotation[2] = (-1 * rotation[2]), rotation[0], rotation[1]

										# rot_bvh = R.from_euler('XYZ', rotation, degrees=True)

										# # x_adj = R.from_euler('x', -90, degrees=True)
										# # z_adj = R.from_euler('z', -90, degrees=True)

										# # transform = z_adj * x_adj
										# # transform = x_adj

										# # rot_bvh = transform * rot_bvh
										# quat = rot_bvh.as_quat()
										# hips["rotation"] = quat

										# #THINK WE NEED QUATERNION FOR THE HIPS
										# quaternion = R.from_euler('xyz', rotation, degrees=True)
										# hips["rotation"] = quaternion

										#I THINK WE WANT QUATERNION FOR THE HIPS SO I ACTUALLY COMMENT OUT THE ROTATION LINE
										#hips["rotation"] = rotation
										#any other joint
									else:
										joint = frame[joint_name]
										start_ind = (i+1) * 3
										end_ind = (i+2) * 3
										rotation = curr_frame_data[start_ind:end_ind]
										#SWITCH AXES TO BE CONSITENT WITH MUJOCO
										rotation[0], rotation[1], rotation[2] = rotation[2], rotation[0], rotation[1]

										#COMMENTING OUT QUATERNION STUFF SINCE WE DONT WANT TO USE IT
										#quaternion = R.from_euler('xyz', rotation, degrees=True)

										#CONVERT TO RADIANS SINCE IT STARTS OUT IN DEGREES
										joint["rotation"] = np.deg2rad(rotation)

				#go through frames to calculate velocity
				for f in range (len(curr_clip)):
					frame = curr_clip[f]

					#if we're at the end of the list we don't want to check rotation in next frame since there is no next frame
					if f == (len(curr_clip) - 1):
						prev_frame = curr_clip[f-1]
						for joint in frame.keys():
							#just set the final velocity to whatever the previous frame's velocity was
							frame[joint]["velocity"] = prev_frame[joint]["velocity"]

					else:
						next_frame = curr_clip[f + 1]
						for joint in frame.keys():
							frame1_rot = frame[joint]["rotation"]
							frame2_rot = next_frame[joint]["rotation"]

							#HIPS ARE QUATERNION, OTHER JOINTS ARE RADIANS
							if joint == "Hips":
								rot_diff = frame2_rot * frame1_rot.inv()
								diff_angle = rot_diff.as_rotvec()
								frame[joint]["velocity"] = diff_angle

							else:
								diff_angle = frame2_rot - frame1_rot
							#DIVIDE BY DT WHICH WILL BE PASSED THROUGH AND DETERMINED BY THE MUJOCO ENVIRONMENT
							vel = diff_angle
							frame[joint]["velocity"] = vel

					#FOR NOW NOT USING END EFFECTS (MIGHT DELETE COMPLETELY TBH)
					#NEXT FIND POSITIONS OF HANDS AND FEET
					#find hip rotation and position first to start accumalation
					# hip_world_pos = frame["Hips"]["position"]
					# hip_world_rot = frame["Hips"]["rotation"]

					# #spine
					# spine_offset = np.array([0.00000, 0.00000, 0.03937])
					# #apply parent rotation (being the hip rot) to get the actual position in world space
					# spine_world = hip_world_rot.apply(spine_offset)
					# #now create world_pos as a new variable so that we can save the hip rot and pos
					# world_pos = hip_world_pos + spine_world
					# world_rot = hip_world_rot * frame["Spine"]["rotation"]

					# #spine1
					# spine1_offset = np.array([0.00000, 0.00000, 10.24829])
					# spine1_world = world_rot.apply(spine1_offset)
					# world_pos = world_pos + spine1_world
					# world_rot = world_rot * frame["Spine1"]["rotation"]

					# #right shoulder
					# r_should_offset = np.array([0.00000, 0.00000, 7.82687])
					# r_should_world = world_rot.apply(r_should_offset)
					# #change to right world pos/rotation now since we'll want to save pos/rot of spine for the left side
					# right_world_pos = world_pos + r_should_world
					# right_world_rot = world_rot * frame["RightShoulder"]["rotation"]

					# #right arm
					# r_arm_offset = np.array([0.00000, -6.71018, -0.00002])
					# r_arm_world = right_world_rot.apply(r_arm_offset)
					# right_world_pos = right_world_pos + r_arm_world
					# right_world_rot = right_world_rot * frame["RightArm"]["rotation"]

					# #right forearm
					# r_forearm_offset = np.array([0.00000, -10.94419, -0.00004])
					# r_forearm_world = right_world_rot.apply(r_forearm_offset)
					# right_world_pos = right_world_pos + r_forearm_world
					# right_world_rot = right_world_rot * frame["RightForeArm"]["rotation"]

					# #right hand
					# r_hand_offset = np.array([0.00000 ,-8.52010, -0.00003])
					# r_hand_world = right_world_rot.apply(r_hand_offset)
					# r_hand_world = right_world_pos + r_hand_world

					# frame["RightHand"]["position"] = r_hand_world

					# #left shoulder
					# l_should_offset = np.array([0.00000, 0.00000, 7.82687])
					# l_should_world = world_rot.apply(l_should_offset)
					# #now do l world pos/rot
					# l_world_pos = world_pos + l_should_world
					# l_world_rot = world_rot * frame["LeftShoulder"]["rotation"]

					# #left arm
					# l_arm_offset = np.array([0.00000, 6.71018, -0.00002])
					# l_arm_world = l_world_rot.apply(l_arm_offset)
					# l_world_pos = l_world_pos + l_arm_world
					# l_world_rot = l_world_rot * frame["LeftArm"]["rotation"]

					# #left forearm
					# l_forearm_offset = np.array([0.00000, 10.94419, -0.00004])
					# l_forearm_world = l_world_rot.apply(l_forearm_offset)
					# l_world_pos = l_world_pos + l_forearm_world
					# l_world_rot = l_world_rot * frame["LeftForeArm"]["rotation"]

					# #left hand
					# l_hand_offset = np.array([0.00000 ,8.52010, -0.00003])
					# l_hand_world = l_world_rot.apply(l_hand_offset)
					# l_hand_world = l_world_pos + l_hand_world

					# frame["LeftHand"]["position"] = l_hand_world

					# #left upper leg
					# #we start back at the hips and luckily we stored it earlier and didnt change it during accumalation
					# l_up_leg_offset = np.array([0.00000, 3.64953, 0.00000])
					# l_up_leg_world = hip_world_rot.apply(l_up_leg_offset)
					# #now do l world pos/rot again since we're not doing anything else with the left hand
					# l_world_pos = hip_world_pos + l_up_leg_world
					# l_world_rot = hip_world_rot * frame["LeftUpLeg"]["rotation"]

					# #left leg
					# l_leg_offset = np.array([0.00000, 0.00000, -15.70580])
					# l_leg_world = l_world_rot.apply(l_leg_offset)
					# l_world_pos = l_world_pos + l_leg_world
					# l_world_rot = l_world_rot * frame["LeftLeg"]["rotation"]

					# #left foot
					# l_foot_offset = np.array([0.00000, 0.00000, -15.41867])
					# l_foot_world = l_world_rot.apply(l_foot_offset)
					# l_foot_world = l_world_pos + l_foot_world

					# frame["LeftFoot"]["position"] = l_foot_world

					# #right upper leg
					# r_up_leg_offset = np.array([0.00000, -3.64953, 0.00000])
					# r_up_leg_world = hip_world_rot.apply(r_up_leg_offset)
					# #now do l world pos/rot again since we're not doing anything else with the left hand
					# r_world_pos = hip_world_pos + r_up_leg_world
					# r_world_rot = hip_world_rot * frame["RightUpLeg"]["rotation"]

					# #right leg
					# r_leg_offset = np.array([0.00000, 0.00000, -15.70580])
					# r_leg_world = r_world_rot.apply(r_leg_offset)
					# r_world_pos = r_world_pos + r_leg_world
					# r_world_rot = r_world_rot * frame["RightLeg"]["rotation"]

					# #right foot
					# r_foot_offset = np.array([0.00000, 0.00000, -15.41867])
					# r_foot_world = r_world_rot.apply(r_foot_offset)
					# r_foot_world = r_world_pos + r_foot_world

					# frame["RightFoot"]["position"] = r_foot_world

	bvh_to_mujoco()
	return tasks


def bvh_to_mujoco():
	for task in tasks.keys():
		print("TASK ", task)
		for clip in tasks[task]:
			for frame in clip:
				qpos = []
				qpos.append(frame["Hips"]["position"][0])
				qpos.append(frame["Hips"]["position"][1])
				qpos.append(frame["Hips"]["position"][2])

				#WE DO THIS WEIRD ORDER BECAUSE MUJOCO ACTUALLY USED W,X,Y,Z WHILE SCIPY USED X,Y,Z,W
				qpos.append(frame["Hips"]["rotation"].as_quat()[3])
				qpos.append(frame["Hips"]["rotation"].as_quat()[0])
				qpos.append(frame["Hips"]["rotation"].as_quat()[1])
				qpos.append(frame["Hips"]["rotation"].as_quat()[2])

				#waist_lower
				#z rotation
				qpos.append(frame["Spine"]["rotation"][2])
				#y rotation
				qpos.append(frame["Spine"]["rotation"][1])

				#torso y
				qpos.append(frame["Spine1"]["rotation"][1])

				#upper_arm_right
				qpos.append(frame["RightArm"]["rotation"][0])
				qpos.append(frame["RightArm"]["rotation"][1] * -1)

				#elbow_right
				qpos.append(frame["RightForeArm"]["rotation"][2])


				#upper_arm_left

				qpos.append(frame["LeftArm"]["rotation"][0])
				qpos.append(frame["LeftArm"]["rotation"][1])

				#elbow_left
				#0,-1,-1
				qpos.append(frame["LeftForeArm"]["rotation"][2] * -1)

				#thigh_right
				qpos.append(frame["RightUpLeg"]["rotation"][0])
				qpos.append(frame["RightUpLeg"]["rotation"][2])
				qpos.append(frame["RightUpLeg"]["rotation"][1])

				#knee_right
				#multiply by -1 because axis is 0,-1,0
				qpos.append(frame["RightLeg"]["rotation"][1] * -1)

				#foot_right
				qpos.append(frame["RightFoot"]["rotation"][1])
				#second axis is 1,0,.5
				axis = np.array([1,0,0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["RightFoot"]["rotation"], axis)
				qpos.append(proj)


				#thigh_left
				qpos.append(frame["LeftUpLeg"]["rotation"][0] * -1)
				qpos.append(frame["LeftUpLeg"]["rotation"][2] * -1)
				qpos.append(frame["LeftUpLeg"]["rotation"][1])

				#knee_left
				#multiply by -1 because axis is 0,-1,0
				qpos.append(frame["LeftLeg"]["rotation"][1] * -1)

				#foot_left
				qpos.append(frame["LeftFoot"]["rotation"][1])
				#second axis is -1,0,-.5
				axis = np.array([-1,0,-0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["LeftFoot"]["rotation"], axis)
				qpos.append(proj)

				frame["qpos"] = qpos

				qvel = []

				#FOR NOW JUST PUT THE FIRST 3 ENTRIES AS 0 SINCE WE'RE CURRENTLY NOT IMPLEMENTING HIPS POSITION
				for i in range (3):
					qvel.append(0)

				qvel.append(frame["Hips"]["velocity"][0])
				qvel.append(frame["Hips"]["velocity"][1])
				qvel.append(frame["Hips"]["velocity"][2])
				#waist_lower
				#z rotation
				qvel.append(frame["Spine"]["velocity"][2])
				#y rotation
				qvel.append(frame["Spine"]["velocity"][1])

				#torso y
				qvel.append(frame["Spine1"]["velocity"][1])

				#upper_arm_right
				qvel.append(frame["RightArm"]["velocity"][0])
				qvel.append(frame["RightArm"]["velocity"][1] * -1)

				#elbow_right
				qvel.append(frame["RightForeArm"]["velocity"][2])


				#upper_arm_left

				qvel.append(frame["LeftArm"]["velocity"][0])
				qvel.append(frame["LeftArm"]["velocity"][1])

				#elbow_left
				#0,-1,-1
				qvel.append(frame["LeftForeArm"]["velocity"][2] * -1)

				#thigh_right
				qvel.append(frame["RightUpLeg"]["velocity"][0])
				qvel.append(frame["RightUpLeg"]["velocity"][2])
				qvel.append(frame["RightUpLeg"]["velocity"][1])

				#knee_right
				#multiply by -1 because axis is 0,-1,0
				qvel.append(frame["RightLeg"]["velocity"][1] * -1)

				#foot_right
				qvel.append(frame["RightFoot"]["velocity"][1])
				#second axis is 1,0,.5
				axis = np.array([1,0,0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["RightFoot"]["velocity"], axis)
				qvel.append(proj)


				#thigh_left
				qvel.append(frame["LeftUpLeg"]["velocity"][0] * -1)
				qvel.append(frame["LeftUpLeg"]["velocity"][2] * -1)
				qvel.append(frame["LeftUpLeg"]["velocity"][1])

				#knee_left
				#multiply by -1 because axis is 0,-1,0
				qvel.append(frame["LeftLeg"]["velocity"][1] * -1)

				#foot_left
				qvel.append(frame["LeftFoot"]["velocity"][1])
				#second axis is -1,0,-.5
				axis = np.array([-1,0,-0.5])
				axis = axis/np.linalg.norm(axis)
				proj = np.dot(frame["LeftFoot"]["velocity"], axis)
				qvel.append(proj)

				frame["qvel"] = qvel


#COMMENTING THIS OUT (WILL PROBABLY DELETE) SINCE WE'RE NOT USING QUATERNIONS ANYMORE
# def quat_from_xyz(x_angle=0.0, y_angle=0.0, z_angle=0.0):
# 	cx = jp.cos(x_angle / 2)
# 	sx = jp.sin(x_angle / 2)
# 	cy = jp.cos(y_angle / 2)
# 	sy = jp.sin(y_angle / 2)
# 	cz = jp.cos(z_angle / 2)
# 	sz = jp.sin(z_angle / 2)

# 	qx = jp.array([cx, sx, 0, 0])
# 	qy = jp.array([cy, 0, sy, 0])
# 	qz = jp.array([cz, 0, 0, sz])

# 	# Multiply quaternions: q = qx * (qy * qz)
# 	def quat_mul(q1, q2):
# 			w1, x1, y1, z1 = q1
# 			w2, x2, y2, z2 = q2
# 			return jp.array([
# 					w1*w2 - x1*x2 - y1*y2 - z1*z2,
# 					w1*x2 + x1*w2 + y1*z2 - z1*y2,
# 					w1*y2 - x1*z2 + y1*w2 + z1*x2,
# 					w1*z2 + x1*y2 - y1*x2 + z1*w2
# 			])

# 	q = quat_mul(qx, quat_mul(qy, qz))
# 	return q

# def mujoco_to_quaternion(qpos):

# 	quaternions=jp.array([])

# 	#hip data
# 	quaternions = jp.concatenate([quaternions, jp.array(qpos[:7])])

# 	#lower waist
# 	abdomen_z, abdomen_y = qpos[7:9]
# 	abdomen_quat = quat_from_xyz(y_angle=abdomen_y, z_angle=abdomen_z)
# 	quaternions = jp.concatenate([quaternions, abdomen_quat])


# 	#torso
# 	torso_y = qpos[9]
# 	torso_quat = quat_from_xyz(y_angle = torso_y)
# 	quaternions = jp.concatenate([quaternions, torso_quat])


# 	#upper right arm
# 	should1_right, should2_right = qpos[10:12]
# 	should_right_quat = quat_from_xyz(x_angle = should1_right, y_angle = should2_right)
# 	quaternions = jp.concatenate([quaternions, should_right_quat])


# 	#right elbow
# 	elbow_right = qpos[12]
# 	elbow_right_quat = quat_from_xyz(y_angle = elbow_right)
# 	quaternions = jp.concatenate([quaternions, elbow_right_quat])


# 	#upper left arm
# 	should1_left, should2_left = qpos[13:15]
# 	should_left_quat = quat_from_xyz(x_angle = should1_left, y_angle =should2_left)
# 	quaternions = jp.concatenate([quaternions, should_left_quat])


# 	#left elbow
# 	elbow_left = qpos[15]
# 	elbow_left_quat = quat_from_xyz(y_angle = elbow_left)
# 	quaternions = jp.concatenate([quaternions, elbow_left_quat])


# 	#right thigh
# 	hip_x_right, hip_z_right, hip_y_right = qpos[16:19]
# 	hip_right_quat = quat_from_xyz(x_angle=hip_x_right, y_angle=hip_y_right, z_angle=hip_z_right)
# 	quaternions = jp.concatenate([quaternions, hip_right_quat])


# 	#right shin
# 	knee_right = qpos[19]
# 	knee_right_quat = quat_from_xyz(y_angle=knee_right)
# 	quaternions = jp.concatenate([quaternions, knee_right_quat])


# 	#right foot
# 	ankle_y_right, ankle_x_right = qpos[20:22]
# 	ankle_right_quat = quat_from_xyz(x_angle=ankle_x_right, y_angle=ankle_y_right)
# 	quaternions = jp.concatenate([quaternions, ankle_right_quat])


# 	#left thigh
# 	hip_x_left, hip_z_left, hip_y_left = qpos[22:25]
# 	hip_left_quat = quat_from_xyz(x_angle=hip_x_left * -1.0, y_angle=hip_y_left, z_angle=hip_z_left * -1.0)
# 	quaternions = jp.concatenate([quaternions, hip_left_quat])


# 	#left shin
# 	knee_left = qpos[25]
# 	knee_left_quat = quat_from_xyz(y_angle=knee_left)
# 	quaternions = jp.concatenate([quaternions, knee_left_quat])


# 	#left foot
# 	ankle_y_left, ankle_x_left = qpos[26:28]
# 	ankle_left_quat = quat_from_xyz(x_angle=ankle_x_left * -1.0, y_angle=ankle_y_left)
# 	quaternions = jp.concatenate([quaternions, ankle_left_quat])

# 	print("QUATERNIONS LENGTH: ", len(quaternions))
# 	return quaternions





In [None]:
#@title View Mocap Only
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')
FIXED_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')
from jax import lax


class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.5, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model = mujoco.MjModel.from_xml_path(
         (FIXED_ROOT_PATH / 'humanoid_fixed_7.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.frames = self.mocap_dict["martial_arts"][0]
    self.frames_qpos = []
    for frame in self.frames:
      qpos = frame["qpos"]
      self.frames_qpos.append(qpos)
    #each row is a frame
    self.frames_qpos = np.array(self.frames_qpos)
    print(self.frames_qpos. shape[0], ", ", self.frames_qpos.shape[1])
    self.frames_qpos = jp.array(self.frames_qpos)
    self.step_count = 0

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    self.step_count = 0
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'frame_count': jp.array(0),
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(self, mo_pos, stt_pos):

    mo_pos = jp.array(mo_pos)
    stt_pos = jp.array(stt_pos)

    pose_reward = 0
    vel_reward = 0
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[7:]
    stt_rot_data = stt_pos[7:]

    print("mo_rot_data: ", len(mo_rot_data))
    print("stt_rot_data: ", len(stt_rot_data))

    for r in range(0, len(mo_rot_data)):
        rot_diff = mo_rot_data[r] - stt_rot_data[r]
        pose_reward += (rot_diff ** 2)

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        #pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = jp.exp(jp.maximum(pose_reward, -50))



    # mo_ang_vel = mo_vel[3:]
    # stt_ang_vel = stt_vel[:]


    # for v in range(0, len(mo_ang_vel)):
    #     vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
    #     vel_reward += (vel_diff) ** 2.0

    # vel_reward *= -0.1
    # vel_reward = math.exp(max(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    # center_mass_reward = jp.linalg.norm(mo_pos[:3] - stt_pos[:3])
    # center_mass_reward *= -10.0
    # center_mass_reward = jp.exp(jp.maximum(center_mass_reward, -50))

    #full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)
    full_reward = pose_reward
    return full_reward


  def step(self, state: State, action: jp.ndarray) -> State:
    current_frame = state.info['frame_count']
    print("Current Frame:", current_frame)
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    #data = self.pipeline_step(data0, action)

    #FOR DEBUGGING MOCAP DATA
    data = data0
    hip_pos = data.qpos[:7]
    rest = self.frames_qpos[current_frame]

    new_qpos = jp.concatenate([hip_pos, rest])
    data = self.pipeline_init(rest, data.qvel)

    # com_before = data0.subtree_com[1]
    # com_after = data.subtree_com[1]
    # # velocity = (x_t - x_0) / t
    # velocity = (com_after - com_before) / self.dt

    # # New Reward: Target Heading Task Reward
    # d_star    = state.info['goal']
    # v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    # speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    # heading_r = jp.exp(-2.5 * speed_err**2)
    # forward_reward = self._heading_weight * heading_r
    # state.info['v_xy'] = velocity[:2]

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    #forward_reward = self._forward_reward_weight * velocity[0]
    #forward_reward = 1.25 * velocity in the x direction

    # min_z, max_z = self._healthy_z_range
    # is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    # is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    # if self._terminate_when_unhealthy:
    #   healthy_reward = self._healthy_reward
    # else:
    #   healthy_reward = self._healthy_reward * is_healthy

    # ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    frame_count = state.info['frame_count'] + 1
    state_info = dict(state.info)
    state_info['frame_count'] = frame_count

    obs = self._get_obs(data, action, state.info)

    #IMITATION REWARD!!!!!!
    # max_frame = self.frames_qpos.shape[0] - 1

    # def compute_reward(_):
    #     mocap_qpos = self.frames_qpos[current_frame]
    #     return 3.0 * self.imit_reward(mocap_qpos, data.qpos)

    # def zero_reward(_):
    #     return 0.0

    # imit_reward = lax.cond(
    #     current_frame < max_frame,
    #     compute_reward,
    #     zero_reward,
    #     operand=None
    # )
    # if current_frame < (self.frames_qpos.shape[0] - 1):
    #   mocap_qpos = self.frames_qpos[current_frame]

    #   imit_reward = 3.0 * self.imit_reward(mocap_qpos, data.qpos)
    # else:
    #   imit_reward = 0


    #ADD IT ALL UP
    # reward = forward_reward + healthy_reward - ctrl_cost

    # done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    # state.metrics.update(
    #     forward_reward=forward_reward,
    #     reward_linvel=forward_reward,
    #     reward_quadctrl=-ctrl_cost,
    #     reward_alive=healthy_reward,
    #     x_position=com_after[0],
    #     y_position=com_after[1],
    #     distance_from_origin=jp.linalg.norm(com_after),
    #     x_velocity=velocity[0],
    #     y_velocity=velocity[1],
    # )

    reward = 1
    done = False

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state_info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,

        # state_info['goal'],
        # state_info['v_xy'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Define Humanoid Env for Good Walk
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
TERRAIN_ROOT_PATH = epath.Path('mujoco_resources/humanoid_terrain')
FIXED_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')
from jax import lax
from jax import random


def quat_inv(q):
    """Returns the inverse (conjugate) of a unit quaternion [w, x, y, z]."""
    w, x, y, z = q
    return jp.array([w, -x, -y, -z])

def quat_mul(q1, q2):
        """Returns the multiplication of two quaternions [w, x, y, z]."""
        w1, x1, y1, z1 = q1
        w2, x2, y2, z2 = q2
        return jp.array([
            w1*w2 - x1*x2 - y1*y2 - z1*z2,
            w1*x2 + x1*w2 + y1*z2 - z1*y2,
            w1*y2 - x1*z2 + y1*w2 + z1*x2,
            w1*z2 + x1*y2 - y1*x2 + z1*w2
        ])

def quat_to_rotvec(q):
        """Convert a unit quaternion [w, x, y, z] to axis-angle vector (rotvec)."""
        norm_q = q / jp.linalg.norm(q)
        w, xyz = norm_q[0], norm_q[1:]
        sin_half_theta = jp.linalg.norm(xyz)
        angle = 2.0 * jp.arctan2(sin_half_theta, w)
        axis = jp.where(sin_half_theta > 1e-6, xyz / sin_half_theta, jp.zeros_like(xyz))
        return angle * axis

class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.5, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model = mujoco.MjModel.from_xml_path(
         (FIXED_ROOT_PATH / 'humanoid_fixed_7.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.frames = self.mocap_dict["walking"][0]
    self.frames_qvel = []
    self.frames_qpos = []
    for frame in self.frames:
      qpos = frame["qpos"]
      qvel = frame["qvel"]
      self.frames_qpos.append(qpos)
      self.frames_qvel.append(qvel)
    #each row is a frame
    self.frames_qpos = np.array(self.frames_qpos)
    self.frames_qvel = np.array(self.frames_qvel)
    print(self.frames_qpos. shape[0], ", ", self.frames_qpos.shape[1])
    self.frames_qpos = jp.array(self.frames_qpos)
    self.frames_qvel = jp.array(self.frames_qvel)
    self.step_count = 0
    self.key = random.PRNGKey(42)  # Set the seed ONCE
    self.key, subkey = random.split(self.key)
    self.random_frame = random.randint(subkey, shape=(), minval=0, maxval=4000)
    print("rand frame: ", self.random_frame)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    # key = random.PRNGKey(rng3)  # Set the seed
    # key, subkey = random.split(self.key)
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
    random_frame = random.randint(rng1, shape=(), minval=0, maxval=4000)
    #jax.debug.print("Random frame: {}", random_frame)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    #RANDOMLY INITIALIZE TO SOME RANDOM FRAME IN THE MOCAP
    qpos = qpos.at[7:].set(self.frames_qpos[random_frame][7:])
    qvel = qvel.at[6:].set(self.frames_qvel[random_frame][6:])

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'frame_count': jp.array(random_frame),
        'mocap_qpos': qpos,
        'mocap_qvel': qvel
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(self, mo_pos, stt_pos, mo_vel, stt_vel):

    mo_pos = jp.array(mo_pos)
    stt_pos = jp.array(stt_pos)

    pose_reward = 0
    vel_reward = 0

    # rot_diff = quat_mul(mo_pos[3:7], quat_inv(stt_pos[3:7]))
    # diff_angle = quat_to_rotvec(rot_diff)
    # pose_reward = jp.sum(diff_angle**2)
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[7:]
    stt_rot_data = stt_pos[7:]

    print("mo_rot_data: ", len(mo_rot_data))
    print("stt_rot_data: ", len(stt_rot_data))

    for r in range(0, len(mo_rot_data)):
        rot_diff = mo_rot_data[r] - stt_rot_data[r]
        pose_reward += (rot_diff ** 2)

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        #pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = jp.exp(jp.maximum(pose_reward, -50))



    mo_ang_vel = mo_vel[6:]
    stt_ang_vel = stt_vel[6:]


    for v in range(0, len(mo_ang_vel)):
       vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
       vel_reward += (vel_diff ** 2.0)

    vel_reward *= -0.1
    vel_reward = jp.exp(jp.maximum(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    # center_mass_reward = jp.linalg.norm(mo_pos[:3] - stt_pos[:3])
    # center_mass_reward *= -10.0
    # center_mass_reward = jp.exp(jp.maximum(center_mass_reward, -50))

    #full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)
    full_reward = 0.7 * pose_reward + 0.3 * vel_reward
    return full_reward



  def step(self, state: State, action: jp.ndarray) -> State:
    current_frame = state.info['frame_count']
    print("Current Frame:", current_frame)
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    #FOR DEBUGGING MOCAP DATA
    # data = data0
    # hip_pos = data.qpos[:7]
    # rest = self.frames_qpos[current_frame][7:]
    # new_qpos = jp.concatenate([hip_pos, rest])
    # data = self.pipeline_init(new_qpos, data.qvel)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt

    # # New Reward: Target Heading Task Reward
    # d_star    = state.info['goal']
    # v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    # speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    # heading_r = jp.exp(-2.5 * speed_err**2)
    # forward_reward = self._heading_weight * heading_r
    # state.info['v_xy'] = velocity[:2]

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    forward_reward = self._forward_reward_weight * (velocity/jp.linalg.norm(velocity))[0] #forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] < (min_z + .3), 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] > (max_z + .3), 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    frame_count = state.info['frame_count'] + 1
    state_info = dict(state.info)
    state_info['frame_count'] = frame_count

    obs = self._get_obs(data, action, state.info)

    #IMITATION REWARD!!!!!!
    max_frame = self.frames_qpos.shape[0] - 1

    mocap_qpos = self.frames_qpos[0]
    mocap_qvel = self.frames_qvel[0]
    mocap_qpos = jp.zeros_like(mocap_qpos)
    mocap_qvel = jp.zeros_like(mocap_qvel)
    def compute_reward(_):
        mocap_qpos = self.frames_qpos[current_frame]
        mocap_qvel = self.frames_qvel[current_frame]
        return self.imit_reward(mocap_qpos, data.qpos, mocap_qvel, data.qvel)

    def zero_reward(_):
        return 0.0

    imit_reward = lax.cond(
        current_frame < max_frame,
        compute_reward,
        zero_reward,
        operand=None
    )

    state.info['mocap_qpos'] = mocap_qpos
    state.info['mocap_qvel'] = mocap_qvel

    #GIVE IT SOME REWARD FOR NOT BENDING IT'S TORSO MUCH
    torso_rot = data.qpos[7:10]
    stiffness_reward = 0
    for r in torso_rot:
      stiffness_reward += r**2
    stiffness_reward = jp.exp(stiffness_reward * -2)



    #ADD IT ALL UP
    reward = forward_reward + healthy_reward + stiffness_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )


    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state_info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['mocap_qpos'],
        state_info['mocap_qvel']
        # state_info['goal'],
        # state_info['v_xy'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Define Humanoid Env for Walking in Circle
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
FIXED_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')
from jax import lax
from jax import random


def quat_inv(q):
    """Returns the inverse (conjugate) of a unit quaternion [w, x, y, z]."""
    w, x, y, z = q
    return jp.array([w, -x, -y, -z])

def quat_mul(q1, q2):
        """Returns the multiplication of two quaternions [w, x, y, z]."""
        w1, x1, y1, z1 = q1
        w2, x2, y2, z2 = q2
        return jp.array([
            w1*w2 - x1*x2 - y1*y2 - z1*z2,
            w1*x2 + x1*w2 + y1*z2 - z1*y2,
            w1*y2 - x1*z2 + y1*w2 + z1*x2,
            w1*z2 + x1*y2 - y1*x2 + z1*w2
        ])

def quat_to_rotvec(q):
        """Convert a unit quaternion [w, x, y, z] to axis-angle vector (rotvec)."""
        norm_q = q / jp.linalg.norm(q)
        w, xyz = norm_q[0], norm_q[1:]
        sin_half_theta = jp.linalg.norm(xyz)
        angle = 2.0 * jp.arctan2(sin_half_theta, w)
        axis = jp.where(sin_half_theta > 1e-6, xyz / sin_half_theta, jp.zeros_like(xyz))
        return angle * axis

class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.5, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model = mujoco.MjModel.from_xml_path(
         (FIXED_ROOT_PATH / 'humanoid_fixed_7.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.frames = self.mocap_dict["walking"][0]
    self.frames_qvel = []
    self.frames_qpos = []
    for frame in self.frames:
      qpos = frame["qpos"]
      qvel = frame["qvel"]
      self.frames_qpos.append(qpos)
      self.frames_qvel.append(qvel)
    #each row is a frame
    self.frames_qpos = np.array(self.frames_qpos)
    self.frames_qvel = np.array(self.frames_qvel)
    print(self.frames_qpos. shape[0], ", ", self.frames_qpos.shape[1])
    self.frames_qpos = jp.array(self.frames_qpos)
    self.frames_qvel = jp.array(self.frames_qvel)
    self.step_count = 0
    self.key = random.PRNGKey(42)  # Set the seed ONCE
    self.key, subkey = random.split(self.key)
    self.random_frame = random.randint(subkey, shape=(), minval=0, maxval=4000)
    print("rand frame: ", self.random_frame)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    # key = random.PRNGKey(rng3)  # Set the seed
    # key, subkey = random.split(self.key)
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
    random_frame = random.randint(rng1, shape=(), minval=0, maxval=4000)
    #jax.debug.print("Random frame: {}", random_frame)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    #RANDOMLY INITIALIZE TO SOME RANDOM FRAME IN THE MOCAP
    qpos = qpos.at[:].set(self.frames_qpos[random_frame][:])
    qvel = qvel.at[:].set(self.frames_qvel[random_frame][:])

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'frame_count': jp.array(random_frame),
        'mocap_qpos': qpos,
        'mocap_qvel': qvel
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(self, mo_pos, stt_pos, mo_vel, stt_vel):

    mo_pos = jp.array(mo_pos)
    stt_pos = jp.array(stt_pos)

    pose_reward = 0
    vel_reward = 0

    rot_diff = quat_mul(mo_pos[3:7], quat_inv(stt_pos[3:7]))
    diff_angle = quat_to_rotvec(rot_diff)
    pose_reward = jp.sum(diff_angle**2)
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[7:]
    stt_rot_data = stt_pos[7:]

    # print("mo_rot_data: ", len(mo_rot_data))
    # print("stt_rot_data: ", len(stt_rot_data))

    for r in range(0, len(mo_rot_data)):
        rot_diff = mo_rot_data[r] - stt_rot_data[r]
        pose_reward += (rot_diff ** 2)

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        #pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = jp.exp(jp.maximum(pose_reward, -50))


    mo_ang_vel = mo_vel[3:]
    stt_ang_vel = stt_vel[3:]
    # mo_ang_vel = mo_vel[6:]
    # stt_ang_vel = stt_vel[6:]


    for v in range(0, len(mo_ang_vel)):
       vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
       vel_reward += (vel_diff ** 2.0)

    vel_reward *= -0.1
    vel_reward = jp.exp(jp.maximum(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    center_mass_reward = jp.linalg.norm(mo_pos[:3] - stt_pos[:3])
    center_mass_reward *= -10.0
    center_mass_reward = jp.exp(jp.maximum(center_mass_reward, -50))

    #full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)
    full_reward = (0.6 * pose_reward) + (0.3 * vel_reward)  + (0.1 * center_mass_reward)
    return full_reward



  def step(self, state: State, action: jp.ndarray) -> State:
    current_frame = state.info['frame_count']
    print("Current Frame:", current_frame)
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    #FOR DEBUGGING MOCAP DATA
    # data = data0
    # hip_pos = data.qpos[:7]
    # rest = self.frames_qpos[current_frame][7:]
    # new_qpos = jp.concatenate([hip_pos, rest])
    # data = self.pipeline_init(new_qpos, data.qvel)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt

    # # New Reward: Target Heading Task Reward
    # d_star    = state.info['goal']
    # v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    # speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    # heading_r = jp.exp(-2.5 * speed_err**2)
    # forward_reward = self._heading_weight * heading_r
    # state.info['v_xy'] = velocity[:2]

    hip_quat = data.x.rot[0]
    forward_world = math.rotate(jp.array([1.0, 0.0, 0.0]), hip_quat) #rotate forward axis which is x to match up with quaternion rotation so this is characters forward
    f_xy = forward_world[:2]
    f_xy = f_xy/(jp.linalg.norm(f_xy) + 1e-6)#normalize

    v_xy = velocity[0:2]/(jp.linalg.norm(velocity[0:2]) + 1e-6)
    dir_rew = v_xy @ f_xy #forward direction dotted with velocity direction
    forward_reward = self._forward_reward_weight * dir_rew
    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    #forward_reward = self._forward_reward_weight * (velocity/jp.linalg.norm(velocity))[0] #forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] < (min_z + .3), 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] > (max_z + .3), 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    frame_count = state.info['frame_count'] + 1
    state_info = dict(state.info)
    state_info['frame_count'] = frame_count

    obs = self._get_obs(data, action, state.info)

    #IMITATION REWARD!!!!!!
    max_frame = self.frames_qpos.shape[0] - 1

    mocap_qpos = self.frames_qpos[0]
    mocap_qvel = self.frames_qvel[0]
    mocap_qpos = jp.zeros_like(mocap_qpos)
    mocap_qvel = jp.zeros_like(mocap_qvel)
    def compute_reward(_):
        mocap_qpos = self.frames_qpos[current_frame]
        mocap_qvel = self.frames_qvel[current_frame]
        return self.imit_reward(mocap_qpos, data.qpos, mocap_qvel, data.qvel)

    def zero_reward(_):
        return 0.0

    imit_reward = lax.cond(
        current_frame < max_frame,
        compute_reward,
        zero_reward,
        operand=None
    )

    state_info['mocap_qpos'] = mocap_qpos
    state_info['mocap_qvel'] = mocap_qvel

    #GIVE IT SOME REWARD FOR NOT BENDING IT'S TORSO MUCH
    torso_rot = data.qpos[7:10]
    stiffness_reward = 0
    for r in torso_rot:
      stiffness_reward += r**2
    stiffness_reward = jp.exp(stiffness_reward * -2)



    #ADD IT ALL UP
    reward = forward_reward + healthy_reward + stiffness_reward + (3.0 * imit_reward) - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )


    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state_info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['mocap_qpos'],
        state_info['mocap_qvel']
        # state_info['goal'],
        # state_info['v_xy'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
#@title Define Humanoid Env for Martial Arts
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
FIXED_ROOT_PATH = epath.Path('mujoco_resources/humanoid_CMU_folder')
from jax import lax
from jax import random


def quat_inv(q):
    """Returns the inverse (conjugate) of a unit quaternion [w, x, y, z]."""
    w, x, y, z = q
    return jp.array([w, -x, -y, -z])

def quat_mul(q1, q2):
        """Returns the multiplication of two quaternions [w, x, y, z]."""
        w1, x1, y1, z1 = q1
        w2, x2, y2, z2 = q2
        return jp.array([
            w1*w2 - x1*x2 - y1*y2 - z1*z2,
            w1*x2 + x1*w2 + y1*z2 - z1*y2,
            w1*y2 - x1*z2 + y1*w2 + z1*x2,
            w1*z2 + x1*y2 - y1*x2 + z1*w2
        ])

def quat_to_rotvec(q):
        """Convert a unit quaternion [w, x, y, z] to axis-angle vector (rotvec)."""
        norm_q = q / jp.linalg.norm(q)
        w, xyz = norm_q[0], norm_q[1:]
        sin_half_theta = jp.linalg.norm(xyz)
        angle = 2.0 * jp.arctan2(sin_half_theta, w)
        axis = jp.where(sin_half_theta > 1e-6, xyz / sin_half_theta, jp.zeros_like(xyz))
        return angle * axis

class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=2.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.5, 10.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      target_speed = 1.0,
      heading_reward_weight= 5.0,
      **kwargs,
  ):

    # Option 1: TRAIN WITHOUT HEIGHT FIELD-- plain humanoid
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

    # Option 2: TRAIN WITH HEIGH FIELD-- humanoid with terrain
    # mj_model = mujoco.MjModel.from_xml_path(
    #     (TERRAIN_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model = mujoco.MjModel.from_xml_path(
         (FIXED_ROOT_PATH / 'humanoid_fixed_7.2.xml').as_posix())

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5

    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    # Original Rewards/Costs
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

    # New Parameters
    self._target_speed = target_speed
    self._heading_weight= heading_reward_weight

    #mocap stuff
    self.mocap_dict = parse_bvh_file()
    self.frames = self.mocap_dict["martial_arts"][0]
    self.frames_qvel = []
    self.frames_qpos = []
    for frame in self.frames:
      qpos = frame["qpos"]
      qvel = frame["qvel"]
      self.frames_qpos.append(qpos)
      self.frames_qvel.append(qvel)
    #each row is a frame
    self.frames_qpos = np.array(self.frames_qpos)
    self.frames_qvel = np.array(self.frames_qvel)
    print(self.frames_qpos. shape[0], ", ", self.frames_qpos.shape[1])
    self.frames_qpos = jp.array(self.frames_qpos)
    self.frames_qvel = jp.array(self.frames_qvel)
    self.step_count = 0
    self.key = random.PRNGKey(42)  # Set the seed ONCE
    self.key, subkey = random.split(self.key)
    self.random_frame = random.randint(subkey, shape=(), minval=0, maxval=2000)
    print("rand frame: ", self.random_frame)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    # key = random.PRNGKey(rng3)  # Set the seed
    # key, subkey = random.split(self.key)
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
    random_frame = random.randint(rng1, shape=(), minval=0, maxval=2000)
    #jax.debug.print("Random frame: {}", random_frame)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    #RANDOMLY INITIALIZE TO SOME RANDOM FRAME IN THE MOCAP
    qpos = qpos.at[:].set(self.frames_qpos[random_frame][:])
    qvel = qvel.at[:].set(self.frames_qvel[random_frame][:])

    data = self.pipeline_init(qpos, qvel)

    theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    d_star = jp.stack([jp.cos(theta), jp.sin(theta)])

    state_info = {
        'rng3': rng3,
        'goal': d_star,
        'v_xy': jp.zeros((2,), dtype=jp.float32),
        'frame_count': jp.array(random_frame),
        'mocap_qpos': qpos,
        'mocap_qvel': qvel
    }

    obs = self._get_obs(data, jp.zeros(self.sys.nu), state_info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }

    state = State(data, obs, reward, done, metrics, state_info)

    # Sample a random 2D unit vector for the desired heading
    # theta = jax.random.uniform(rng3, (), minval=0.0, maxval=2 * jp.pi)
    # d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
    # state.info['goal'] = d_star
    return state

  def imit_reward(self, mo_pos, stt_pos, mo_vel, stt_vel):

    mo_pos = jp.array(mo_pos)
    stt_pos = jp.array(stt_pos)

    pose_reward = 0
    vel_reward = 0

    rot_diff = quat_mul(mo_pos[3:7], quat_inv(stt_pos[3:7]))
    diff_angle = quat_to_rotvec(rot_diff)
    pose_reward = jp.sum(diff_angle**2)
    #list of floats representing values in quaternion rotations
    mo_rot_data = mo_pos[7:]
    stt_rot_data = stt_pos[7:]

    # print("mo_rot_data: ", len(mo_rot_data))
    # print("stt_rot_data: ", len(stt_rot_data))

    for r in range(0, len(mo_rot_data)):
        rot_diff = mo_rot_data[r] - stt_rot_data[r]
        pose_reward += (rot_diff ** 2)

        #find difference between mocap rotation for given joint and the state rotation
        #make negative so that when we raise e to it, higher difference leads to smaller reward
        #pose_reward += (diff_quat) ** 2.0



    #multiply by negative number so greater difference = more negative which means smaller fraction e is raised to it
    #maxes out when there's no difference and reward ends up being 1
    pose_reward *= -2.0

    pose_reward = jp.exp(jp.maximum(pose_reward, -50))


    mo_ang_vel = mo_vel[3:]
    stt_ang_vel = stt_vel[3:]
    # mo_ang_vel = mo_vel[6:]
    # stt_ang_vel = stt_vel[6:]


    for v in range(0, len(mo_ang_vel)):
       vel_diff = mo_ang_vel[v] - stt_ang_vel[v]
       vel_reward += (vel_diff ** 2.0)

    vel_reward *= -0.1
    vel_reward = jp.exp(jp.maximum(vel_reward, -50))


    #first 3 values in both lists represent xyz pos of hips
    center_mass_reward = jp.linalg.norm(mo_pos[:3] - stt_pos[:3])
    center_mass_reward *= -10.0
    center_mass_reward = jp.exp(jp.maximum(center_mass_reward, -50))

    #full_reward = (0.7 * pose_reward) + (0.3 * center_mass_reward)
    full_reward = (0.6 * pose_reward) + (0.3 * vel_reward)  + (0.1 * center_mass_reward)
    return full_reward



  def step(self, state: State, action: jp.ndarray) -> State:
    current_frame = state.info['frame_count']
    print("Current Frame:", current_frame)
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    #FOR DEBUGGING MOCAP DATA
    # data = data0
    # hip_pos = data.qpos[:7]
    # rest = self.frames_qpos[current_frame][7:]
    # new_qpos = jp.concatenate([hip_pos, rest])
    # data = self.pipeline_init(new_qpos, data.qvel)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    # velocity = (x_t - x_0) / t
    velocity = (com_after - com_before) / self.dt

    # # New Reward: Target Heading Task Reward
    # d_star    = state.info['goal']
    # v_xy   = velocity[:2] @ d_star # <--- v_xy = [v_x, v_y] @ [d_x, d_y]
    # speed_err = jp.maximum(0.0, self._target_speed - v_xy)
    # heading_r = jp.exp(-2.5 * speed_err**2)
    # forward_reward = self._heading_weight * heading_r
    # state.info['v_xy'] = velocity[:2]


    ######################################### walking in circle forward reward
    # hip_quat = data.x.rot[0]
    # forward_world = math.rotate(jp.array([1.0, 0.0, 0.0]), hip_quat) #rotate forward axis which is x to match up with quaternion rotation so this is characters forward
    # f_xy = forward_world[:2]
    # f_xy = f_xy/(jp.linalg.norm(f_xy) + 1e-6)#normalize

    # v_xy = velocity[0:2]/(jp.linalg.norm(velocity[0:2]) + 1e-6)
    # dir_rew = v_xy @ f_xy #forward direction dotted with velocity direction
    # forward_reward = self._forward_reward_weight * dir_rew
    #########################################

    # Zero Reward -- for martial arts one
    forward_reward = 0.0

    # Old Reward -- velocity[0] is x component of velocity {vx, vy, vz}
    #forward_reward = self._forward_reward_weight * (velocity/jp.linalg.norm(velocity))[0] #forward_reward = 1.25 * velocity in the x direction

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] < (min_z + .3), 0.0, is_healthy)
    is_healthy = jp.where(data.xpos[2][2] > (max_z + .3), 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    frame_count = state.info['frame_count'] + 1
    state_info = dict(state.info)
    state_info['frame_count'] = frame_count

    obs = self._get_obs(data, action, state.info)

    #IMITATION REWARD!!!!!!
    max_frame = self.frames_qpos.shape[0] - 1

    mocap_qpos = self.frames_qpos[0]
    mocap_qvel = self.frames_qvel[0]
    mocap_qpos = jp.zeros_like(mocap_qpos)
    mocap_qvel = jp.zeros_like(mocap_qvel)
    def compute_reward(_):
        mocap_qpos = self.frames_qpos[current_frame]
        mocap_qvel = self.frames_qvel[current_frame]
        return self.imit_reward(mocap_qpos, data.qpos, mocap_qvel, data.qvel)

    def zero_reward(_):
        return 0.0

    imit_reward = lax.cond(
        current_frame < max_frame,
        compute_reward,
        zero_reward,
        operand=None
    )

    state_info['mocap_qpos'] = mocap_qpos
    state_info['mocap_qvel'] = mocap_qvel

    #GIVE IT SOME REWARD FOR NOT BENDING IT'S TORSO MUCH
    torso_rot = data.qpos[7:10]
    stiffness_reward = 0
    for r in torso_rot:
      stiffness_reward += r**2
    stiffness_reward = jp.exp(stiffness_reward * -2)


    #ADD IT ALL UP
    #reward = forward_reward + healthy_reward + stiffness_reward + (3.0 * imit_reward) - ctrl_cost

    # ADD IT ALL UP (FOR MARTIAL ARTS) ---
    reward = healthy_reward + stiffness_reward + (6.0 * imit_reward) - ctrl_cost


    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )


    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state_info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, state_info: dict[str, Any],
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""

    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # print("Print data.qpos in _get_obs:", data.qpos)
    # print("Print data.qvel in _get_obs:", data.qvel)
    # # mass and inertia tensor in the center of mass (COM) frame.
    # print("Print flattened data.cinert in _get_obs:", data.cinert.ravel())
    # print("Print flattened data.cvel in _get_obs:", data.cvel.ravel())
    # print("Print data.qfrc_actuator in _get_obs:", data.qfrc_actuator)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
        state_info['mocap_qpos'],
        state_info['mocap_qvel']
        # state_info['goal'],
        # state_info['v_xy'],
    ])


# register env class we just made
envs.register_environment('humanoid', Humanoid)

# set the env as this humanoid env
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions to put it on the GPU
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Train


In [None]:
#@title Train
train_fn = functools.partial(
    ppo.train, num_timesteps=5_000_000, num_evals=5, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])


  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']

  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Train whilst Saving Checkpoints
ckpt_path = epath.Path('/tmp/humanoid_imitating_hips_120_martial_arts_7.2/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def save_policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

train_fn = functools.partial(
    ppo.train, num_timesteps=10_000_000, num_evals=20, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, policy_params_fn=save_policy_params_fn, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']

  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Save Checkpoints to Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!cp -r /tmp/humanoid_imitating_hips_120_martial_arts_7.2/ckpts /content/drive/MyDrive/humanoid_72_martial_arts

Mounted at /content/drive


In [None]:
#@title Train from the Last Checkpoint
ckpt_path = epath.Path('/content/drive/MyDrive/humanoid_72_martial_arts')
ckpt_path.mkdir(parents=True, exist_ok=True)

latest_ckpts = list(ckpt_path.glob('*'))
latest_ckpts.sort(key=lambda x: int(x.as_posix().split('/')[-1]))
latest_ckpt = latest_ckpts[-1]

# If last ckpt is in google drive
# ckpt_path = pathlib.Path('/content/drive/MyDrive/humanoid_training/ckpts')
# latest_ckpt = sorted(ckpt_path.iterdir(), key=lambda p: int(p.name))[-1]

save_path_for_training_from_checkpoint = epath.Path('/tmp/humanoid_imitating_hips_72_40_M/ckpts')
save_path_for_training_from_checkpoint.mkdir(parents=True, exist_ok=True)

def save_policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = save_path_for_training_from_checkpoint / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


train_fn = functools.partial(
    ppo.train, num_timesteps=30_000_000, num_evals=20, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, policy_params_fn=save_policy_params_fn, seed=0, restore_checkpoint_path=latest_ckpt)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']

  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print("Done.")

Current Frame: Traced<ShapedArray(int32[])>with<BatchTrace> with
  val = Traced<ShapedArray(int32[128])>with<DynamicJaxprTrace>
  batch_dim = 0
Step: 0 | Eval Reward: 522.82 ± 129.33
Current Frame: Traced<ShapedArray(int32[])>with<BatchTrace> with
  val = Traced<ShapedArray(int32[3072])>with<DynamicJaxprTrace>
  batch_dim = 0
Step: 1597440 | Eval Reward: 538.78 ± 163.26
Step: 3194880 | Eval Reward: 571.99 ± 158.18
Step: 4792320 | Eval Reward: 582.93 ± 221.38
Step: 6389760 | Eval Reward: 605.13 ± 194.41
Step: 7987200 | Eval Reward: 638.58 ± 193.96
Step: 9584640 | Eval Reward: 624.46 ± 147.03
Step: 11182080 | Eval Reward: 608.33 ± 193.64
Step: 12779520 | Eval Reward: 629.80 ± 159.70
Step: 14376960 | Eval Reward: 648.22 ± 207.79
Step: 15974400 | Eval Reward: 643.68 ± 197.86
Step: 17571840 | Eval Reward: 645.43 ± 201.60
Step: 19169280 | Eval Reward: 681.43 ± 252.31
Step: 20766720 | Eval Reward: 662.79 ± 189.63
Step: 22364160 | Eval Reward: 671.14 ± 240.23
Step: 23961600 | Eval Reward: 544.

In [None]:
#@title Save Checkpoints to Drive
drive.mount('/content/drive', force_remount=True)
!cp -r /tmp/humanoid_imitating_hips_72_40_M/ckpts /content/drive/MyDrive/hi_ya_40_000_000

Mounted at /content/drive


In [None]:
#@title Save Trained Model
model_path = '/tmp/humanoid_fixed_7_hi_ya_40_M'
model.save_params(model_path, params)
from google.colab import files
files.download(model_path)

# Render

In [None]:
#@title Visualize Currently Activated Env
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(2300):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  print("qpos:", state.pipeline_state.qpos)
  #print("qpos_length:", len(state.pipeline_state.qpos))
  #print("qvel:", state.pipeline_state.qvel)
  #print("qvel_length:", len(state.pipeline_state.qvel))

media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)

In [None]:
#@title Render Trained Policy
# NOTE: Make sure to play/activate the current Humanoid Env you use to train the policy you wanted to render, so there is not shape mismatches.
train_fn = functools.partial(
    ppo.train, num_timesteps=0, num_evals=0, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])
  reward = metrics['eval/episode_reward']
  reward_std = metrics['eval/episode_reward_std']
  print(f"Step: {num_steps} | Eval Reward: {reward:.2f} ± {reward_std:.2f}")
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print('Load Model...')
model_path = '/tmp/humanoid_fixed_7_hi_ya_40_M'
# model_path = '/content/drive/MyDrive/with_hip_data_imitation_after_good_walk_longer'
# model_path = '/content/drive/MyDrive/run_policy'

params = model.load_params(model_path)

inference_fn = make_inference_fn(params) # <----- plugs the weights in

jit_inference_fn = jax.jit(inference_fn)

eval_env = envs.get_environment('humanoid')

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
rng = jax.random.PRNGKey(4)
state = jit_reset(rng)

# theta = jp.deg2rad(0)
# d_star = jp.stack([jp.cos(theta), jp.sin(theta)])
# state.info['goal'] = d_star

rollout = [state.pipeline_state] # <---- initliazie rollout list with the first reset state

n_steps = 5000
render_every = 2

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng) # this is where your model is making the inference!
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

  if state.done:
    break

media.show_video(env.render(rollout[::render_every], camera='side', height=1080, width=1920), fps=1.0 / env.dt / render_every)

print("Done.")