# License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Instructions

---
To start an experiment:
---
1. Choose the agent type by setting the `AGENT_CLASS` variable in the configuration below.
Select `Vit`  in order to choose the singlepath agent. Select `MultiVit` in order to the multipath agent.
These agents use a ViT-Large root model.
Set agent types suffixed with `T3` to use a ViT-Tiny root model capped to 3 layers. Set agent types suffixed with `B` to use a ViT-Base root model.

1. Set `TASK_NAME` to the string-id of the task assigned to the instantiation of the selected agent.
Refer to `TFDS_IMAGE_CLASSIFCATON_DATASETS` and `VTAB_TASKS` (below) for lists of tasks ids that have been tested with the current code. 
These lists contain task ids from the [Tensorflow Datasets Catalog](https://www.tensorflow.org/datasets/catalog/overview).
Note that some tasks require manual download, refer to the corresponding catalog page for instructions. **WARNING**: The system state needs to be populated with at least one root model before running an agent training on any task. In order to generate the root model, set `TASK_NAME` to either `"root_model/checkpoint"` or `"root_model/random_init"` for respectively loading a pretrained root model or generating a randomly initialized one.

1. Set `NUM_PATHS_SAMPLED_MAX` to the desired number of paths to sample. Additional configuration parameters can be modified in the Agents definitions code below. Configurations are set to the settings described in the publication.

1. Set `SYSTEM_STATE_RELATIVE_DIR` to a relative path from where the system state will be read and written.

1. By default, the system state is stored under a temporary folder within the Virtual Machine (VM) memory. This temporary folder is deleted when the VM is stopped or restarted.
It is possible to store the system state folder on your Google Drive by activating the
`SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE` option. In this case, you will be prompted for access approval and the system state folder will be saved in a folder named `"munet_system_state"` under your Google Drive root folder. Furthermore, it is also possible to store the system state into a Google Drive folder shared with multiple users by creating a link to the shared folder into your Google Drive and then setting `GDRIVE_ROOT_DIR` (below) to the path of the linked shared folder. 

1. To start the experiment, select "Connect to a hosted runtime" from the dropdown menu on the top right, and then select "Run all" from the "Runtime" menu. A free hosted CPU runtime is always available. Free access to GPU and TPU accelerators are occasionally provided by the service depending on availability.

---
During the experiment execution:
---

1. The print output is displayed after the last cell of this Colab.

1. The system state folder is populated with a subfolder for each agent.
The name of each agent folder is prefixed with the `AGENT` type string and suffixed with the `TASK_NAME`.
Each agent directory is populated with incremental state subfolders  containing the sharded state of the architectures and parameters generated by the agent during the corresponding evolutionary cycle.

1. Agents can be started asynchronously and run in parallel in varying quantities.
It is possible to resume an interrupted agent training by restarting the execution with the same configuration.
It is possible to continue a completed training by increasing `NUM_PATHS_SAMPLED_MAX`.

1. To achieve a multi-agent execution, multiple Colabs need to be run in parallel, each set to the same configuration but different `TASK_NAME`.

1. To achieve heterogeneous hardware execution, parallel Colab Notebooks can be connected to a runtime of different types.
It is possible to switch between CPU, GPU and TPU by selecting `Change runtime type` in the `Resources` tab in this Colab Notebook.

In [None]:
# @title Agent parameters
AGENT_CLASS = "VitT3" # @param ["VitT3", "VitB", "Vit", "MultiVitT3", "MultiVitB", "MultiVit"] { type: "string", isTemplate: true }
# Set TASK_NAME to "root_model/checkpoint" or "root_model/random_init" to initalize the population.
TASK_NAME = "root_model/checkpoint"  # @param { type: "string", isTemplate: true }
NUM_PATHS_SAMPLED_MAX = 16 # @param { type: "integer", isTemplate: true }

In [None]:
# Saves system state on Google drive instead of saving it in a temporary VM folder.
SYSTEM_STATE_RELATIVE_DIR = "munet_system_state/"  # @param { type: "string", isTemplate: true }
SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE = False  # @param { type: "boolean", isTemplate: true }
if SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE:
  from google.colab import drive
  drive.mount('/content/gdrive')
  GDRIVE_ROOT_DIR = "/content/gdrive/My Drive/"
  SYSTEM_STATE_DIR = GDRIVE_ROOT_DIR + SYSTEM_STATE_RELATIVE_DIR
  print("Saving system state in Google Drive.")
else:
  SYSTEM_STATE_DIR = "/tmp/" + SYSTEM_STATE_RELATIVE_DIR
  print("WARNING: Saving system state in VM, state will be lost after reboot!")

In [None]:
# Test immutability of published paths at beginning of each cycle.
# Tollerance may be increased if the system is run with any context difference: e.g. harware, input preprocessing, libraries or datsets version.
TEST_IMMUTABILITY = False
IMMUTABILITY_RELATIVE_TOLLERANCE = 0.001  # 0.1%

# Imports

In [None]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git
!pip install -q ml_collections
!pip install -q tensorflow_addons
![ -d task_adaptation ] || git clone --depth=1 https://github.com/google-research/task_adaptation
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer

import sys
if './task_adaptation' not in sys.path:
  sys.path.append('./task_adaptation')
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')

import jax.tools.colab_tpu
try:
  jax.tools.colab_tpu.setup_tpu()
except:
  pass  # Not a Tpu

In [None]:
import copy
import flax
import flax.linen as nn
import gc
import inspect
import jax
import jax.extend
import jax.numpy as jnp
import json
import math
import numpy as np
import optax
import os
import pandas as pd
import random
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from collections import defaultdict
from functools import partial
from flax.training import checkpoints as flax_checkpoints
from ml_collections import ConfigDict, FrozenConfigDict
from tensorflow.io import gfile
from threading import Thread, Lock
from typing import Any, Type
tf.compat.v1.enable_eager_execution()

In [None]:
# ViT imports.
from vision_transformer.vit_jax import checkpoint as vit_checkpoint
from vision_transformer.vit_jax.configs import models as vit_models_config
from vision_transformer.vit_jax import models_vit as vit_models
# VTAB imports.
import task_adaptation.registry as task_adapt_registry
import task_adaptation.data.caltech
import task_adaptation.data.cifar
import task_adaptation.data.dtd
import task_adaptation.data.oxford_flowers102
import task_adaptation.data.oxford_iiit_pet
import task_adaptation.data.sun397
import task_adaptation.data.svhn
import task_adaptation.data.patch_camelyon
import task_adaptation.data.eurosat
import task_adaptation.data.resisc45
import task_adaptation.data.diabetic_retinopathy
import task_adaptation.data.clevr
import task_adaptation.data.dmlab
import task_adaptation.data.dsprites
import task_adaptation.data.kitti
import task_adaptation.data.smallnorb

# Utils

In [None]:
# Ref. Tfds catalog: https://www.tensorflow.org/datasets/catalog/overview
TFDS_IMAGE_CLASSIFCATON_TASKS = {
    "beans": {"version":"0.1.0"},
    "binary_alpha_digits": {"version":"1.0.0"},
    "caltech_birds2010": {"version":"0.1.1"},
    "caltech_birds2011": {"version":"0.1.1"},
    "cars196": {"version":"2.1.0"},
    "cassava": {"version":"0.1.0"},
    "cats_vs_dogs": {"version":"4.0.0"},
    "cifar10": {"version":"3.0.2"},
    "cifar100": {"version":"3.0.2"},
    "citrus_leaves": {"version":"0.1.2"},
    "cmaterdb/bangla": {"version":"1.0.0"},
    "cmaterdb/devanagari": {"version":"1.0.0"},
    "cmaterdb/telugu": {"version":"1.0.0"},
    "colorectal_histology": {"version":"2.0.0"},
    "controlled_noisy_web_labels/mini_imagenet_red": {"version":"1.0.0"},
    "controlled_noisy_web_labels/mini_imagenet_blue": {"version":"1.0.0"},
    "curated_breast_imaging_ddsm/patches": {"version":"3.0.0"},
    "cycle_gan/apple2orange": {"version":"2.0.0"},
    "cycle_gan/summer2winter_yosemite": {"version":"2.0.0"},
    "cycle_gan/horse2zebra": {"version":"2.0.0"},
    "cycle_gan/monet2photo": {"version":"2.0.0"},
    "cycle_gan/cezanne2photo": {"version":"2.0.0"},
    "cycle_gan/ukiyoe2photo": {"version":"2.0.0"},
    "cycle_gan/vangogh2photo": {"version":"2.0.0"},
    "cycle_gan/maps": {"version":"2.0.0"},
    "cycle_gan/cityscapes": {"version":"2.0.0"},
    "cycle_gan/facades": {"version":"2.0.0"},
    "cycle_gan/iphone2dslr_flower": {"version":"2.0.0"},
    "deep_weeds": {"version":"3.0.0"},
    "domainnet/real": {"version":"1.0.0"},
    "domainnet/painting": {"version":"1.0.0"},
    "domainnet/clipart": {"version":"1.0.0"},
    "domainnet/quickdraw": {"version":"1.0.0"},
    "domainnet/infograph": {"version":"1.0.0"},
    "domainnet/sketch": {"version":"1.0.0"},
    "emnist/digits": {"version":"3.0.0"},
    "emnist/letters": {"version":"3.0.0"},
    "fashion_mnist": {},  # Not yet included in the µNet system checkpoint.
    "food101": {"version":"2.0.0"},
    "horses_or_humans": {"version":"3.0.0"},
    "i_naturalist2017": {"version":"0.1.0"},
    "i_naturalist2018": {"version":"1.0.0"},
    "i_naturalist2021": {},  # Not yet included in the µNet system checkpoint.
    "imagenet2012": {"version":"5.1.0"},
    "imagenet_a": {"version":"0.1.0"},
    "imagenet_lt": {"version":"1.0.0"},
    "imagenet_r": {"version":"0.2.0"},
    "imagenet_sketch": {"version":"1.0.0"},
    "imagenette": {"version":"1.0.0"},
    "imagewang": {"version":"2.0.0"},
    "kmnist": {"version":"3.0.1"},
    "malaria": {"version":"1.0.0"},
    "mnist": {"version":"3.0.1"},
    "omniglot": {"version":"3.0.0"},
    "pet_finder": {"version":"1.0.0"},
    "places365_small": {"version":"2.1.0"},
    "plant_village": {"version":"1.0.2"},
    "plantae_k": {"version":"0.1.0"},
    "quickdraw_bitmap": {"version":"3.0.0"},
    "rock_paper_scissors": {"version":"3.0.0"},
    "siscore/rotation": {"version":"1.0.0"},
    "siscore/size": {"version":"1.0.0"},
    "siscore/location": {"version":"1.0.0"},
    "stanford_dogs": {"version":"0.2.0"},
    "stanford_online_products": {"version":"1.0.0"},
    "stl10": {"version":"1.0.0"},
    "tf_flowers": {"version":"3.0.1"},
    "uc_merced": {"version":"2.0.0"},
    "visual_domain_decathlon/aircraft": {"version":"1.2.0"},
    "visual_domain_decathlon/cifar100": {"version":"1.2.0"},
    "visual_domain_decathlon/daimlerpedcls": {"version":"1.2.0"},
    "visual_domain_decathlon/dtd": {"version":"1.2.0"},
    "visual_domain_decathlon/gtsrb": {"version":"1.2.0"},
    "visual_domain_decathlon/imagenet12": {"version":"1.2.0"},
    "visual_domain_decathlon/omniglot": {"version":"1.2.0"},
    "visual_domain_decathlon/svhn": {"version":"1.2.0"},
    "visual_domain_decathlon/ucf101": {"version":"1.2.0"},
    "visual_domain_decathlon/vgg-flowers": {"version":"1.2.0"},
}

In [None]:
# Append suffix "/1k" to get the 1k version of each task.
VTAB_TASKS = [
    "caltech101",
    # cifar100/10 were already added with slightly different val split but same test set. So here is added only the 1k versions.
    "cifar100/1k",
    "cifar10/1k",
    "dtd",
    "oxford_flowers102",
    "oxford_iiit_pet",
    "sun397",
    "svhn_cropped",
    "patch_camelyon",
    "eurosat",
    "resisc45",
    "diabetic_retinopathy_detection/btgraham-300",
    "clevr/count_cylinders",  # Not in results table.
    "clevr/count_all",  # Clevr-Count
    "clevr/closest_object_distance",  # Clevr-Dist
    "dmlab",
    "dsprites/label_x_position",  # dSpr-Loc
    "dsprites/label_orientation",  # dSpr-Ori
    "kitti/closest_object_distance",  # Not in results table.
    "kitti/count_vehicles",  # Not in results table.
    "kitti/closest_vehicle_distance",  # Kitti-dist
    "smallnorb/label_category",  # Not in results table.
    "smallnorb/label_lighting",  # Not in results table.
    "smallnorb/label_azimuth",  # Azim
    "smallnorb/label_elevation",  # Elev
    ]
for tn in VTAB_TASKS:
  assert tn not in TFDS_IMAGE_CLASSIFCATON_TASKS.keys(), tn

In [None]:
def compute_flops_hlo(flax_module, *a, **kw):
  m = jax.xla_computation(flax_module)(*a, **kw).as_hlo_module()
  # Compute flops on cpu for cross platform concistency.
  client = jax.extend.backend.get_backend("cpu")
  analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
  return analysis["flops"]

# Models

In [None]:
def get_sample_images(image_size, batch_size):
  return np.zeros((batch_size, image_size, image_size, 3))

In [None]:
def get_num_params(params):
  return sum(jax.tree_util.tree_flatten(
      jax.tree_util.tree_map(lambda p: np.prod(p.shape), params))[0])

In [None]:
def get_optimizer(
    opt_lr: float,
    opt_lr_schedule: str,
    opt_lr_warmup_ratio: float,
    opt_momentum: float,
    opt_nesterov: bool,
    num_train_batches_between_validations: int,
    num_validations_per_path_training: int,
    ):
  min_lr = opt_lr / 1000.0
  if opt_lr_schedule == "constant":
    # Divide by 2 so that average lr is the same as other types.
    learning_rate = 0.5 * opt_lr
  elif opt_lr_schedule == "linear":
    train_steps = int(num_train_batches_between_validations * num_validations_per_path_training)
    warmup_steps = int(opt_lr_warmup_ratio * train_steps)
    schedules = [
        optax.linear_schedule(
            init_value=min_lr,
            end_value=opt_lr,
            transition_steps=warmup_steps),
        optax.linear_schedule(
            init_value=opt_lr,
            end_value=min_lr,
            transition_steps=train_steps-warmup_steps)]
    learning_rate = optax.join_schedules(schedules, [warmup_steps])
  elif opt_lr_schedule == "cosine":
    train_steps = int(num_train_batches_between_validations
                      * num_validations_per_path_training)
    learning_rate = optax.warmup_cosine_decay_schedule(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps)
  elif opt_lr_schedule == "restarts":
    train_steps = num_train_batches_between_validations
    repeats = num_validations_per_path_training
    kwargs = dict(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps,
    )
    kwargs = [kwargs] * repeats
    learning_rate = optax.sgdr_schedule(kwargs)
  else:
    assert False, f"Invalid lr schedule: {opt_lr_schedule}"

  return optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.sgd(
          learning_rate=learning_rate,
          momentum=opt_momentum,
          nesterov=opt_nesterov,
          accumulator_dtype=jnp.bfloat16))

In [None]:
def merge_params(a, b):
  params = a.copy(b)
  assert len(params) == len(a) + len(b)
  return params

## Vit Model

In [None]:
class VitModelFactory():
  @staticmethod
  def get_model(hparams, config):
    return get_vit_model(hparams, config)

  @staticmethod
  def get_init_comps(hparams, config):
    return get_vit_init_comps(hparams, config)

  @staticmethod
  def get_comps2model_fn():
    return vit_comps2model

  @staticmethod
  def get_sample_input(hparams):
    return get_sample_images(image_size=hparams["ds_image_size"], batch_size=1)

In [None]:
def get_vit_filename(query):
  df = vit_checkpoint.get_augreg_df()
  res = df.query(query).filename.unique()
  assert len(res) == 1
  return res[0]

In [None]:
VIT_CONFIG_CACHE = {}

def get_vit_config(query):
  global VIT_CONFIG_CACHE
  if query not in VIT_CONFIG_CACHE:
    filename = get_vit_filename(query)
    config = vit_models_config.AUGREG_CONFIGS[filename.split("-")[0]].copy_and_resolve_references()
    config.unlock()
    # Disable dropout.
    config.transformer.dropout_rate = 0.0
    config.transformer.attention_dropout_rate = 0.0
    config.lock()
    VIT_CONFIG_CACHE[query] = config
  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()

def get_set_vit_config(hparams, config):
  path_config = get_vit_config(config.vit_checkpoint_query)
  path_config.transformer.num_layers = int(hparams["num_layers"])
  path_config.unlock()
  path_config.num_classes = int(hparams["num_classes"])
  if "classifier" in hparams:
    path_config.classifier = hparams["classifier"]
  path_config.lock()
  path_config = FrozenConfigDict(path_config)
  return path_config

def get_max_num_layers(query):
  config = get_vit_config(query)
  return config.transformer.num_layers

In [None]:
# Get params from ViT checkpoints.
def get_vit_checkpoint_comps(image_size, query):
  filename = get_vit_filename(query)
  config = get_vit_config(query)
  model = vit_models.VisionTransformer(**config, num_classes=1)  # num_classes unused.
  init_params = copy.deepcopy(jax.device_get(
      model.init(jax.random.PRNGKey(random.randrange(int(1e10))),
                 VitModelFactory.get_sample_input({"ds_image_size": image_size}),
                 train=False  # Disables dropout, no effect on params.
                 )["params"]))
  params = vit_checkpoint.load_pretrained(
    pretrained_path=f"gs://vit_models/augreg/{filename}.npz",
    init_params=init_params,
    model_config=config)
  return vit_model2comps(params)

In [None]:
# Get ViT model and init_params.
def get_vit_model(hparams, config):
  vit_config = get_set_vit_config(hparams, config)
  return vit_models.VisionTransformer(**vit_config)

def get_vit_init_comps(hparams, config):
  model = get_vit_model(hparams, config)
  init_params = copy.deepcopy(jax.device_get(model.init(
      jax.random.PRNGKey(random.randrange(int(1e10))),
      VitModelFactory.get_sample_input(hparams),
      train=False  # Disables dropout, no effect on params.
      )["params"]))
  return vit_model2comps(init_params)

In [None]:
# ViT parameters mapping to components.
TRANSFORMER_KEYS = set(
    ["encoder_norm", "posembed_input" ] + \
    [f"encoderblock_{k}" for k in range(30)])

def vit_model2comps(params):
  new_params = {}
  for k in params.keys():
    if k == "Transformer":
      t_params = params[k]
      for t_k in t_params.keys():
        new_params[t_k] = t_params[t_k]
    else:
      new_params[k] = params[k]
  return flax.core.freeze(new_params)

def vit_comps2model(params):
  new_params = params.unfreeze()
  new_params["Transformer"] = {}
  for k in list(new_params.keys()):
    if k in TRANSFORMER_KEYS:
      new_params["Transformer"][k] = new_params.pop(k)
  assert len(new_params["Transformer"]) != 0
  return flax.core.freeze(new_params)

## MultiVit Model

In [None]:
class MultiVitModelFactory():
  @staticmethod
  def get_model(hparams, config):
    return get_multivit_model(hparams, config)

  @staticmethod
  def get_init_comps(hparams, config):
    return get_multivit_init_comps(hparams, config)

  @staticmethod
  def get_comps2model_fn():
    return multivit_comps2model

  @staticmethod
  def get_sample_input(hparams):
    return {str(k): get_sample_images(image_size=k, batch_size=1) for k in hparams["ds_image_size"]}

In [None]:
def get_multivit_init_comps(hparams, config):
  model = get_multivit_model(hparams, config)
  init_params = copy.deepcopy(jax.device_get(
      model.init(
          jax.random.PRNGKey(random.randrange(int(1e10))),
          MultiVitModelFactory.get_sample_input(hparams),
          train=False  # Disables dropout, no effect on params.
          )["params"]))
  return multivit_model2comps(init_params)

In [None]:
def multivit_comps2model(params):
  params = params.unfreeze()
  for k in params:
    if k.startswith("path_"):
      params[k] = vit_comps2model(flax.core.freeze(params[k]))
  return flax.core.freeze(params)

def multivit_model2comps(params):
  # Mapping of paths component skipped since those are never used from rand init.
  return params

In [None]:
class MultipathRouter(nn.Module):
  init_main_path_weight: float
  num_paths: int
  lr_mult: float

  @nn.compact
  def __call__(self, x):
    assert self.num_paths > 0
    assert self.lr_mult >= 0 and self.lr_mult <= 1
    init_bias = np.log((1/(1/self.init_main_path_weight -1))*(self.num_paths-1))
    x = nn.LayerNorm()(x)
    x = nn.Dense(self.num_paths,
                 kernel_init=nn.initializers.zeros,
                 bias_init=nn.initializers.constant(np.asarray([init_bias]+[0]*(self.num_paths-1)))
                 )(x)
    x = nn.softmax(x)
    x = self.lr_mult * x + (1-self.lr_mult) * jax.lax.stop_gradient(x)
    return x

class Connector(nn.Module):
  out_dim: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.out_dim,
                 kernel_init=nn.initializers.zeros,
                 bias_init=nn.initializers.zeros,
                 )(x)
    return x

class MultiVitModel(nn.Module):
  config: Any
  path_module: Type[nn.Module] = vit_models.VisionTransformer
  main_path_name: str = "path_0"
  @nn.compact
  def __call__(self, inputs, *, train):
    logits_0 = self.path_module(
        name=self.main_path_name,
        **self.config.paths_configs[self.main_path_name])(
            inputs[self.config.paths_image_size[self.main_path_name]], train=train)
    out_dim = logits_0.shape[-1]
    weights = MultipathRouter(
        name="multipath_router",
        num_paths=len(self.config.paths_configs),
        **self.config.router)(
            logits_0)
    all_logits = [logits_0]
    for path_name in self.config.paths_configs:
      if path_name == self.main_path_name:
        continue
      representation = self.path_module(
          name=path_name,
          **self.config.paths_configs[path_name])(
              inputs[self.config.paths_image_size[path_name]], train=train)
      path_logits = Connector(
          name=f"head_adapter_{path_name}", out_dim=out_dim)(representation)
      all_logits.append(path_logits)
    stacked = jnp.stack(all_logits, axis=-1)
    logits_comb = jnp.einsum("BLp,Bp->BL", stacked, weights)
    logits_comb = jnp.einsum("BLp,Bp->BL", jax.lax.stop_gradient(stacked), weights)
    logits_sum = jnp.einsum("BLp,Bp->BL", stacked, jnp.ones_like(weights))
    logits_out = logits_comb - jax.lax.stop_gradient(logits_sum) + logits_sum
    return logits_out

In [None]:
def get_multivit_model(hparams, config):
  model_config = ConfigDict()
  model_config.paths_configs = {
      k: get_set_vit_config(hparams["paths"][k]["hparams"], config) for k in hparams["paths"]}
  model_config.paths_image_size = {
      k: str(hparams["paths"][k]["hparams"]["ds_image_size"]) for k in hparams["paths"]}
  model_config.router = {
      "init_main_path_weight": float(hparams["router_init_main_path_weight"]),
      "lr_mult": float(hparams["router_lr_mult"]),
  }
  return MultiVitModel(config=FrozenConfigDict(model_config))

# Agents

In [None]:
def format_agent_id(class_name, task_name):
  agent_id = f"{class_name}/{task_name}"
  assert "~" not in agent_id, f"Invalid agent id: {agent_id}"
  return agent_id.replace("/", "~")

def get_agent_class(agent_id):
  return globals()[agent_id.split("~")[0]]

In [None]:
def incremental_mutation(value, values_list):
  assert value in values_list, f"{value} not in {values_list}"
  idx = values_list.index(value)
  idx += 1 if np.random.uniform() < 0.5 else -1
  idx = max(0, min(len(values_list)-1, idx))
  return values_list[idx]

In [None]:
def decay_selection(paths):
  for path in sorted(paths, key=lambda p: p.score(), reverse=True):
    offsprings = path.metrics.get("offsprings", 0)
    assert not math.isnan(offsprings)
    select_prob = 0.5 ** offsprings
    print(f"  Candidate path {path.id},",
          f"selection probability: 0.5^{offsprings}=={select_prob}")
    if np.random.uniform() < select_prob:
      print(f"  Selected path {path.id}")
      return path

In [None]:
class Agent():
  def __init__(self, system_state_dir, task_name, num_paths_sampled_max):
    self.config = self.get_config()
    self.config.system_state_dir = system_state_dir
    self.config.task_name = task_name
    self.config.agent_id = format_agent_id(self.class_name, task_name)
    self.config.agent_dir = os.path.join(system_state_dir, self.id)
    self.config.num_paths_sampled_max = num_paths_sampled_max
    self.config = FrozenConfigDict(self.config)
    config_validate(self.config)
    self.task = None if task_name.startswith("root_model/") else Task(name=task_name, config=self.config)

  @property
  def class_name(self):
    return self.__class__.__name__

  @property
  def id(self):
    return self.config.agent_id

  def agents_to_load(self):
    # Defaults to load only agents of the same class. Extend this list to allow
    # to access the strucutres and parameters produced by different agent types.
    return [self.class_name+"~*"]

  def prune_population(self):
    print("Prune population")
    candidates = copy.copy(self.pop.paths[self.config.agent_id])
    survivors = []
    while candidates and len(survivors) < self.config.max_population_size:
      print("Sampling survivor:", len(survivors)+1)
      selected = decay_selection(candidates) if survivors else None
      if selected is None:
        selected = get_best_path(candidates)
        print("  Selected best candidate:", selected.id)
      assert selected is not None
      survivors.append(selected)
      candidates.remove(selected)
    assert len(survivors) == len(set(survivors))  # Check no duplicates.
    self.pop.paths[self.config.agent_id] = survivors

In [None]:
def run_evolution(agent):
  config = agent.config
  devices = jax.local_devices()
  print("DEVICE COUNT:", len(devices))
  while True:
    print("\n\n====")
    agent.load_state()
    print(f"\nSAMPLES: [{agent.num_paths_sampled}/{config.num_paths_sampled_max}]")
    if agent.num_paths_sampled >= config.num_paths_sampled_max:
      break
    task = agent.task
    best_path = agent.pop.get_best_path()
    if TEST_IMMUTABILITY and best_path:
      run_test_eval(best_path, task, test_immutability=True)
    gen_hparams = agent.sample_gen_hps()
    gen_hparams["num_classes"] = task.num_classes
    paths = []
    for i in range(len(devices)):
      print(f"Sampling path: {Path.counter}")
      paths.append(agent.sample_path(gen_hparams))
      agent.num_paths_sampled += 1
      gc.collect()
    gen_hparams = agent.finalize_hps(gen_hparams, paths)
    ds_train = task.get_ds("train", gen_hparams)
    ds_validation = task.get_ds("validation", gen_hparams)
    train_loop(agent, paths, ds_train, ds_validation, devices)
    for path in paths:
      if path.metrics["improved"]:
        assert path not in agent.pop.paths[config.agent_id]
        agent.pop.paths[config.agent_id].append(path)
    # Track best path.
    curr_best_path = agent.pop.get_best_path()
    if curr_best_path != best_path:
      if best_path:
        assert curr_best_path.score() >= best_path.score()
      best_path = curr_best_path
      best_path.metrics["new_best"] = True
      agent.print_best_path_summary()
    df_leaderboard(pop_to_df(agent.pop))
    agent.prune_population()
    assert best_path in agent.pop.paths[config.agent_id], best_path
    best_path.metrics["num_cycles"] = best_path.metrics.get("num_cycles", 0) + 1
    run_test_eval(best_path, task)
    save_state(agent)
    if agent.num_paths_sampled >= config.num_paths_sampled_max:
      break

def run_root_model(agent):
  agent.load_state()
  save_state(agent)

## Vit Agent

In [None]:
def get_common_config_vit():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 300_000
  config.num_validations_per_path_training = 4
  config.num_validation_examples_max = 10_000
  config.max_population_size = 4
  # Force finetune last layer norm that technically is part of the head.
  config.force_mutations = ["clone:encoder_norm"]
  config.scorer_kwargs = dict(
      scale_factor=0.99,
      base_accounted_params=2_200_000_000,
      base_flops=3_800_000_000_000,
      )
  config.hparams_defaults = {
      "_mu_": 0.2,
      "opt_lr": 0.02,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.02,
      "opt_momentum": 0.8,
      "opt_nesterov": True,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.hparams_mutation_ranges = {
      "_mu_": [0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30],
      "opt_lr": [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],
      "opt_lr_warmup_ratio": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3],
      "opt_momentum": [0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99],
      "opt_nesterov": [True, False],
      "ds_area_range_min": [0.05, 0.5, 0.95, 1.0],
      "ds_aspect_ratio_range_min": [0.5, 0.75, 1.0],
      "ds_flip_left_right": [True, False],
      "ds_brightness_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_contrast_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_saturation_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_hue_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_quality_delta": [ 0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
  }
  return config

def get_config_vit_large():
  config = get_common_config_vit()
  config.batch_size = 16
  # The query is used to get the model configs even if the checkpoint is not loaded.
  config.vit_checkpoint_query = 'name=="L/16" and ds=="i21k" and aug=="medium2" and wd==0.03 and sd==0.1'
  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.hparams_defaults["num_layers"] = max_num_layers
  config.hparams_mutation_ranges["num_layers"] = list(
      range(config.hparams_defaults["num_layers"]+1
            +1  # Allow to exceed root-model's layers by 1.
            ))
  config.hparams_defaults["ds_image_size"] = 384
  config.hparams_mutation_ranges["ds_image_size"] = [224, 384]
  return config

def get_config_vit_ti3():
  config = get_common_config_vit()
  config.batch_size = 512
  config.vit_checkpoint_query = 'name=="Ti/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'
  config.hparams_defaults["num_layers"] = 3
  config.hparams_mutation_ranges["num_layers"] = list(range(config.hparams_defaults["num_layers"]+1))
  config.hparams_defaults["ds_image_size"] = 32
  config.hparams_mutation_ranges["ds_image_size"] = [16*i for i in (range(1, 1+int(112/16)))]
  return config

def get_config_vit_base():
  config = get_common_config_vit()
  config.batch_size = 256
  config.vit_checkpoint_query = 'name=="B/16" and ds=="i21k" and aug=="medium1" and wd==0.1 and sd==0'
  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.hparams_defaults["num_layers"] = max_num_layers
  config.hparams_mutation_ranges["num_layers"] = list(range(config.hparams_defaults["num_layers"]+1))
  config.hparams_defaults["ds_image_size"] = 80
  config.hparams_mutation_ranges["ds_image_size"] = [16*i for i in (range(1, 1+int(112/16)))]
  return config

def config_validate(config):
  for khp in config.hparams_defaults:
    if khp in config.hparams_mutation_ranges:
      assert config.hparams_defaults[khp] in config.hparams_mutation_ranges[khp], khp
  for khp in config.hparams_mutation_ranges:
    assert khp in config.hparams_defaults, khp

In [None]:
class Vit(Agent):  # ViT large.
  @staticmethod
  def get_model_factory():
    return VitModelFactory

  def load_state(self):
    task_name = self.config.task_name
    self.pop = Population(self.config)
    self.num_paths_sampled = 0
    if task_name.startswith("root_model/"):  # Root models.
      hparams = self.config.hparams_defaults.as_configdict()
      if task_name == "root_model/random_init":
        hparams["num_classes"] = 0  # Removes head layer.
        path_params = self.get_model_factory().get_init_comps(hparams, self.config)
      else:
        assert task_name == "root_model/checkpoint", task_name
        path_params = get_vit_checkpoint_comps(hparams["ds_image_size"],
                                               self.config.vit_checkpoint_query)
      path = Path(hparams, [Component(k, self.id, v) for k, v in path_params.items()],
                  parent=None, agent_id=self.id, task_name=task_name)
      self.pop.paths[self.id].append(path)
    else:
      load_paths(self)
      assert self.pop.paths, "Empty population! run an agent creating a root model."
      df_leaderboard(pop_to_df(self.pop))

  def get_config(self):
    return get_config_vit_large()

  def run(self):
    if self.config.task_name.startswith("root_model/"):
      run_root_model(self)
      return
    run_evolution(self)

  def do_mutate(self, hparams, mutation_name):
    """Returns True if mutation (∂) is sampled to be applied."""
    if mutation_name in self.config.get("force_mutations", []):
      return True
    mutation_prob_k = f"_mu_|{mutation_name}"   # Key for ∂ in µ() lookup table.
    if "_mu_" in self.config.hparams_mutation_ranges:  # Mutate µ(∂) value.
      if hparams["_mu_"] > np.random.uniform():
        hparams[mutation_prob_k] = incremental_mutation(
            hparams.get(mutation_prob_k, hparams["_mu_"]),
            self.config.hparams_mutation_ranges["_mu_"])
    return hparams.get(mutation_prob_k, hparams["_mu_"]) > np.random.uniform()

  def parent_decay_selection(self):
    parent = decay_selection(self.pop.paths[self.config.agent_id])
    if parent is not None:
      parent.metrics["offsprings"] = parent.metrics.get("offsprings", 0) + 1
    return parent

  def sample_path(self, gen_hparams):
    parent = self.parent_decay_selection()
    if not parent:  # Random sample.
      parent = random.choice([p for paths in self.pop.paths.values() for p in paths])
      print(f"  Randomly selected parent {parent.agent_id}:{parent.id}")
    return self.mutate_parent(parent, gen_hparams)

  def mutate_hparams(self, hparams):
    for k in sorted(self.config.hparams_mutation_ranges):
      if k in hparams and self.do_mutate(hparams, f"hp:{k}"):
        hparams[k] = incremental_mutation(
            hparams[k], self.config.hparams_mutation_ranges[k])

  def is_generation_hp(self, k):  # True if the hp is shared across generation.
    return (k.startswith("ds_") or k.startswith("_mu_|hp:ds_") or k in ["_mu_"])
    
  def sample_gen_hps(self):
    """Sample hparams that need to be shared across each paths generation."""
    gen_hparams = {k:v for k,v in self.config.hparams_defaults.items() if self.is_generation_hp(k)}
    # Overwrite with values from best path if available.
    best_path = self.pop.get_best_path()
    if best_path:
      gen_hparams.update({k:v for k,v in best_path.hparams.items() if self.is_generation_hp(k)})
      self.mutate_hparams(gen_hparams)  # Sample mutations.
    return gen_hparams

  def finalize_hps(self, hparams, paths):
    # Validate shared params.
    for k in hparams:
      if self.is_generation_hp(k):
        for path in paths:
          assert hparams[k] == path.hparams[k]
    return hparams

  def allow_component_reuse(self, comp_name, parent):
    # Head must be trainable. If parent is not from same agent it will fall back
    # to random init of correct shape.
    return not(comp_name == "head" and self.config.agent_id != parent.agent_id)

  def mutate_parent(self, parent, gen_hparams):
    config = self.config
    agent_id = config.agent_id
    task_name = config.task_name
    comps = []
    new_hparams = copy.deepcopy(parent.hparams)
    self.mutate_hparams(new_hparams)
    # Overwrite hparams shared with paths in the current generation.
    new_hparams.update(gen_hparams)

    init_params = self.get_model_factory().get_init_comps(new_hparams, config)
    # Attept to reuse matching componenent from parent.
    comps_lookup = {c.name:c for c in parent.components}
    for new_comp_name in init_params:
      comp = None
      if new_comp_name in comps_lookup and self.allow_component_reuse(new_comp_name, parent):
        parent_comp = comps_lookup[new_comp_name]
        # Check shapes match otherwise skip.
        if (jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) ==
            jax.tree_util.tree_map(jnp.shape, parent_comp.params)):
          if parent_comp.is_trainable() or self.do_mutate(new_hparams, f"clone:{new_comp_name}"):
            comp = parent_comp.clone(agent_id)  # Clone trainable component.
          else:
            comp = parent_comp  # Refer to frozen component.
        else:
          if new_comp_name == "posembed_input":
            # Change of image size changed shape of position embeddings,
            # this can happend if ds_image_size is tuned.
            assert "ds_image_size" in config.hparams_mutation_ranges
            assert new_hparams["ds_image_size"] != parent.hparams["ds_image_size"]
          else:
            print(f"WARNING: Shapes mismatch for component: {new_comp_name}  {parent.agent_id}->{agent_id}")
            print(jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]))
            print(jax.tree_util.tree_map(jnp.shape, parent_comp.params))
            assert False
      if comp is None:  # Custom component mutation.
        comp = self.mutate_component(new_comp_name, new_hparams, init_params, parent)
      if comp is None:  # Random init.
        print(f"    Initialized parameters of \"{new_comp_name}\" module")
        comp = Component(new_comp_name, agent_id, init_params[new_comp_name])
      comps.append(comp)
    return Path(new_hparams, comps, parent=parent, agent_id=agent_id, task_name=task_name)

  def mutate_component(self, comp_name, hparams, init_params, parent):
    if comp_name == "posembed_input":
      parent_posembed = [c for c in parent.components if c.name == comp_name]
      assert len(parent_posembed) == 1
      parent_posembed = parent_posembed[0].params["pos_embedding"]
      mapped_posembed = vit_checkpoint.interpolate_posembed(
          parent_posembed, init_params[comp_name]["pos_embedding"].shape[1], True)
      return Component(
          name="posembed_input", agent_id=self.config.agent_id,
          params=flax.core.freeze({"pos_embedding": mapped_posembed}))

  def print_best_path_summary(self):
    best_path = self.pop.get_best_path()
    print(f"Best id:{best_path.id}", f"score:{best_path.score():.4f}",
          f"quality:{best_path.metrics['quality']:.4f}", f"\n{best_path.hparams}")

  def get_paths_to_publish(self):
    return [self.pop.get_best_path()]

class VitT3(Vit):  # ViT tiny 3 layers.
  def get_config(self):
    return get_config_vit_ti3()

class VitB(Vit):  # ViT base.
  def get_config(self):
    return get_config_vit_base()

## MultiVit Agent

In [None]:
def set_multivit_common_config(config):
  config.force_mutations = []
  config.scorer_kwargs = {}
  for rm_k in ["num_layers", "ds_image_size"]:
    del config.hparams_defaults[rm_k]
    del config.hparams_mutation_ranges[rm_k]
  config.hparams_defaults["router_init_main_path_weight"] = 0.8
  config.hparams_defaults["router_lr_mult"] = 0.05
  config.hparams_defaults["num_paths"] = 2
  config.hparams_mutation_ranges["router_lr_mult"] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]
  config.hparams_mutation_ranges["num_paths"] = [2, 3]
  return config

In [None]:
class MultiVit(Vit):
  def get_config(self):
    config = get_config_vit_large()
    return set_multivit_common_config(config)

  @staticmethod
  def get_model_factory():
    return MultiVitModelFactory

  def run(self):
    run_evolution(self)

  def agents_to_load(self):
    return [self.single_path_agent_class()+"~*", self.config.agent_id]

  def single_path_agent_class(self):
    return self.class_name.removeprefix("Multi")

  def single_path_main_agent_id(self):
    return format_agent_id(self.single_path_agent_class(), self.config.task_name)

  def load_state(self):
    if self.config.task_name.startswith("root_model/"):
      assert False, (
          "Root models need to be generated with the corresponding" \
          f"single path agent: '{self.single_path_agent_class()}'.")
    super().load_state()
    assert self.single_path_main_agent_id() in self.pop.paths, (
        "Missing state for the corresponding single path main agent. " \
        f"Run agent '{self.single_path_agent_class()}' " \
        f"on the '{self.config.task_name}' task to generated it.")

  def sample_path(self, gen_hparams):
    selected_paths = {}
    selected_paths["path_0"] = random.choice(self.pop.paths[self.single_path_main_agent_id()])
    parent = self.parent_decay_selection()
    if parent is not None:
      new_hparams = copy.deepcopy(parent.hparams)
    else:
      parent = selected_paths["path_0"]
      best_path = self.pop.get_best_path()
      if best_path:
        new_hparams = copy.deepcopy(best_path.hparams)
      else:
        new_hparams = self.config.hparams_defaults.to_dict()
      print("  New random")
    new_hparams["paths"] = {}
    self.mutate_hparams(new_hparams)
    # Overwrite hparams shared with paths in the current generation.
    new_hparams.update(gen_hparams)

    for path_name in [f"path_{i}" for i in range(int(new_hparams["num_paths"]))]:
      if path_name in selected_paths:
        continue
      if path_name in parent.hparams.get("paths", {}):
        selected_paths[path_name] = self.pop.get_path_from_full_id(parent.hparams["paths"][path_name]["agent_id"], parent.hparams["paths"][path_name]["id"])
        print("  Subpath from parent: ", path_name, selected_paths[path_name].full_id)
        continue
      selected_paths[path_name] = random.choice([
          p for paths in self.pop.paths.values() for p in paths if (
              p.agent_id not in [sp.agent_id for sp in selected_paths.values()]
              and p.agent_id.startswith(f"{self.single_path_agent_class()}~")
              and not p.agent_id.endswith("~1k")  # Excludes VTAB-1k tasks.
              # and (path_name != "path_1" or p.agent_id in ['Vit~i_naturalist2017'])  # Forces i_naturalist2017 selection.
              )])
      print("  Subpath rand selected:", path_name, selected_paths[path_name].full_id)

    for (path_name, path) in selected_paths.items():
      new_hparams["paths"][path_name] = {
          "id": path.id,
          "agent_id": path.agent_id,
          "hparams": copy.deepcopy(path.hparams)}
    print("  Sampled subpaths:",
          {k: v["agent_id"] for k, v in new_hparams["paths"].items()})
    # Set headless model config for models of different tasks.
    for k in new_hparams["paths"]:
      if selected_paths[k].task_name != self.config.task_name:
        new_hparams["paths"][k]["hparams"]["num_classes"] = 0
    # Collect image sizes needed.
    image_sizes = set()
    for k in new_hparams["paths"]:
      image_sizes.add(int(new_hparams["paths"][k]["hparams"]["ds_image_size"]))
    new_hparams["ds_image_size"] = list(image_sizes)
    # Collect components.
    init_params = self.get_model_factory().get_init_comps(new_hparams, self.config)
    comps = []
    for new_comp_name in init_params:
      if new_comp_name in new_hparams["paths"]:
        comps.append(ComponentPath(name=new_comp_name,
                                   path=selected_paths[new_comp_name]))
      else:
        comps_lookup = {c.name:c for c in parent.components}
        if new_comp_name in comps_lookup and (
          jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) ==
          jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):
          print("  COMP Reusing", new_comp_name)
          comp = comps_lookup[new_comp_name].clone(agent_id=self.config.agent_id)
        else:
          print("  COMP Init", new_comp_name)
          comp = Component(name=new_comp_name, agent_id=self.config.agent_id,
                           params=init_params[new_comp_name])
        comps.append(comp)
    return Path(new_hparams, comps, parent=parent,
                agent_id=self.config.agent_id, task_name=self.config.task_name)

  def finalize_hps(self, hparams, paths):
    # Validate shared params.
    for k in hparams:
      if self.is_generation_hp(k) and k != "ds_image_size":
        for path in paths:
          assert hparams[k] == path.hparams[k], (k, hparams[k], path.hparams[k])
    image_sizes = set()
    for path in paths:
      image_sizes.update(path.hparams["ds_image_size"])
    hparams["ds_image_size"] = list(image_sizes)
    print("Image sizes:", hparams["ds_image_size"])
    return hparams

  def print_best_path_summary(self):
    super().print_best_path_summary()
    best_path = self.pop.get_best_path()
    print("Paths used by best model:",
          {k: best_path.hparams["paths"][k]["agent_id"] for k in best_path.hparams["paths"]})

  def get_paths_to_publish(self):
    paths_to_publish = [self.pop.get_best_path()]
    for path in list(paths_to_publish):
      for k in path.hparams["paths"]:
        paths_to_publish.append(self.pop.get_path_from_full_id(path.hparams["paths"][k]["agent_id"], path.hparams["paths"][k]["id"]))
    return paths_to_publish

class MultiVitT3(MultiVit):
  def get_config(self):
    return set_multivit_common_config(get_config_vit_ti3())

class MultiVitB(MultiVit):
  def get_config(self):
    return set_multivit_common_config(get_config_vit_base())

# Tasks

In [None]:
TFDS_BUILDERS_CACHE = {}
def get_tfds_builder(tfds_name):
  global TFDS_BUILDERS_CACHE
  if tfds_name not in TFDS_BUILDERS_CACHE:
    TFDS_BUILDERS_CACHE[tfds_name] = tfds.builder(tfds_name)
    TFDS_BUILDERS_CACHE[tfds_name].download_and_prepare()
  return TFDS_BUILDERS_CACHE[tfds_name]

In [None]:
def get_versioned_tfds_name(tfds_name):
  version = TFDS_IMAGE_CLASSIFCATON_TASKS.get(tfds_name, {}).get("version", "")
  if version:
    return tfds_name + ":" + version
  return tfds_name

def get_default_splits(tfds_name):
  info = get_tfds_builder(tfds_name).info
  splits = list(info.splits.keys())
  assert "train" in splits, splits
  splits.remove("train")
  used_percent = 0
  slice_percent = 5
  pp = {}
  for k in ["test", "validation"]:
    if k in splits:
      pp[k] = k
      splits.remove(k)
    else:
      pp[k] = f"train[{used_percent}%:{used_percent+slice_percent}%]"
      used_percent += slice_percent
  pp["train"] = f"train[{used_percent}%:]"
  return pp

def get_dataset_and_splits(tfds_name: str):
  vtab_class = None
  if tfds_name in ["imagenet_v2", "cifar10_1"]:
    assert False,  f"{tfds_name} used as validation set for other tasks."
  dataset = get_versioned_tfds_name(tfds_name)
  if tfds_name == "imagenet2012":
    dataset = {"train":dataset, "validation":"imagenet_v2:3.0.0", "test":dataset}
    splits = {"train":"train", "validation":"test", "test":"validation"}
  elif tfds_name == "cifar100":
    splits = {"train":"train[:98%]", "validation":"train[98%:]", "test":"test"}
  elif tfds_name == "cifar10":
    dataset = {"train":dataset, "validation":"cifar10_1:1.1.0", "test":dataset}
    splits = {"train":"train", "validation":"test", "test":"test"}
  elif (tfds_name.startswith("visual_domain_decathlon/") or
        tfds_name in ["i_naturalist2017", "i_naturalist2018", "places365_small"]):
    # Test has no labels, split validation in half.
    splits =  {"train":"train", "validation":"validation[:50%]", "test":"validation[50%:]"}
  elif tfds_name.startswith("cmaterdb/"):
    # Increase size of validation set due to small dataset size.
    splits =  {"train":"train[20%:]", "validation":"train[:20%]", "test":"test"}
  elif tfds_name == "omniglot":
    # Test has no labels, and missing validation, use additional splits.
    splits = {"train":"train", "validation":"small1", "test":"small2"}
  elif tfds_name.startswith("controlled_noisy_web_labels/"):
    splits =  {"train":"train_00", "validation":"validation[:50%]", "test":"validation[50%:]"}
  elif tfds_name.startswith("cycle_gan/"):
    splits =  {"train":"trainA[10%:]+trainB[10%:]",
               "validation":"trainA[:10%]+trainB[:10%]",
               "test":"testA+testB"}
  elif tfds_name in ["imagenet_a", "imagenet_r", "imagenet_sketch",
                     "siscore/rotation", "siscore/size", "siscore/location",]:
    # Only test split.
    splits =  {"train":"test[10%:]", "validation":"test[5%:10%]", "test":"test[:5%]"}
  elif tfds_name in ["pet_finder"]:
    # Explicitly use only train split. E.g. test has no labels.
    splits =  {
        "train":"train[10%:]",
        "validation":"train[5%:10%]",
        "test":"train[:5%]"}
  elif tfds_name == "quickdraw_bitmap":
    # Cap size of test and validation set.
    splits =  {"train":"train[20000:]", "validation":"train[10000:20000]", "test":"train[:10000]"}
  elif tfds_name == "stanford_online_products":
    # Use the first 10k test samples as validation since test has 60k.
    splits =  {"train":"train", "validation":"test[:10000]", "test":"test[10000:]"}
  elif tfds_name in VTAB_TASKS or (
      tfds_name.endswith("/1k") and tfds_name.replace("/1k", "") in VTAB_TASKS):
    is_vtab_1k = tfds_name.endswith("/1k")
    tfds_name = tfds_name.replace("/1k", "")
    registry_name = {
        "diabetic_retinopathy_detection/btgraham-300": "diabetic_retinopathy",
        "svhn_cropped": "svhn",
        "cifar100": "cifar",
        "cifar10": "cifar",
    }.get(tfds_name, tfds_name.split("/")[0])
    args = {
        "clevr/count_all": ("count_all",),
        "clevr/count_cylinders": ("count_cylinders",),
        "clevr/closest_object_distance": ("closest_object_distance",),
        "dsprites/label_x_position": ("label_x_position",),
        "dsprites/label_orientation": ("label_orientation",),
        "kitti/closest_object_distance": ("closest_object_distance",),
        "kitti/count_vehicles": ("count_vehicles",),
        "kitti/closest_vehicle_distance": ("closest_vehicle_distance",),
        "smallnorb/label_category": ("label_category",),
        "smallnorb/label_lighting": ("label_lighting",),
        "smallnorb/label_azimuth": ("label_azimuth",),
        "smallnorb/label_elevation": ("label_elevation",),
        "cifar100": (100,),
        "cifar10": (10,),
    }.get(tfds_name, ())
    vtab_class = task_adapt_registry.Registry.lookup(
        f"data.{registry_name}")(*args)
    vtab_splits = vtab_class._tfds_splits
    dataset = {
        "caltech101": "caltech101:3.*.*",
        "dtd": "dtd:3.*.*",
        "oxford_flowers102": "oxford_flowers102:2.*.*",
        "oxford_iiit_pet": "oxford_iiit_pet:3.*.*",
        "sun397": "sun397/tfds:4.*.*",
        "svhn": "svhn_cropped:3.*.*",
        "patch_camelyon": "patch_camelyon:2.*.*",
        "eurosat": "eurosat/rgb:2.*.*",
        "resisc45": "resisc45:3.*.*",
        "diabetic_retinopathy": "diabetic_retinopathy_detection/btgraham-300:3.*.*",
        "clevr": "clevr:3.*.*",
        "dmlab": "dmlab:2.0.1",
        "dsprites": "dsprites:2.*.*",
        "kitti": "kitti:3.2.0",
        "smallnorb": "smallnorb:2.*.*",
        "cifar" : "cifar100:3.*.*" if tfds_name == "cifar100" else "cifar10:3.*.*",
    }[registry_name]
    if is_vtab_1k:
      splits =  {"train": str(vtab_splits["train800"]),
                 "validation": str(vtab_splits["val200"]),
                 "test": str(vtab_splits["test"])}
    else:
      splits =  {"train": str(vtab_splits["train"]),
                 "validation": str(vtab_splits["val"]),
                 "test": str(vtab_splits["test"])}
  else:
    splits = get_default_splits(tfds_name)
  return dataset, splits, vtab_class

class Task():
  def __init__(self, name, config):
    self.config = config

    self.dataset, self.splits, self.vtab_class = get_dataset_and_splits(name)
    self.name = name
    if self.vtab_class:
      self.num_classes = self.vtab_class.get_num_classes()
    else:
      self.num_classes = self.get_builder(
          "train").info.features[self.get_label_key()].num_classes
    num_train_examples = self.get_builder(
        "train").info.splits[self.splits["train"]].num_examples
    self.train_batch_size = config.batch_size
    self.num_train_batches_between_validations = math.ceil(
        min(num_train_examples,
            config.num_train_examples_between_validations_max)
        / self.train_batch_size)

    num_validation_examples_tot = self.get_builder(
        "validation").info.splits[self.splits["validation"]].num_examples
    if config.num_validation_examples_max <= num_validation_examples_tot:
      self.validation_batch_size = config.batch_size
      self.num_validation_batches = math.floor(
          config.num_validation_examples_max / self.validation_batch_size)
    else:
      # Adjust batch_size and num_batches to cover the smaller validation sets.
      self.num_validation_batches = math.ceil(
          num_validation_examples_tot / config.batch_size)
      self.validation_batch_size = math.floor(
          num_validation_examples_tot / self.num_validation_batches)
      assert num_validation_examples_tot >= (
          self.num_validation_batches*self.validation_batch_size)
    self.num_validation_examples = (
        self.num_validation_batches * self.validation_batch_size)

    print(f"Task: {self.name}")
    print(f"  Train batches between validations: {self.num_train_batches_between_validations}")
    print(f"  Validation batches: {self.num_validation_batches}")
    print(f"  Validation batch size: {self.validation_batch_size}")
    print(f"  Dataset {self.dataset}")
    print(f"  Splits {self.splits}")

  def get_label_key(self):
    return {
        "stanford_online_products": "super_class_id",
        }.get(self.name, "label")

  def get_builder(self, mode):
    if type(self.dataset) == str:
      return get_tfds_builder(self.dataset)
    return get_tfds_builder(self.dataset[mode])

  def get_ds(self, mode, hparams):
    data = self.get_builder(mode).as_dataset(
        split=self.splits[mode],
        shuffle_files=mode=="train")

    def _pp(data):  # Preprocessing function.
      im = data["image"]
      tf.debugging.assert_type(im, tf.uint8)

      if mode == "train":
        if hparams.get("ds_quality_delta", 0.0) > 0.0:
          im = tf.image.random_jpeg_quality(
              im,
              min_jpeg_quality=int(100 * (1 - hparams["ds_quality_delta"])),
              max_jpeg_quality=100)

      # Must have 3 channels.
      if im.shape[-1] == 1:
        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)
      assert im.shape[-1] == 3
      im = tf.cast(im, tf.float32)
      if mode == "train":
        if hparams.get("ds_area_range_min", 1.0) < 1.0:
          begin, size, _ = tf.image.sample_distorted_bounding_box(
              tf.shape(im),
              tf.zeros([0, 0, 4], tf.float32),
              aspect_ratio_range=[hparams["ds_aspect_ratio_range_min"],
                                  1.0/hparams["ds_aspect_ratio_range_min"]],
              area_range=[hparams["ds_area_range_min"], 1.0],
              min_object_covered=0,  # Min image overlap already set by area_range.
              use_image_if_no_bounding_boxes=True)
          im = tf.slice(im, begin, size)
          im.set_shape([None, None, 3]) # Restore dimension lost by tf.slice().
        if hparams.get("ds_flip_left_right", False):
          if tf.random.uniform(shape=[]) > 0.5:
            im = tf.image.flip_left_right(im)
        if hparams.get("ds_brightness_delta", 0.0) > 0.0:
          im = tf.image.random_brightness(
              im, max_delta=hparams["ds_brightness_delta"])
        if hparams.get("ds_contrast_delta", 0.0) > 0.0:
          im = tf.image.random_contrast(
              im, lower=1-hparams["ds_contrast_delta"],
              upper=1+hparams["ds_contrast_delta"])
        if hparams.get("ds_saturation_delta", 0.0) > 0.0:
          im = tf.image.random_saturation(
              im, lower=1-hparams["ds_saturation_delta"],
              upper=1+hparams["ds_saturation_delta"])
        if hparams.get("ds_hue_delta", 0.0) > 0.0:
          im = tf.image.random_hue(im, max_delta=hparams["ds_hue_delta"])

      def get_formatted_image(image, image_size):
        image = tf.image.resize(image, [image_size, image_size])
        # Values in range [-1 , 1].
        image = image / 127.5 - 1
        image = tf.clip_by_value(image, -1, 1)
        return image

      if type(hparams["ds_image_size"]) is list:
        out_im = {}
        for im_size in hparams["ds_image_size"]:
          out_im[str(im_size)] = get_formatted_image(im, int(im_size))
      else:
        out_im = get_formatted_image(im, int(hparams["ds_image_size"]))
      return {"image": out_im,
              "label": data[self.get_label_key()]}

    if mode == "validation":
      data = data.take(self.num_validation_examples).cache()
    if mode != "test":
      data = data.repeat()
    if self.vtab_class and self.vtab_class._base_preprocess_fn:
      data = data.map(self.vtab_class._base_preprocess_fn, tf.data.AUTOTUNE)
    data = data.map(_pp, tf.data.AUTOTUNE)
    if mode == "train":
      batch_size = self.train_batch_size
    else:
      batch_size = self.validation_batch_size
    data = data.batch(batch_size)
    if mode == "train":
      data = data.shuffle(10)
    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))

# Components

In [None]:
class Component():
  def reset_globals():
    Component.counter = 0
    Component.last_saved = -1

  def __init__(self, name, agent_id, params, id=None):
    self._name = name
    self.agent_id = agent_id
    self._params = jax.device_get(params)
    self._num_params = None
    if id is None:
      self.id = Component.counter
      Component.counter += 1
    else:
      self.id = id
    self.full_id = f"{self.agent_id}:{self.id}"

  @property
  def name(self):
    if self._name is None:
      self._name = get_comp_data(self, "name")
    return self._name

  @property
  def params(self):
    if self._params is None:
      self._params = get_comp_params(self)
    return self._params

  @property
  def num_params(self):
    if self._num_params is None:
      self._num_params = get_num_params(self.params)
    return self._num_params

  def is_trainable(self):
    return len(Population.unmutable_comp2agents[self.full_id]) == 0

  def clone(self, agent_id):
    return Component(name=self.name, agent_id=agent_id,
                     params=copy.deepcopy(jax.device_get(self.params)))

In [None]:
class ComponentPath():
  """Wraps a Paths to be used as a Component."""
  def __init__(self, name, path):
    self.name = name
    self.path = path

  def is_trainable(self):
    return False

  @property
  def params(self):
    return flax.core.freeze(self.path.get_all_params())

In [None]:
AID2CDF = {}  # Maps agent ids to a dataframe with components metadata.

def load_sharded_df(agent_dir, filename = "components"):
  dfs = []
  threads = []
  lock = Lock()
  def load_df(s_dir):
    df = df_read_from_csv(s_dir, filename)
    with lock:
      dfs.append(df)
  for s_dir in gfile.glob(agent_dir + f"/state_*"):
    threads.append(Thread(target=load_df, args=(s_dir,)))
    threads[-1].start()
  [t.join() for t in threads]
  return pd.concat(dfs)

def get_comp_data(comp, data_name):
  if comp.agent_id not in AID2CDF or comp.id not in AID2CDF[comp.agent_id]['id'].values:
    AID2CDF[comp.agent_id] = load_sharded_df(os.path.join(SYSTEM_STATE_DIR, comp.agent_id))
  c_data = AID2CDF[comp.agent_id].loc[AID2CDF[comp.agent_id]['id'] == comp.id]
  assert len(c_data) == 1
  return c_data[data_name].values[0]

In [None]:
CFID2PARAMS = {}  # Comps full id 2 comp parameters.

def get_comp_params(comp):
  if comp.full_id not in CFID2PARAMS:
    stt = time.time()
    state_iid = get_comp_data(comp, "state_iid")
    agent_dir = os.path.join(SYSTEM_STATE_DIR, comp.agent_id)
    chkpnt = os.path.join(agent_dir, f"state_{state_iid}/checkpoint_{state_iid}")
    loaded_params = flax.core.freeze(flax_checkpoints.restore_checkpoint(
        ckpt_dir=os.path.dirname(chkpnt), target=None))
    for k in loaded_params.keys():
      _, c_agent_id, c_id = k.split(":")
      full_id = f"{c_agent_id}:{c_id}"
      assert full_id not in CFID2PARAMS, full_id
      CFID2PARAMS[full_id] = loaded_params[k]
    print(f"  Loaded components: {time.time()-stt:.2f}s", chkpnt, loaded_params.keys())
  return CFID2PARAMS[comp.full_id]

 # Paths & Population

In [None]:
class Path():
  def reset_globals(config):
    Path.config = config
    Path.counter = 0
    Path.last_saved = -1
    Path.paths = []
    Path.scorer = ScorerDecay(**config.get("scorer_kwargs", {}))

  def __init__(self, hparams, components, parent, agent_id, task_name, id=None):
    self.components = components
    if id is None:
      self.id = Path.counter
      Path.counter += 1
    else:
      self.id = id
    self.agent_id = agent_id
    self.full_id = f"{self.agent_id}:{self.id}"
    self.task_name = task_name
    self.hparams = hparams
    self.parent = parent
    if self.parent is not None:
      self.hparams["parent_agent_id"] = parent.agent_id
      self.hparams["parent_id"] = int(parent.id)
    self._model = None
    self.metrics = {"generation": parent.metrics["generation"]+1 if parent else 0}
    Path.paths.append(self)

  @property
  def model_factory(self):
    return get_agent_class(self.agent_id).get_model_factory()

  @property
  def model(self):
    if self._model == None:
      self._model = self.model_factory.get_model(self.hparams, self.config)
    return self._model

  def comps_only(self):  # Exclude wrapped paths.
    return [c for c in self.components if c.__class__ is Component]

  def score(self):
    return Path.scorer.score(self)

  def get_all_params(self):
    params = {}
    for c in self.components:
      assert c.name not in params, c.name
      params[c.name] = c.params
    return flax.core.freeze(params)

  def get_trainable_params(self):
    params = {}
    for c in self.components:
      if c.is_trainable():
        assert c.name not in params, c.name
        params[c.name] = c.params
    return flax.core.freeze(params)

  def get_fixed_params(self):
    params = {}
    for c in self.components:
      if not c.is_trainable():
        assert c.name not in params, c.name
        params[c.name] = c.params
    return flax.core.freeze(params)

  def update_trainable(self, trained_params):
    trainable_count = 0
    for c in self.components:
      if c.is_trainable():
        trainable_count += 1
        assert c.name in trained_params.keys()
        c._params = trained_params[c.name]
    assert len(trained_params.keys()) == trainable_count, (
        f"{len(trained_params.keys())} {trainable_count}")

  def get_num_accounted_params(self):
    rtn = 0
    for c in sorted(self.components, key=lambda c: c.id):
      agnts = copy.copy(Population.unmutable_comp2agents[c.full_id])
      agnts.add(self.agent_id)
      rtn += c.num_params/len(agnts)
    return rtn

  def get_flops(self):
    return compute_flops_hlo(
          partial(self.model.apply, train=False),
          {"params": self.model_factory.get_comps2model_fn()(merge_params(
              self.get_trainable_params(),
              self.get_fixed_params()))},
          self.model_factory.get_sample_input(self.hparams))

In [None]:
class ScorerDecay():
  def __init__(self, scale_factor=1, base_accounted_params=0, base_flops=0):
    assert 0.0 < scale_factor <= 1.0
    self.scale_factor = scale_factor
    self.base_accounted_params = base_accounted_params
    self.base_flops = base_flops

  def score(self, path):
    if "quality" not in path.metrics or math.isnan(path.metrics["quality"]):
      return None
    assert path.metrics["quality"] >= 0, (
        f"{path.task_name} {path.metrics['quality']}")
    score = path.metrics["quality"]
    if self.base_accounted_params > 0:
      if "accounted_params" not in path.metrics:
        path.metrics["accounted_params"] = path.get_num_accounted_params()
      score *= self.scale_factor ** (path.metrics["accounted_params"]/self.base_accounted_params)
    if self.base_flops > 0:
      if "flops" not in path.metrics:
        path.metrics["flops"] = path.get_flops()
      score *= self.scale_factor ** (path.metrics["flops"]/self.base_flops)
    assert score >= 0
    path.metrics["score"] = score
    return score

In [None]:
def get_best_path(paths):
  if paths:  # Oldest path achieving max score.
    return max(sorted(paths, key=lambda p: p.id), key=lambda p: p.score())

In [None]:
class Population():
  def __init__(self, config):
    self.config = config
    Path.reset_globals(config)
    Component.reset_globals()
    self.paths = defaultdict(list)
    self.unmutable_comp2agents = None

  def get_best_path(self):
    return get_best_path(self.paths[self.config.agent_id])

  def get_path_from_full_id(self, agent_id, path_id):
    for p in self.paths[agent_id]:
      if p.id == path_id:
        return p
    assert False, f"Path not found {agent_id}:{path_id}"

In [None]:
pd.set_option("display.expand_frame_repr", False)
pd.set_option("display.max_columns", 100)

def pop_to_df(pop):
  return paths_to_df([p for paths in pop.paths.values() for p in paths])

def paths_to_df(paths):
  def _format(x):
    if type(x) in [dict, list]:
      return json.dumps(x)
    return x
  metrics_keys = set()
  hparams_keys = set()
  for path in paths:
    metrics_keys.update(path.metrics)
    hparams_keys.update(path.hparams)
  data = defaultdict(list)
  for path in paths:
    data["agent_id"].append(path.agent_id)
    data["task_name"].append(path.task_name)
    data["id"].append(path.id)
    data["components"].append(",".join(
        [f"{c.agent_id}:{c.id}" for c in path.comps_only()]))
    for k in hparams_keys:
      data[f"hparams.{k}"].append(_format(path.hparams[k]) if k in path.hparams else None)
    for k in metrics_keys:
      data[f"metrics.{k}"].append(path.metrics[k] if k in path.metrics else None)
  return pd.DataFrame(data)

def components_to_df(paths, agent_id, last_saved_iid, state_iid):
  comps = set()
  for p in paths:
    comps.update(p.comps_only())
  data = defaultdict(list)
  for c in comps:
    if c.agent_id != agent_id or c.id <= last_saved_iid:
      continue
    data["agent_id"].append(c.agent_id)
    data["id"].append(c.id)
    data["name"].append(c.name)
    data["num_params"].append(c.num_params)
    data["state_iid"].append(state_iid)
  return pd.DataFrame(data)

def print_df_segments(df, segment_length = 5):
  tot_length = df.shape[0]
  # Pad column title with spaces to keep alignment across segments.
  def prepend_spaces(original_str, pad_to_len):
    return " " * (pad_to_len-len(original_str)) + original_str
  pad_to_len = max([len(tn) for tn in set(df["agent_id"].to_list())])+1
  df = df.rename(columns={"agent_id": prepend_spaces("agent_id", pad_to_len),
                          "task_name": prepend_spaces("task_name", pad_to_len),
                          "hparams.parent_agent_id": prepend_spaces("hparams.parent_agent_id", pad_to_len)})
  for x in range(0, tot_length, segment_length):
    print(df[x:min(x+segment_length, tot_length)])

def df_leaderboard(df):
  # Place columns on the left for readability.
  all_keys = sorted(df.columns.tolist())
  first_keys = ["agent_id", "id", "task_name", "metrics.test_quality", "metrics.score",
                "metrics.quality", "metrics.accounted_params", "metrics.flops",
                "hparams.parent_agent_id", "hparams.parent_id"]
  first_keys = [k for k in first_keys if k in all_keys]
  sorted_keys = first_keys + [k for k in all_keys if k not in first_keys]
  sorted_keys = [k for k in sorted_keys if "_mu_|" not in k]  # Remove mu function parameters.
  df = df[sorted_keys]
  if "metrics.score" in df:
    df = df.sort_values(["agent_id", "metrics.score"], ascending=[True, False], ignore_index=True)
  else:
    df = df.sort_values("agent_id", ignore_index=True)
  print_df_segments(df)

# Checkpointing

In [None]:
def df_write_to_csv(df, dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "w") as outfile:
    df.to_csv(outfile, index=False)

def df_read_from_csv(dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "r") as infile:
    df = pd.read_csv(infile)
  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in
  # columns with type strings/object.
  for c in df.columns:
    if df[c].dtype == np.object_:
        df[c].fillna("", inplace=True)
  return df

def get_comps_params_to_save(pop):
  comps_params = {}
  # All components generated by this agent.
  all_comps = set(
      [c for p in pop.paths[pop.config.agent_id] for c in p.comps_only() if c.agent_id == pop.config.agent_id])
  # Check that there are not duplicate ids.
  assert len(all_comps) == len(set([c.id for c in all_comps])), (
      sorted([f"{c.name}:{c.agent_id}:{c.id}" for c in all_comps], key=lambda c: c.id))
  for c in all_comps:
    if c.id <= Component.last_saved:
      continue
    assert c.agent_id == pop.config.agent_id
    c_id_string = f"{c.name}:{c.agent_id}:{c.id}"
    comps_params[c_id_string] = c.params
  return comps_params

In [None]:
def latest_checkpoint(ckpt_dir, prefix = "checkpoint_"):
  ckpt_dir = os.fspath(ckpt_dir)
  glob_path = os.path.join(ckpt_dir, f"{prefix}*")
  checkpoint_files = flax_checkpoints.natural_sort(gfile.glob(glob_path))
  checkpoint_files = [f for f in checkpoint_files if not f.endswith("_tmp")]
  return checkpoint_files[-1] if checkpoint_files else None

In [None]:
def save_checkpoint(ckpt_dir, comps_params, num_paths_sampled):
  print("  Saving components", num_paths_sampled, comps_params.keys())
  flax_checkpoints.save_checkpoint(
      ckpt_dir, target=comps_params, step=num_paths_sampled, overwrite=True)

def save_state(agent):
  write_start = time.time()
  # Save data needed to resume exp.
  state_dir = os.path.join(agent.config.agent_dir, f"state_{agent.num_paths_sampled}")
  gfile.makedirs(state_dir)
  assert not latest_checkpoint(state_dir), f"Checkpoint already present in forlder: {state_dir}"
  print("Saving state:", agent.num_paths_sampled)
  df_write_to_csv(paths_to_df(agent.get_paths_to_publish()), state_dir, "published")
  df_write_to_csv(paths_to_df([p for p in agent.pop.paths[agent.config.agent_id]]), state_dir, "population")
  df_write_to_csv(paths_to_df(Path.paths).query(f'agent_id=="{agent.config.agent_id}" and id>{Path.last_saved}'), state_dir, "paths")
  df_write_to_csv(components_to_df(Path.paths, agent.config.agent_id, Component.last_saved, agent.num_paths_sampled), state_dir, "components")
  json.dump(agent.config.as_configdict().to_dict(), gfile.GFile(os.path.join(state_dir, "config.json"), "w"), indent=2)
  save_checkpoint(state_dir, get_comps_params_to_save(agent.pop), agent.num_paths_sampled)
  print(f"  State save time: {time.time() - write_start:.2f}s")

In [None]:
def load_paths(agent):
  # Load system state metadata.
  population_df = []
  Population.unmutable_comp2agents = defaultdict(set)
  for agent_to_load in agent.agents_to_load():
    for agent_dir in gfile.glob(os.path.join(agent.config.system_state_dir, agent_to_load)):
      agent_checkpoint = latest_checkpoint(os.path.join(agent_dir, "state_*/"))
      if agent_checkpoint:
        add_pdf = df_read_from_csv(os.path.dirname(agent_checkpoint), "published")
        if agent_dir.endswith("/"+agent.id):
          agent.num_paths_sampled = int(os.path.basename(agent_checkpoint).removeprefix("checkpoint_"))
          state_dir = os.path.dirname(agent_checkpoint)
          paths_df = df_read_from_csv(state_dir, "paths")
          comps_df = df_read_from_csv(state_dir, "components")
          def validate_df(df):
            assert len(df["agent_id"].unique()) == 1, len(df["agent_id"].unique())
            assert df["agent_id"].unique()[0] == agent.id, df["agent_id"].unique()[0]
          validate_df(paths_df)
          validate_df(comps_df)
          Path.paths = []
          Path.last_saved = int(paths_df.id.max())
          Component.last_saved = int(comps_df.id.max())
          Path.counter = Path.last_saved + 1
          Component.counter = Component.last_saved + 1
          print("CONTINUING FROM STATE", agent.num_paths_sampled)
          population_df.append(df_read_from_csv(state_dir, "population"))
        else:
          for _, arow in add_pdf.iterrows():
            for cid in arow["components"].split(","):
              Population.unmutable_comp2agents[cid].add(arow["agent_id"])
        population_df.append(add_pdf)
  population_df = pd.concat(population_df)
  # Create paths.
  cfid2comp = {}
  for _, row in population_df.drop_duplicates(["agent_id", "id"]).iterrows():
    hparams = {}
    metrics = {}
    for k, v in row.iteritems():
      if type(v) is float and math.isnan(v):
        continue
      if k.startswith("hparams."):
        if type(v) == str and (v.startswith("{") or v.startswith("[")):
          v = json.loads(v)
        hparams[k.removeprefix("hparams.")] = v
      if k.startswith("metrics."):
        metrics[k.removeprefix("metrics.")] = v
    if "accounted_params" in metrics:
      del metrics["accounted_params"]  # Triggers recomputation.
    comps = []
    for c_fid in row["components"].split(","):
      if c_fid not in cfid2comp:
        c_agent_id, cid = c_fid.split(':')
        cfid2comp[c_fid] = Component(name=None, agent_id=c_agent_id, params=None, id=int(cid))
      comps.append(cfid2comp[c_fid])
    path = Path(hparams, comps, parent=None, agent_id=row["agent_id"],
                task_name=row["task_name"], id=int(row["id"]))
    path.metrics = metrics
    agent.pop.paths[row["agent_id"]].append(path)
  # Add components representing subpaths.
  for path in [p for paths in agent.pop.paths.values() for p in paths]:
    if "paths" in path.hparams:
      for k in path.hparams["paths"]:
        sub_path = agent.pop.get_path_from_full_id(
            path.hparams["paths"][k]["agent_id"], path.hparams["paths"][k]["id"])
        path.components.append(ComponentPath(name=k, path=sub_path))

# Training

In [None]:
@partial(jax.jit, static_argnames="model")
def eval_step(params, inputs, labels, model):
  logits = model.apply({"params": params}, inputs, train=False)
  # Avg accuracy on the batch.
  return (logits.argmax(axis=-1) == labels).mean()

In [None]:
@partial(jax.jit, static_argnames=["model", "optimizer", "format_params_fn"], donate_argnums=[0, 2])
def train_step(params, fixed_params, opt_state, inputs, labels, model, optimizer, format_params_fn):
  def loss_fn(params, fixed_params, inputs, labels):
    logits = model.apply(
        {"params": format_params_fn(merge_params(params, fixed_params))},
        inputs, train=True)
    labels = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))
  grads = jax.grad(loss_fn)(params, fixed_params, inputs, labels)
  updates, opt_state = optimizer.update(grads, opt_state, params=params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [None]:
def execute_train_step(path, train_batch):
  path.params_device, path.opt_state_device = train_step(
      path.params_device,
      path.fixed_params_device,
      path.opt_state_device,
      train_batch["image"],
      train_batch["label"],
      path.model,
      path.optimizer,
      path.model_factory.get_comps2model_fn())

def execute_eval_step(path, eval_batch):
  path.accs.append(
      eval_step(
          path.model_factory.get_comps2model_fn()(merge_params(
              path.params_device, path.fixed_params_device)),
          eval_batch["image"],
          eval_batch["label"],
          path.model))

In [None]:
PREV_LOOP_END = time.time()

def train_loop(agent, paths, ds_train, ds_validation, devices):
  global PREV_LOOP_END
  timing = {}
  task = agent.task
  for p_id, path in enumerate(paths):
    path.device_id = p_id % len(devices)
    path.device = devices[path.device_id]
    path.optimizer = get_optimizer(
        opt_lr=path.hparams["opt_lr"],
        opt_lr_schedule=path.hparams["opt_lr_schedule"],
        opt_lr_warmup_ratio=path.hparams["opt_lr_warmup_ratio"],
        opt_momentum=path.hparams["opt_momentum"],
        opt_nesterov=path.hparams["opt_nesterov"],
        num_train_batches_between_validations=task.num_train_batches_between_validations,
        num_validations_per_path_training=agent.config.num_validations_per_path_training)
    path.best_params_local = None
    path.best_quality = None
    path.best_score = path.parent.score() if path.agent_id == path.parent.agent_id else -np.inf
    path.evals = []
    path.exe_thread = None
  gc.collect()
  # Tranfer parameters to devices.
  for path in paths:
    path.params_device = jax.device_put(path.get_trainable_params(), path.device)
    path.fixed_params_device = jax.device_put(path.get_fixed_params(), path.device)
    path.opt_state_device = jax.jit(path.optimizer.init, device=path.device)(path.params_device)
  iter_ds_validation = iter(ds_validation)
  # Train loop.
  print(("\t".join([f"{path.id}" for path in paths]) + "\t< Path id").expandtabs(8))
  for t_step, train_batch in zip(
      range(agent.config.num_validations_per_path_training
            * task.num_train_batches_between_validations),
      ds_train):
    if t_step == 0:
      timing["start_train"] = time.time()
    for path in paths:
      train_batch = jax.device_put(train_batch, path.device)
      if path.exe_thread is not None:
        path.exe_thread.join()
      path.exe_thread = Thread(target=execute_train_step, args=(path, train_batch))
      path.exe_thread.start()
    if t_step == 0:
      [p.exe_thread.join() for p in paths]
      timing["end_train_compile"] = time.time()
    # Evaluation on validation set.
    if (t_step+1) % task.num_train_batches_between_validations == 0:
      for path in paths:
        path.accs = []
      for e_step, eval_batch in zip(range(task.num_validation_batches), iter_ds_validation):
        if e_step == 0:
          start_eval_round = time.time()
          if "start_eval" not in timing:
            timing["start_eval"] = start_eval_round
        for path in paths:
          eval_batch = jax.device_put(eval_batch, path.device)
          path.exe_thread.join()
          path.exe_thread = Thread(target=execute_eval_step, args=(path, eval_batch))
          path.exe_thread.start()
        if e_step == 0 and "end_eval_compile" not in timing:
          [p.exe_thread.join() for p in paths]
          timing["end_eval_compile"] = time.time()
      # Get params of best models.
      qs = []
      eval_idx = (t_step+1) // task.num_train_batches_between_validations
      for path in paths:
        path.exe_thread.join()
        quality = np.mean(path.accs)
        del path.accs
        qs.append(f"{quality:.4f}")
        path.evals.append(quality)
        # Set quality in metrics for current score computation.
        path.metrics["quality"] = quality
        path_score = path.score()
        if path_score > path.best_score:
          path.best_params_local = jax.device_get(path.params_device)
          path.best_score = path_score
          path.best_quality = quality
          qs[-1] += "*"
      time_train = time.time() - PREV_LOOP_END
      avg_path_time = (time_train / eval_idx) / len(paths)
      print(("\t".join(qs) + f"\t< Eval {eval_idx}").expandtabs(8),
            f"tot:{time_train:.1f}s", f"avg/path:{avg_path_time:.1f}s")
      timing["time_eval"] = timing.get("time_eval", 0) + (time.time() - start_eval_round)
      del eval_batch
  del train_batch
  for path in paths:
    del path.params_device
    del path.fixed_params_device
    del path.opt_state_device
    del path.optimizer
    del path.exe_thread
  gc.collect()
  timing["end_train"] = time.time()
  time_init = timing["start_train"] - PREV_LOOP_END
  time_train_compile = timing["end_train_compile"] - timing["start_train"]
  time_eval_compile = timing["end_eval_compile"] - timing["start_eval"]
  time_eval = timing["time_eval"] - time_eval_compile
  time_train = timing["end_train"] - timing["end_train_compile"] - time_eval - time_eval_compile
  PREV_LOOP_END = timing["end_train"]
  for path in paths:
    path.metrics["time_init"] = time_init
    path.metrics["time_train_compile"] = time_train_compile
    path.metrics["time_eval_compile"] = time_eval_compile
    path.metrics["time_train"] = time_train
    path.metrics["time_eval"] = time_eval
    path.metrics["timestamp_end"] = PREV_LOOP_END
    path.metrics["num_params"] = get_num_params(path.get_all_params())
    path.metrics["num_trainable_params"] = get_num_params(path.get_trainable_params())
    path.metrics["quality"] = max(path.evals)
    path.metrics["evals"] = json.dumps([float(v) for v in path.evals])
    if path.best_params_local != None:
      path.metrics["improved"] = True
      path.update_trainable(path.best_params_local)
      assert path.best_quality == path.metrics["quality"]
      assert path.best_score == path.score()
    else:
      path.metrics["improved"] = False
      # Sampled path will be dropped if not improved, so skip paramter update.
      assert path.best_quality == None
    del path.best_params_local
    del path.best_score
    del path.best_quality
    del path.evals
  pqs = []
  qs = []
  psc = []
  sc = []
  for path in paths:
    if path.task_name == path.parent.task_name:
      metric_suffix = "" if path.agent_id == path.parent.agent_id else "A"
      pqs.append(f"{path.parent.metrics['quality']:.4f}{metric_suffix}")
      psc.append(f"{path.parent.score():.4f}{metric_suffix}")
    else:
      pqs.append("NEW")
      psc.append("NEW")
    qs.append(f"{path.metrics['quality']:.4f}")
    sc.append(f"{path.score():.4f}")
    if path.metrics["improved"]:
      sc[-1] += "+"
  print(("\t".join(pqs) + "\t< Parent best quality").expandtabs(8))
  print(("\t".join(qs) + "\t< Path best quality").expandtabs(8))
  print(("\t".join(psc) + "\t< Parent score").expandtabs(8))
  print(("\t".join(sc) + "\t< Path score").expandtabs(8))
  print("time\tINIT\tCOMPtrn\tCOMPevl\tTRN\tEVAL".expandtabs(8))
  print(f"(s)\t{time_init:.1f}\t{time_train_compile:.1f}\t{time_eval_compile:.1f}\t{time_train:.1f}\t{time_eval:.1f}".expandtabs(8))

In [None]:
def has_test_quality(path):
  return ("test_quality" in path.metrics and not math.isnan(path.metrics["test_quality"]))

# Run final eval on test set.
def run_test_eval(path, task, test_immutability=False):
  if not test_immutability and has_test_quality(path):
    return  # Skip if test_quality already computed and no immutability test required.
  eval_st = time.time()
  ds_test = task.get_ds("test", path.hparams)
  params = path.get_all_params()
  # Running on same device can allow to reuse the fn compiled for validation.
  if not hasattr(path, "device"):
    path.device = random.choice(jax.local_devices())  # Otherwise pick random device.
  params_device = jax.device_put(path.model_factory.get_comps2model_fn()(params), path.device)
  acc_sum = []
  tot_num_samples = 0
  # Warning: if repeat() is called on this dataset, then this loop never ends.
  for batch in ds_test:
    acc_avg = eval_step(params_device, batch["image"], batch["label"], path.model)
    batch_size = batch["label"].shape[0]
    # Need to recompute sum because last batch can have different size to allow
    # for exact eval on the test set.
    acc_sum.append(acc_avg * batch_size)
    tot_num_samples += batch_size
  del params_device
  acc_avg = np.sum(acc_sum) / tot_num_samples
  if has_test_quality(path):  # Test that new accuracy is close to recorded one.
    assert test_immutability
    print(f"Immutability test of path {path.id}: {path.metrics['test_quality']}~={acc_avg}")
    assert np.isclose(path.metrics["test_quality"], acc_avg, rtol=IMMUTABILITY_RELATIVE_TOLLERANCE)
  path.metrics["test_quality"] = acc_avg
  print(f"TEST QUALITY: {acc_avg}\nTEST TIME: {time.time()-eval_st:.2f}s")

# Main

In [None]:
AGENT = get_agent_class(AGENT_CLASS)(system_state_dir=SYSTEM_STATE_DIR, task_name=TASK_NAME, num_paths_sampled_max=NUM_PATHS_SAMPLED_MAX)
AGENT.run()