Copyright 2022 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Demo rollout for real cube trajectory from ContactNets dataset (Pfrommer et al, 2020)

In [None]:
#@title Installation (skip if running locally)
# Note, this should be skipped if running locally.
!mkdir /content/gnn_single_rigids
!mkdir /content/gnn_single_rigids/src
!touch /content/gnn_single_rigids/__init__.py
!touch /content/gnn_single_rigids/src/__init__.py

!wget -O /content/gnn_single_rigids/src/graph_network.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/graph_network.py
!wget -O /content/gnn_single_rigids/src/learned_simulator.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/learned_simulator.py
!wget -O /content/gnn_single_rigids/src/utils.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/utils.py
!wget -O /content/gnn_single_rigids/src/normalizers.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/normalizers.py
!wget -O /content/gnn_single_rigids/src/meshtools.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/meshtools.py
!wget -O /content/gnn_single_rigids/src/rollout.py https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/src/rollout.py

!wget -O /content/requirements.txt https://raw.githubusercontent.com/deepmind/master/gnn_single_rigids/requirements.txt
!pip install -r requirements.txt

In [None]:
#@title Download weights from google cloud storage if running in colab 
#(skip if running locally)
from google.colab import auth
auth.authenticate_user()

!gsutil cp gs://dm_gnn_single_rigids/example_real_toss.pkl .
!gsutil cp gs://dm_gnn_single_rigids/gns_params.pkl .

In [None]:
#@title imports
import pickle
import functools
import jraph
import numpy as np
import scipy

from gnn_single_rigids.src import utils
from gnn_single_rigids.src import learned_simulator
from gnn_single_rigids.src import rollout
from gnn_single_rigids.src import meshtools

In [None]:
#@title convert data formatted as [translation, quaternion] for consecutive 
#time-steps into sequence of graphs tuples
def convert_states_to_graph_sequence(src_data,
                                     cube_size=2.0,
                                     num_history_frames=3):
  """Converts dataframe to sequence of graphs tuples."""
  # first get the mesh for the cube
  box = meshtools.transform(meshtools.make_unit_box(), scale=cube_size)
  verts = box.verts
  faces = meshtools.triangulate(box.faces)
  # get edges between nodes depending on faces
  mesh_senders, mesh_receivers = meshtools.triangles_to_edges(faces)

  # convert nodes in mesh + translation / rotation of center of mass 
  # into sequence of node positions
  num_frames = len(src_data)
  node_pos = []
  for frame_idx in range(num_frames):
    trans_com = src_data[frame_idx][0:3] * 2.0 / cube_size
    wijk_quat = src_data[frame_idx][3:7]

    # scipy uses ijkw so transform before applying to vertices
    rot_com = scipy.spatial.transform.Rotation.from_quat(
        wijk_quat[[1, 2, 3, 0]])
    node_pos.append(rot_com.apply(verts) + trans_com)

  # create sequence of graphs tuples with the appropriate history length
  graph_sequence = []
  for frame_idx in range(num_history_frames, num_frames):
    node_features = {
        "external_mask": np.zeros((verts.shape[0],)),
        "world_position": np.stack(node_pos[frame_idx - 
                                            num_history_frames:frame_idx], 
                                   axis=1),
        "mesh_position": verts,
    }
    graph = jraph.GraphsTuple(
        n_node=np.array([verts.shape[0]]),
        n_edge=np.array([mesh_senders.shape[0]]),
        nodes=node_features, edges={}, senders=mesh_senders, 
        receivers=mesh_receivers, globals=np.array([[]]))
    graph_sequence.append(graph)

  return graph_sequence

In [None]:
#@title Load data
with open("gns_params.pkl", "rb") as f:
  pickled_data = pickle.loads(f.read())
  network = pickled_data["network"]
  state = pickled_data["state"]
  plan = pickled_data["plan"]

# tosses contains 1 example real cube tossing trajectory from 
# Pfrommer et al CoRL 2020
with open("example_real_toss.pkl", "rb") as f:
  tosses_data = pickle.loads(f.read())


In [None]:
#@title convert data to sequence of graphs tuples
cube_size = 2.0
graph_sequence = convert_states_to_graph_sequence(tosses_data[0], 
                                                  cube_size=cube_size, 
                                                  num_history_frames=3)

In [None]:
#@title make a predicted rollout with graph network simulator
flatten_fn = functools.partial(utils.flatten_features, **plan['flatten_kwargs'])
haiku_model = functools.partial(learned_simulator.LearnedSimulator, 
                                flatten_features_fn=flatten_fn, 
                                graph_network_kwargs=plan['graph_network_kwargs'])

p_rollout = rollout.get_predicted_trajectory(graph_sequence, 
                                             {'state': state, 'params':network}, 
                                             haiku_model, utils.forward_graph)

In [None]:
#@title visualize ground truth compared to predicted
import plotly.graph_objects as go

# initialize cube and plotting elements
gt_sequence = graph_sequence
rollout_sequence = p_rollout
elements = []
frames = []
sliders = None
box = meshtools.transform(meshtools.make_unit_box(), scale=cube_size)

def plot_mesh(data, mesh, color, flat=True, opacity=1.0):
  # plots the cube mesh
  tris = meshtools.triangulate(mesh.faces)
  data.append(go.Mesh3d(
      flatshading=flat,
      lighting=dict(fresnel=0, specular=0, ambient=0.2, diffuse=1),
      color=color,
      opacity=opacity,
      lightposition=dict(x=-10, y=-10, z=50),
      i=tris[:, 0], j=tris[:, 1], k=tris[:, 2],
      x=mesh.verts[:, 0], y=mesh.verts[:, 1], z=mesh.verts[:, 2]))

def plot_floor(data, mesh, color):
  # plots the floor as a solid surface
  floor_verts = mesh.verts[:, 2] <= mesh.verts[:, 2].min() + 1e-6
  floor_faces = np.all(floor_verts[mesh.faces], axis=1)
  plot_mesh(data, meshtools.Mesh(mesh.verts, mesh.faces[floor_faces]), color)

slider_steps = []
b_min = np.ones(3)*1000
b_max = np.ones(3)*-1000

for frame_idx in range(0, len(gt_sequence)):
  gt_frame = gt_sequence[frame_idx]
  p_frame = rollout_sequence[frame_idx]
  gt_pos = gt_frame.nodes['world_position'][:, -1]
  p_pos = p_frame.nodes['world_position'][:, -1]

  # track the minimum / maximum for plotting
  b_min = np.minimum(b_min, np.minimum(gt_pos.min(0), p_pos.min(0)))
  b_max = np.maximum(b_max, np.maximum(gt_pos.max(0), p_pos.max(0)))

  # make meshes for plotting
  gt_mesh = meshtools.Mesh(verts=gt_pos, faces=box.faces)
  p_mesh = meshtools.Mesh(verts=p_pos, faces=box.faces)

  data = []

  # plot meshes
  plot_mesh(data, gt_mesh, "red", flat=False, opacity=0.2)
  plot_mesh(data, p_mesh, "royalblue", flat=False, opacity=1)

  # append to frames
  frames.append(go.Frame(name=frame_idx, data=data))
  slider_steps.append(dict(
      method="animate",
      label=frame_idx,
      args=[[frame_idx], dict(
          frame=dict(duration=0),
          mode="immediate",
          transition=dict(duration=0))]))
  sliders = [dict(
      transition=dict(duration=0, easing="linear"),
      steps=slider_steps)]

# make the floor and plot it
elements.extend(frames[0]["data"])
b_min[2] = 0
b_center = 0.5 * (b_min + b_max)
b_center[2] += 1e-3
b_size = b_max - b_min
domain = meshtools.transform(meshtools.make_unit_box(),
                             scale=b_size, translate=b_center)
plot_floor(elements, domain, "lightgray")

# create figure for viewing
fig = go.Figure(data=elements, frames=frames, layout=go.Layout(
    scene_aspectmode="data",
    showlegend=False,
    width=800,
    height=800,
    sliders=sliders,
    scene=dict(camera=dict(up=dict(x=0, y=0, z=1))),
))
fig.show()