# Download and run a LatentModulatedSIREN model pre-trained on CelebA-HQ-64
This demo shows how to load the pretrained weights from a LatentModulatedSIREN model, from the paper [From data to functa: Your data point is a function and you can treat it like one]() (Dupont, Kim, Eslami, Rezende, Rosenbaum. 2022). It uses code from the official [JAX](https://github.com/google/jax) + [Haiku](https://github.com/deepmind/dm-haiku) implementation.


It's recommended to use `Runtime->Change Runtime Type` to pick a GPU for speed.

In [None]:
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
!pip install chex
!pip install dm-haiku
!pip install dill
!pip install matplotlib
!pip install optax
!git clone https://github.com/deepmind/functa/
import dill
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import os
os.chdir('functa')
import function_reps, pytree_conversions, helpers

# Load pretrained weights 

In [None]:
# Load params of LatentModulatedSiren model
mod_dim = 512  # choose one of 64, 128, 256, 512, 1024
# Download pretrained weights
os.environ['MOD_DIM'] = str(mod_dim)
!wget https://storage.googleapis.com/dm-functa/celeba_params_${MOD_DIM}_latents.npz
# Load pretrained weights
path = f'celeba_params_{mod_dim}_latents.npz'
with open(path, 'rb') as f:
  ckpt = dill.load(f)
params = ckpt['params']
config = ckpt['config']
assert config['model']['type'] == 'latent_modulated_siren'
print(f'Loaded params for model with {mod_dim} latent dimensions.')
# Create haiku transformed model that runs the forward pass.
# Only keep configs needed for model construction from model config
# `None` below ensures no error is given when already removed
model_config = config['model'].copy()
model_config.pop('type', None)
model_config.pop('l2_weight', None)
model_config.pop('noise_std', None)


def model_net(coords):
  hk_model = function_reps.LatentModulatedSiren(
      out_channels=config['dataset']['num_channels'], **model_config)
  return hk_model(coords)


model = hk.without_apply_rng(hk.transform(model_net))

# Define function that renders image from a single modulation
weights, init_modulation = function_reps.partition_params(params)
init_modulation, concat_idx, tree_def = pytree_conversions.pytree_to_array(
    init_modulation)


def render_image(modulation, coords):
  modulation_tree = pytree_conversions.array_to_pytree(
      modulation, concat_idx, tree_def)
  modulated_params = function_reps.merge_params(weights, modulation_tree)
  return model.apply(modulated_params, coords)


# Use jit and vmap to render faster on a batch of modulations
render_image = jax.jit(jax.vmap(render_image))

# Load modulation dataset and grab a batch of modulations

In [None]:
# Download and load a batch of modulations.
# Ensure that the modulation dataset has been downloaded to the correct dir.
!wget https://storage.googleapis.com/dm-functa/celeba_modulations_${MOD_DIM}_latents.npz
path = f'celeba_modulations_{mod_dim}_latents.npz'
with open(path, 'rb') as f:
  data = dill.load(f)
  train_dict = data['train']
  test_dict = data['test']
bs = 9
test_mods = test_dict['modulation'][:bs]
assert test_mods.shape == (bs, mod_dim)

# Reconstruct batch of modulations and visualize reconstructions

In [None]:
# Create coords and tile for vmapping
coords = function_reps.get_coordinate_grid(config['dataset']['resolution'])
coords = jnp.stack([coords for _ in range(bs)])  # (bs, H, W, 2)
# Reconstruct test_mods
rec = render_image(test_mods, coords)  # (bs, H, W, 3)

# Plot reconstructions as a grid
im_batch = helpers.image_grid_from_batch(rec)
gridsize = int(np.floor(np.sqrt(bs)))
figsize = 4
fig, ax = plt.subplots(figsize=(gridsize * figsize, gridsize * figsize))
ax.imshow(im_batch)
ax.set_axis_off()
plt.show()