In [1]:
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# %%
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import pathlib
import PIL.Image
from tqdm.auto import tqdm
# set current directory: mujoco_mpc/python/mujoco_mpc
from mujoco_mpc import agent as agent_lib


In [2]:

# %matplotlib inline
base_path = pathlib.Path('/Users/eabe/Research/MyRepos/mujoco_mpc/')
# %%
# model
model_path = base_path / 'mjpc/tasks/fruitfly/flytrackingqpos/task.xml'
# model_path = base_path / 'build/mjpc/tasks/cartpole/task.xml'
    # pathlib.Path('/Users/eabe/Research/MyRepos/BiomechControl/models/fruitfly_v2/flybody/fruitfly/assets/fruitfly.xml')

model = mujoco.MjModel.from_xml_path(model_path.as_posix())

# data
data = mujoco.MjData(model)

# renderer
renderer = mujoco.Renderer(model)

# %%
# agent
agent = agent_lib.Agent(task_id="Fruitfly Track", model=model)
# agent = agent_lib.Agent(task_id="Cartpole", model=model)

# weights
# agent.set_cost_weights({"Velocity": 0.15})
# print("cost_weights", agent.get_cost_weights())

# parameters
# agent.set_task_parameter("Height Goal", .1)
# print("Parameters:", agent.get_task_parameters())

# %%
# rollout horizon
T = 1500

# trajectories
qpos = np.zeros((model.nq, T))
qvel = np.zeros((model.nv, T))
ctrl = np.zeros((model.nu, T - 1))
time = np.zeros(T)

# costs
cost_total = np.zeros(T - 1)
cost_terms = np.zeros((len(agent.get_cost_term_values()), T - 1))

# rollout
mujoco.mj_resetData(model, data)

# cache initial state
qpos[:, 0] = data.qpos
qvel[:, 0] = data.qvel
time[0] = data.time

# frames
frames = []
FPS = 1.0 / model.opt.timestep


I0000 00:00:1720566547.646397 14002044 agent_server.cc:55] Server listening on [::]:51026


In [10]:

from dm_control.mujoco.wrapper.mjbindings import enums
import PIL.ImageDraw

scene_option = mujoco.wrapper.core.MjvOption()
scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 1]
scene_option.flags[enums.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
renderer.update_scene(data)
pixels = renderer.render(camera_id=1, width=640, height=480,scene_option=scene_option)
PIL.Image.fromarray(pixels)

array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [3]:

# simulate
for t in tqdm(range(T - 1)):
  if t % 100 == 0:
    print("t = ", t)

  # set planner state
  agent.set_state(
      time=data.time,
      qpos=data.qpos,
      qvel=data.qvel,
      act=data.act,
      mocap_pos=data.mocap_pos,
      mocap_quat=data.mocap_quat,
      userdata=data.userdata,
  )

  # run planner for num_steps
  num_steps = 10
  for _ in range(num_steps):
    agent.planner_step()

  # get costs
  cost_total[t] = agent.get_total_cost()
  for i, c in enumerate(agent.get_cost_term_values().items()):
    cost_terms[i, t] = c[1]

  # set ctrl from agent policy
  data.ctrl = agent.get_action()
  ctrl[:, t] = data.ctrl

  # step
  mujoco.mj_step(model, data)

  # cache
  qpos[:, t + 1] = data.qpos
  qvel[:, t + 1] = data.qvel
  time[t + 1] = data.time

  # render and save frames
  renderer.update_scene(data)
  pixels = renderer.render()
  frames.append(pixels)

# reset
agent.reset()


In [None]:

# display video
# SLOWDOWN = 0.5
# media.show_video(frames, fps=SLOWDOWN * FPS)


In [None]:

# %%
# plot position
fig = plt.figure()

plt.plot(time, qpos[0, :], label="q0", color="blue")
plt.plot(time, qpos[1, :], label="q1", color="orange")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Configuration")

# %%
# plot velocity
fig = plt.figure()

plt.plot(time, qvel[0, :], label="v0", color="blue")
plt.plot(time, qvel[1, :], label="v1", color="orange")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Velocity")

# %%
# plot control
fig = plt.figure()

plt.plot(time[:-1], ctrl[0, :], color="blue")

plt.xlabel("Time (s)")
plt.ylabel("Control")

# %%
# plot costs
fig = plt.figure()

for i, c in enumerate(agent.get_cost_term_values().items()):
  plt.plot(time[:-1], cost_terms[i, :], label=c[0])

plt.plot(time[:-1], cost_total, label="Total (weighted)", color="black")

plt.legend()
plt.xlabel("Time (s)")
plt.ylabel("Costs")


In [None]:
plt.imshow(frames[1])