In [None]:
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# %%
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import pathlib

# set current directory: mujoco_mpc/python/mujoco_mpc
from mujoco_mpc import agent as agent_lib


In [None]:

# %matplotlib inline
base_path = pathlib.Path('/Users/eabe/Research/MyRepos/mujoco_mpc/')
# %%
# model
model_path = base_path / 'Archive/fruitfly/flytracking/task.xml'
    # pathlib.Path('/Users/eabe/Research/MyRepos/BiomechControl/models/fruitfly_v2/flybody/fruitfly/assets/fruitfly.xml')

model = mujoco.MjModel.from_xml_path(model_path.as_posix())

# data
data = mujoco.MjData(model)

# renderer
renderer = mujoco.Renderer(model)

# %%
# agent
agent = agent_lib.Agent(task_id="Fruitfly Stand", model=model)
# agent = agent_lib.Agent(task_id="Cartpole", model=model)

# weights
# agent.set_cost_weights({"Velocity": 0.15})
print("Cost weights:", agent.get_cost_weights())

# parameters
agent.set_task_parameter("Height Goal", .1)
print("Parameters:", agent.get_task_parameters())

# %%
# rollout horizon
T = 1500

# trajectories
qpos = np.zeros((model.nq, T))
qvel = np.zeros((model.nv, T))
ctrl = np.zeros((model.nu, T - 1))
time = np.zeros(T)

# costs
cost_total = np.zeros(T - 1)
cost_terms = np.zeros((len(agent.get_cost_term_values()), T - 1))

# rollout
mujoco.mj_resetData(model, data)

# cache initial state
qpos[:, 0] = data.qpos
qvel[:, 0] = data.qvel
time[0] = data.time

# frames
frames = []
FPS = 1.0 / model.opt.timestep


In [None]:
data.mocap_pos = np.array([[ 0.03180882,  0.02074661, -0.02898251],
       [ 0.02675209,  0.0303128 , -0.06901883],
       [ 0.08600139,  0.07803469, -0.08055862],
       [ 0.10179283,  0.09320089, -0.1285635 ],
       [ 0.18741733,  0.13553   , -0.10618452],
       [ 0.02449044, -0.01045236, -0.02746203],
       [ 0.00544831, -0.00954052, -0.05942646],
       [ 0.05213638, -0.04126289, -0.04300317],
       [ 0.04389725, -0.03507727, -0.09113092],
       [ 0.12152389, -0.0347662 , -0.1136906 ],
       [-0.00334122,  0.00833813, -0.02897212],
       [-0.02852611,  0.0054291 , -0.05745497],
       [-0.02857925,  0.08880582, -0.01390682],
       [-0.03397397,  0.1140864 , -0.07873123],
       [-0.05046751,  0.17704488, -0.12401552],
       [-0.00794117, -0.00762688, -0.02882023],
       [-0.03114363, -0.01511973, -0.0537156 ],
       [-0.04529041, -0.10126639, -0.03102564],
       [-0.03436553, -0.11175767, -0.09636242],
       [-0.03460989, -0.17275089, -0.15247021],
       [-0.03451603,  0.00276523, -0.02583813],
       [-0.05442001,  0.00890758, -0.05109516],
       [-0.03822012,  0.07392365,  0.0150405 ],
       [-0.07363825,  0.08524816, -0.05112792],
       [-0.1031378 ,  0.162298  , -0.09543193],
       [-0.0371287 , -0.01212705, -0.02613554],
       [-0.05845902, -0.0220647 , -0.04891861],
       [-0.11405964, -0.09623906, -0.02873848],
       [-0.17524566, -0.08927459, -0.07793304],
       [-0.26040928, -0.12538023, -0.10520066]])

In [None]:
import PIL.Image
renderer.update_scene(data)
pixels = renderer.render()
PIL.Image.fromarray(pixels)

In [None]:

# simulate
for t in range(T - 1):
  if t % 100 == 0:
    print("t = ", t)

  # set planner state
  agent.set_state(
      time=data.time,
      qpos=data.qpos,
      qvel=data.qvel,
      act=data.act,
      mocap_pos=data.mocap_pos,
      mocap_quat=data.mocap_quat,
      userdata=data.userdata,
  )

  # run planner for num_steps
  num_steps = 10
  for _ in range(num_steps):
    agent.planner_step()

  # get costs
  cost_total[t] = agent.get_total_cost()
  for i, c in enumerate(agent.get_cost_term_values().items()):
    cost_terms[i, t] = c[1]

  # set ctrl from agent policy
  data.ctrl = agent.get_action()
  ctrl[:, t] = data.ctrl

  # step
  mujoco.mj_step(model, data)

  # cache
  qpos[:, t + 1] = data.qpos
  qvel[:, t + 1] = data.qvel
  time[t + 1] = data.time

  # render and save frames
  renderer.update_scene(data)
  pixels = renderer.render()
  frames.append(pixels)

# reset
agent.reset()


In [None]:

# display video
SLOWDOWN = 0.5
media.show_video(frames, fps=SLOWDOWN * FPS)

# %%
# plot position
fig = plt.figure()

plt.plot(time, qpos[0, :], label="q0", color="blue")
plt.plot(time, qpos[1, :], label="q1", color="orange")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Configuration")

# %%
# plot velocity
fig = plt.figure()

plt.plot(time, qvel[0, :], label="v0", color="blue")
plt.plot(time, qvel[1, :], label="v1", color="orange")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Velocity")

# %%
# plot control
fig = plt.figure()

plt.plot(time[:-1], ctrl[0, :], color="blue")

plt.xlabel("Time (s)")
plt.ylabel("Control")

# %%
# plot costs
fig = plt.figure()

for i, c in enumerate(agent.get_cost_term_values().items()):
  plt.plot(time[:-1], cost_terms[i, :], label=c[0])

plt.plot(time[:-1], cost_total, label="Total (weighted)", color="black")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Costs")
