Copyright 2022 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Demo design optimization for 3D WaterCourse environment


In [None]:
#@title Installation
!mkdir /content/inverse_design
!mkdir /content/inverse_design/src
!touch /content/inverse_design/__init__.py
!touch /content/inverse_design/src/__init__.py

!wget -O /content/inverse_design/src/connectivity_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/connectivity_utils.py
!wget -O /content/inverse_design/src/graph_network.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/graph_network.py
!wget -O /content/inverse_design/src/learned_simulator.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/learned_simulator.py
!wget -O /content/inverse_design/src/model_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/model_utils.py
!wget -O /content/inverse_design/src/normalizers.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/normalizers.py
!wget -O /content/inverse_design/src/watercourse_env.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/watercourse_env.py

!wget -O /content/requirements.txt https://raw.githubusercontent.com/deepmind/master/inverse_design/requirements.txt
!pip install -r requirements.txt

In [None]:
#@title Imports
from inverse_design.src import learned_simulator
from inverse_design.src import model_utils
from inverse_design.src import watercourse_env

In [None]:
#@title Load Pickled Dataset & Params
import pickle
from google.colab import auth
auth.authenticate_user()

!gsutil cp gs://dm_inverse_design_watercourse/init_sequence.pickle .
!gsutil cp gs://dm_inverse_design_watercourse/gns_params.pickle .

with open('init_sequence.pickle', "rb") as f:
  pickled_data = pickle.loads(f.read())
  gt_sequence = pickled_data['gt_sequence']
  meta = pickled_data['meta']

with open('gns_params.pickle', "rb") as f:
  pickled_params = pickle.loads(f.read())
  network = pickled_params['network']
  plan = pickled_params['plan']

In [None]:
#@title make ID control/loss functions
import jax
import functools

# maximum number of edges for single step of rollout to pad to
MAX_EDGES = 2**16

# define haiku model
connectivity_radius = meta["connectivity_radius"]
flatten_fn = functools.partial(model_utils.flatten_features, **plan['flatten_kwargs'])
haiku_model = functools.partial(learned_simulator.LearnedSimulator, connectivity_radius=connectivity_radius, flatten_features_fn=flatten_fn, **plan['model_kwargs'])

# create initial landscape (obstacle) in the scene
obstacle_pos = watercourse_env.make_plain_obstacles()
for frame in gt_sequence:
  pos = frame.nodes['world_position'].copy()
  pos[:obstacle_pos.shape[0]] = obstacle_pos[:, None]
  frame.nodes['world_position'] = pos


# get initial sequence of particles from dataset for initial graph
obstacle_edges, inflow_stack, initial_graph = watercourse_env.build_initial_graph(gt_sequence[15:], max_edges=MAX_EDGES)

# infer the landscape size from the dataset (25 x 25)
# note that this is not required, it is also possible to create a smaller
# or larger landscape (obstacle) as the design space
num_side = int(jax.numpy.sqrt(initial_graph.nodes['obstacle_mask'].sum()))
n_obs = num_side**2

# rollout length definition (final state taken for reward computation)
length = 50
# radius within which to connect particles
radius = 0.1
# smoothing factor for loss
smoothing_factor = 1e2

@jax.jit
def run(vars):
  # create landscape as graph from vars parameters
  graph, raw_obs = watercourse_env.design_fn(vars, initial_graph)

  # rollout
  final_graph, traj = watercourse_env.rollout(
      graph, inflow_stack[:length], network, haiku_model,
      obstacle_edges, radius=radius)
  
  # losses
  losses = {
      'objective': watercourse_env.max_x_loss_fn(final_graph),
      'smooth': smoothing_factor * watercourse_env.smooth_loss_fn(raw_obs, num_side=num_side),
  }

  # auxiliaries to keep track of for plotting
  aux = {
      'design': vars,
      'losses': losses,
      'traj': traj
  }
  return sum(losses.values()), aux


In [None]:
from IPython.display import clear_output
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

# set learning rate and number of optimization steps
LEARNING_RATE = 0.05
num_opt_steps = 100

# define optimizer as adam with learning rate
optimizer = optax.adam(learning_rate=LEARNING_RATE)

# initialize design parameters to be zeros (flat landscape)
params = jnp.zeros(n_obs, dtype=jnp.float32)
opt_state = optimizer.init(params)

# initialize empty optimization trajectory (for tracking improvements to losses and design)
opt_traj = []

# optimization step with current design parameters
@jax.jit
def opt_step(params, opt_state):
  grads, aux = jax.grad(run, has_aux=True)(params)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, aux

# run optimization loop and track progress
for i in range(num_opt_steps):
  params, opt_state, aux = opt_step(params, opt_state)
  opt_traj.append(aux)
  clear_output(wait=True)
  fig, ax = plt.subplots(1,1,figsize=(10,5))
  for key in aux['losses'].keys():
    ax.plot([t['losses'][key] for t in opt_traj])
  ax.plot([sum(t['losses'].values()) for t in opt_traj])
  ax.legend(list(aux['losses'].keys())+['total'])
  plt.show()

In [None]:
import numpy as np

# plot design iterations (every 10 steps)
n_sam = range(0, len(opt_traj), 10)
fig, ax = plt.subplots(1,len(n_sam),figsize=(len(n_sam)*10, 10), squeeze=False)

for fi, idx in enumerate(n_sam):
  design = opt_traj[idx]['design']

  # control function uses tanh as transformation, so mimic here to see heightfield
  fld = np.tanh(design.reshape((num_side, num_side)))
  ax[0, fi].imshow(fld, vmin=-1, vmax=1)
  ax[0, fi].set_axis_off()

In [None]:
from IPython.display import clear_output
# plot video of how particles move for optimized design and initial design

def _plt(ax, frame, i):
  pos = frame['pos'][i]  
  p = pos[frame['mask'][i]]
  ax.scatter(p[:, 0], p[:, 2], p[:, 1], c='b',s=10)
  obs = pos[:num_side**2]
  ax.scatter(obs[:, 0], obs[:, 2], obs[:, 1], c='k',s=3)
  ax.scatter([1.5],[1.5],[0], c='g',s=20)
  ax.set_xlim([-0.6, 1.6])
  ax.set_ylim([-0.1, 1.6])
  ax.set_zlim([-0.1, 1.2])

roll_fin0 = run(opt_traj[0]['design'])[1]['traj']
roll_fin1 = run(opt_traj[-1]['design'])[1]['traj']

for i in range(roll_fin0['pos'].shape[0]):
  clear_output(wait=True)
  fig = plt.figure(figsize=(20,10))
  ax1 = fig.add_subplot(1, 2, 1, projection='3d')
  ax1.set_title('Initial design, frame %d' % i)
  _plt(ax1, roll_fin0, i)
  ax2 = fig.add_subplot(1, 2, 2, projection='3d')
  ax2.set_title('Design at final step')
  _plt(ax2, roll_fin1, i)
  plt.show()