Copyright 2023 Google LLC.

SPDX-License-Identifier: Apache-2.0

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Instructions

Follow steps 2 - 4 [here](https://research.google.com/colaboratory/local-runtimes.html)

In [None]:
!nvidia-smi

# Notebook specific installs

In [None]:
!pip3 install opencv-python h5py tensorboard
!pip install moviepy --upgrade

In [None]:
import os
import flax.linen as nn
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import tqdm

import chex
import pickle

import datetime
import flax.jax_utils as flax_utils

import matplotlib.pyplot as plt
from importlib import reload

In [None]:
num_devices = jax.local_device_count()
jax.devices()

In [None]:
from hct import ndp
from hct.common import utils

In [None]:
import h5py

In [None]:
# Download file to local folder first
with h5py.File('episode_1.hdf5', "r") as f:
  f.visit(lambda name: print(name))

In [None]:
with h5py.File('episode_1.hdf5', "r") as f:
  actions = np.array(f['action'])
  images = np.array(f['observations']['images']['top'])
  qpos = np.array(f['observations']['qpos'])
  qvel = np.array(f['observations']['qvel'])

In [None]:
print(actions.shape, images.shape, qpos.shape, qvel.shape)
action_dim = actions.shape[1]

In [None]:
from moviepy.editor import ImageSequenceClip

In [None]:
clip = ImageSequenceClip([img for img in images], fps=50)
clip.ipython_display(fps=50)

## Pre-process data: raw_data --> raw_norm_data --> hct_data

In [None]:
import cv2

In [None]:
# Resize images & normalize all data
images_resized = []
for img in images:
  images_resized.append(cv2.resize(img, (320, 240)))

hf_obs = np.concatenate((qpos, qvel), axis=1)

raw_data = {'images': np.array(images_resized),
            'hf_obs': hf_obs,
            'actions': actions}

norm_stats = utils.compute_norm_stats(raw_data)

raw_norm_data = utils.normalize(raw_data, *norm_stats)

In [None]:
"""Batch data into following format:
for t = 0, 1, ....
example = {
  'image': s_t,
  'actions': U_t: (num_actions, action_dim) = [u_t(tau_0), ..., u_t(tau_{M-1})]
  'hf_obs': X_t: (num_actions, num_hf_obs_per_action+1, x_dim) = [x_t^0, ..., x_t^{M-1}]
}
where X_t[0][-1] = x_t(0), coinciding with s_t and U_t[0]
"""
norm_images = raw_norm_data['images']
norm_hf_obs = raw_norm_data['hf_obs']
norm_actions = raw_norm_data['actions']

# Set constants
num_actions = 5
num_hf_obs_per_action = 1  # current dataset has action-freq = hf_state-freq

# Pre-tile hf_obs array with first observation
init_tile = np.array([norm_hf_obs[0]]*num_hf_obs_per_action)
norm_hf_obs = np.concatenate((init_tile, norm_hf_obs), axis=0)

hct_data = {
    'images': [],
    'hf_obs': [],
    'actions': []
}

# Step through control actions (can also step through images if that is the base frequency)
final_idx = len(norm_actions) - num_actions
for idx in range(0, final_idx+1, num_actions):
  hct_data['images'].append(norm_images[idx])  # current dataset has image-freq = control-freq
  hct_data['actions'].append(norm_actions[idx:idx+num_actions,...])

  hf_obs_idx = num_hf_obs_per_action + idx * num_hf_obs_per_action

  state_obs = []
  for _ in range(num_actions):
    state_obs.append(norm_hf_obs[hf_obs_idx-num_hf_obs_per_action:hf_obs_idx+1,...])
    hf_obs_idx += num_hf_obs_per_action

  hct_data['hf_obs'].append(np.stack(state_obs, axis=0))

In [None]:
hct_data = {key: np.array(data) for key, data in hct_data.items()}
print([(key, arr.shape) for key, arr in hct_data.items()])

In [None]:
fig, axs = plt.subplots(7, 2, figsize=(10, 20))

for i in range(action_dim):
  row, col = np.unravel_index(i, (7, 2))
  for t, chunk in enumerate(hct_data['actions']):
    _ = axs[row][col].plot(np.arange(t*num_actions, (t+1)*num_actions), chunk[:, i])

## Setup data-loader

In [None]:
# For NDP - only use x_t(0)
hct_data['hf_obs'] = hct_data['hf_obs'][:, 0, -1, ...]

In [None]:
data_prng = hk.PRNGSequence(jax.random.PRNGKey(654321))
num_data = len(hct_data['images'])
shuffle_idxs = jax.random.permutation(next(data_prng), num_data)
train_ratio = 0.8
num_train = int(train_ratio * num_data)

training_data = {
    key: arr[shuffle_idxs][:num_train] for key, arr in hct_data.items()
}
eval_data = {key: arr[shuffle_idxs][num_train:] for key, arr in hct_data.items()}

train_data_manager = utils.BatchManager(next(data_prng), training_data,
                                        batch_size=int(0.5*num_train))

eval_data_manager = utils.BatchManager(next(data_prng), eval_data,
                                       len(eval_data['images']))

sample_batch = train_data_manager.next_batch()

## Setup Model and Trainstate

In [None]:
model_prng = hk.PRNGSequence(jax.random.PRNGKey(123456))

loss = lambda u_true, u_pred: jnp.sum(jnp.square(u_true - u_pred))
action_dim = hct_data['actions'].shape[-1]

model = ndp.NDP(
    action_dim=action_dim,
    num_actions=num_actions,
    loss_fnc=loss,
    activation=nn.relu,
    zs_dim=32,
    zs_width=64,
    zo_dim=16,
    num_basis_fncs=5)

In [None]:
learning_rate = 1e-2
weight_decay = 0.
train_state = ndp.create_ndp_train_state(
    model, next(model_prng), learning_rate, weight_decay,
    sample_batch['images'], sample_batch['hf_obs']
)

In [None]:
utils.param_count(train_state.params)

In [None]:
# Replicate model across devices
train_state = flax_utils.replicate(train_state)

## Setup Eval

In [None]:
def eval(eval_batch_manager: utils.BatchManager,
         ts: utils.TrainStateBN,
         key: chex.PRNGKey,
         num_devices: int) -> float:
  """Do loss eval."""

  prng = hk.PRNGSequence(key)
  num_eval_batches = eval_batch_manager.num_batches

  eval_loss = 0.
  for _ in range(num_eval_batches):
    eval_batch = eval_batch_manager.next_pmapped_batch(num_devices)
    batch_loss, _ = ndp.optimize_ndp(
        ts, eval_batch['images'], eval_batch['hf_obs'], eval_batch['actions'])
    eval_loss += batch_loss[0]

  return eval_loss / num_eval_batches


## Create Save dirs

In [None]:
timestamp = datetime.datetime.now().strftime(f'%Y-%m-%d-%H:%M:%S')
exp_dir = '/tmp' + timestamp
chk_subdir = 'ndp'
chk_dir = os.path.join(exp_dir, chk_subdir)

## Do Optimization

In [None]:
num_train_steps = 300
log_every = 5
eval_every = 10
save_every = 50

for idx in tqdm.tqdm(range(num_train_steps)):

  train_batch = train_data_manager.next_pmapped_batch(num_devices)
  batch_loss, train_state = ndp.optimize_ndp(
      train_state, train_batch['images'], train_batch['hf_obs'], train_batch['actions'])

  if idx % log_every == 0:
    print('idx: ', idx, 'train_loss:', batch_loss[0])

  if idx % eval_every == 0:
    eval_loss = eval(eval_data_manager, train_state, next(data_prng), num_devices)
    print('idx: ', idx, 'eval_loss:', eval_loss)

  if (idx+1) % save_every == 0:
    utils.save_model(chk_dir, idx + 1, save_every,
                     flax_utils.unreplicate(train_state))

## Test eval on full episode

In [None]:
# Skip if proceeding directly from training
train_state = utils.restore_model(chk_dir, flax_utils.unreplicate(train_state))
train_state = flax_utils.replicate(train_state)

In [None]:
model_params = {'params': flax_utils.unreplicate(train_state).params,
                'batch_stats': flax_utils.unreplicate(train_state).batch_stats}

In [None]:
# Generate some predictions
actions_pred, losses = model.apply(model_params,
                                   hct_data['images'],
                                   hct_data['hf_obs'],
                                   hct_data['actions'],
                                   method=model.compute_augmented_flow)

In [None]:
# Unnormalize the predictions
u_true = jax.vmap(utils.unnormalize, in_axes=(0, None, None))(
    hct_data['actions'], norm_stats[0]['actions'], norm_stats[1]['actions'])
u_pred = jax.vmap(utils.unnormalize, in_axes=(0, None, None))(
    actions_pred, norm_stats[0]['actions'], norm_stats[1]['actions'])

In [None]:
# Plot comparison

eval_range = [0, 20]

fig, axs = plt.subplots(7, 2, figsize=(30, 30))
for i in range(action_dim):
  row, col = np.unravel_index(i, (7, 2))
  for t in range(eval_range[0], eval_range[1]):
    lplot = axs[row][col].plot(
        np.arange(t*num_actions, (t+1)*num_actions), u_true[t, :, i], '--',
        linewidth=2)
    color = lplot[0].get_color()

    _ = axs[row][col].plot(
        np.arange(t*num_actions, (t+1)*num_actions), u_pred[t, :, i], '-',
        linewidth=1.5, color=color)