# Visualization

## Import

In [1]:
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
import pandas as pd

from common.cell import to_rgb, to_rgba, make_ellipse_mask
from common.pool import Pool
from common.nca import NCA
from common.vae import vae_dict
from common.utils import load_face, jnp2pil, visualize_nca, load_params

from tqdm import tqdm
from omegaconf import OmegaConf

## Load experiment

In [2]:
from pathlib import Path


run_path = Path("/project/output/face/2023-08-18_095035_765053")
config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

## Main

In [3]:
config.exp.dataset_size = 1000
config.exp.n_perceive_free = 6
config.exp.update_size = 128

In [4]:
# Init a random key
random_key = jax.random.PRNGKey(config.seed)

# Load VAE
vae_dir = Path(config.exp.vae_dir)
vae_config = OmegaConf.load(vae_dir / ".hydra" / "config.yaml")

# Load list_attr_celeba.txt file into a pandas DataFrame
df_attr_celeba = pd.read_csv(vae_config.exp.attr_dir, sep="\s+", skiprows=1)
df_attr_celeba.replace(to_replace=-1, value=0, inplace=True) # replace -1 by 0

# Load list_landmarks_align_celeba.txt file into a pandas DataFrame
df_landmarks_align_celeba = pd.read_csv(vae_config.exp.landmarks_dir, sep="\s+", skiprows=1)

# Crop images from (218, 178) to (178, 178)
df_landmarks_align_celeba["lefteye_y"] = df_landmarks_align_celeba["lefteye_y"] - (218 - 178) / 2
df_landmarks_align_celeba["righteye_y"] = df_landmarks_align_celeba["righteye_y"] - (218 - 178) / 2
df_landmarks_align_celeba["nose_y"] = df_landmarks_align_celeba["nose_y"] - (218 - 178) / 2
df_landmarks_align_celeba["leftmouth_y"] = df_landmarks_align_celeba["leftmouth_y"] - (218 - 178) / 2
df_landmarks_align_celeba["rightmouth_y"] = df_landmarks_align_celeba["rightmouth_y"] - (218 - 178) / 2

# Resize images from (178, 178) to face_shape
df_landmarks_align_celeba /= 178/vae_config.exp.face_shape[0]

# Dataset
height, width = vae_config.exp.face_shape[:2]
if config.exp.dataset_size == -1:
    dataset_size = df_landmarks_align_celeba.shape[0]
else:
    dataset_size = config.exp.dataset_size

if vae_config.exp.grayscale:
    dataset_phenotypes_target = np.zeros((dataset_size, *vae_config.exp.face_shape, 1))
else:
    dataset_phenotypes_target = np.zeros((dataset_size, *vae_config.exp.face_shape, 3))

mask = np.zeros((dataset_size, height, width, 1,))
for i, (index, row,) in tqdm(enumerate(df_landmarks_align_celeba.iterrows()), total=dataset_size):
    dataset_phenotypes_target[i] = load_face(vae_config.exp.dataset_dir + index, vae_config.exp.face_shape, vae_config.exp.grayscale)
    center = (row["lefteye_x"] + row["righteye_x"]) / 2, (row["lefteye_y"] + row["righteye_y"]) / 2
    mask[i, ..., 0] = make_ellipse_mask(center, width, height, 0.7*width/2, 0.9*height/2)
    if i == dataset_size-1:
        break
dataset_phenotypes_target = dataset_phenotypes_target * mask

# VAE
vae = vae_dict[vae_config.exp.vae_index](img_shape=dataset_phenotypes_target[0].shape, latent_size=vae_config.exp.latent_size)
random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
vae_params = vae.init(random_subkey_1, random_subkey_2, dataset_phenotypes_target[0])
vae_params = load_params(vae_params, vae_dir / "vae.pickle")
param_count = sum(x.size for x in jax.tree_util.tree_leaves(vae_params))
print("Number of parameters in VAE: ", param_count)

# Cell states
if vae_config.exp.grayscale:
    phenotype_size = 1
    cell_state_size = phenotype_size + 1 + config.exp.hidden_size
else:
    phenotype_size = 3
    cell_state_size = phenotype_size + 1 + config.exp.hidden_size

@jax.jit
def phenotype_to_genotype(random_key, phenotype_target):
    z, _, _ = vae.apply(vae_params, random_key, phenotype_target, method=vae.encode)
    return z

@jax.jit
def init_cell_state():
    cell_state = jnp.zeros((phenotype_size+1+config.exp.hidden_size,))  # init cell_state
    cell_state = cell_state.at[phenotype_size:].set(1.0)  # set alpha and hidden channels to 1.0
    return cell_state

@jax.jit
def init_cells_state(_):
    cell_state = init_cell_state()
    cells_state = jnp.zeros((height, width, cell_state_size,))
    return cells_state.at[height//2, width//2].set(cell_state)

def phenotype_to_genotype_scan(carry, x):
    random_key, phenotype_target = x
    z = phenotype_to_genotype(random_key, phenotype_target)
    return (), z

random_keys = jax.random.split(random_key, 1+dataset_phenotypes_target.shape[0])
random_key, random_keys = random_keys[-1], random_keys[:-1]
_, dataset_genotypes_target = jax.lax.scan(
    phenotype_to_genotype_scan,
    (),
    (random_keys, dataset_phenotypes_target),
    length=dataset_phenotypes_target.shape[0])

# Trainset - Testset phenotypes
dataset_phenotypes_target = np.concatenate([dataset_phenotypes_target, mask], axis=-1)
trainset_phenotypes_target = dataset_phenotypes_target[:int(0.9 * len(dataset_phenotypes_target))]
testset_phenotypes_target = dataset_phenotypes_target[int(0.9 * len(dataset_phenotypes_target)):]

# Trainset - Testset genotypes
trainset_genotypes_target = dataset_genotypes_target[:int(0.9 * len(dataset_genotypes_target))]
testset_genotypes_target = dataset_genotypes_target[int(0.9 * len(dataset_genotypes_target)):]

# Pool
phenotypes_target_idx_init = jax.random.choice(random_key, trainset_phenotypes_target.shape[0], shape=(config.exp.pool_size,), replace=True)
cells_states_init = jax.vmap(init_cells_state)(phenotypes_target_idx_init)
genotypes_target_init = jnp.take(trainset_genotypes_target, phenotypes_target_idx_init, axis=0)
pool = Pool(cells_states=cells_states_init, phenotypes_target_idx=phenotypes_target_idx_init)

# NCA
nca = NCA(cell_state_size=cell_state_size, n_perceive_free=config.exp.n_perceive_free, update_size=config.exp.update_size, fire_rate=config.exp.fire_rate)
random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
params = nca.init(random_subkey_1, random_subkey_2, cells_states_init[0], genotypes_target_init[0])
params = nca.set_kernel(params)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print("Number of parameters in NCA: ", param_count)

# Train state
lr_sched = optax.linear_schedule(init_value=config.exp.learning_rate, end_value=0.1*config.exp.learning_rate, transition_steps=2000)

def zero_grads():
    def init_fn(_):
        return ()

    def update_fn(updates, state, params=None):
        return jax.jax.tree_util.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=lr_sched),)
tx = optax.multi_transform({False: optimizer, True: zero_grads()},
                            nca.get_perceive_mask(params))

train_state = TrainState.create(
    apply_fn=nca.apply,
    params=params,
    tx=tx)

# Train
@jax.jit
def loss_f(cell_states, phenotype):
    return jnp.mean(jnp.square(to_rgba(cell_states) - phenotype), axis=(-1, -2, -3))

loss_log = []

@jax.jit
def scan_apply(carry, random_key):
    (params, cells_states, genotype_target,) = carry
    cells_states_ = train_state.apply_fn(params, random_key, cells_states, genotype_target)
    return (params, cells_states_, genotype_target,), ()

@partial(jax.jit, static_argnames=("n_iterations",))
def train_step(random_key, train_state, cells_states, genotype_target, phenotypes_target, n_iterations):
    def loss_fn(params):
        random_keys = jax.random.split(random_key, n_iterations)
        (params, cells_states_, _,), _ = jax.lax.scan(scan_apply, (params, cells_states, genotype_target,), random_keys, length=n_iterations)
        return loss_f(cells_states_, phenotypes_target).mean(), cells_states_

    (loss, cells_states_), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)

    return train_state, loss, cells_states_

100%|█████████▉| 999/1000 [00:02<00:00, 334.69it/s]


Number of parameters in VAE:  34510979
Number of parameters in NCA:  100936


## Load NCA

In [11]:
i = 16700
params = load_params(params, run_path / "nca_{:07d}.pickle".format(i))

train_state = TrainState.create(
    apply_fn=nca.apply,
    params=params,
    tx=tx)

## Visualize

### Testset

In [12]:
@jax.jit
def scan_apply(carry, random_key):
    (params, cells_states, genotype_target,) = carry
    cells_states_ = train_state.apply_fn(params, random_key, cells_states, genotype_target)
    return (params, cells_states_, genotype_target,), (cells_states_,)

In [25]:
n_iterations = 500
phenotype_target_idx = 35

cells_state = init_cells_state(None)
phenotype_target = testset_phenotypes_target[phenotype_target_idx]
genotype_target = testset_genotypes_target[phenotype_target_idx]

random_keys = jax.random.split(random_key, n_iterations)
(params, cells_state_, _,), (cells_states_,) = jax.lax.scan(
    scan_apply,
    (params, cells_state, genotype_target,),
    random_keys,
    length=n_iterations,)

In [26]:
cells_states = jnp.concatenate([jnp.expand_dims(cells_state, axis=0), cells_states_], axis=0)
imgs = jnp.concatenate([to_rgba(cells_states), jnp.repeat(to_rgba(phenotype_target)[None, ...], n_iterations+1, axis=0)], axis=2)

In [27]:
imgs = [jnp2pil(to_rgb(img)) for img in imgs]
imgs[0].save(run_path / "{:06d}.gif".format(phenotype_target_idx), save_all=True, append_images=imgs[1:], duration=100, loop=0)

### Trainset

In [39]:
n_iterations = 500
phenotype_target_idx = 36

cells_state = init_cells_state(None)
phenotype_target = testset_phenotypes_target[phenotype_target_idx]
genotype_target = testset_genotypes_target[phenotype_target_idx]

random_keys = jax.random.split(random_key, n_iterations)
(params, cells_state_, _,), (cells_states_,) = jax.lax.scan(
    scan_apply,
    (params, cells_state, genotype_target,),
    random_keys,
    length=n_iterations,)