# License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Instructions

This Notebook allows to reproduce the experiments reported in the publication titled:

"[*Multipath Agents for Modular Multitask ML Systems*](https://arxiv.org/abs/2302.02721)" (2023)

---
To start an experiment:
---
1. Choose the agent type by setting the `AGENT` variable in the configuration below.
Select `Vit`  in order to choose the singlepath agent. Select `MultiVit` in order to the multipath agent.
These agents use a ViT-Large root model.
Set agent types suffixed with `T3` to use a ViT-Tiny root model capped to 3 layers. Set agent types suffixed with `B` to use a ViT-Base root model.

1. Set `TASK_NAME` to the string-id of the task assigned to the instantiation of the selected agent.
Refer to `TFDS_IMAGE_CLASSIFCATON_DATASETS` and `VTAB_TASKS` (below) for lists of tasks ids that have been tested with the current code. 
These lists contain task ids from the [Tensorflow Datasets Catalog](https://www.tensorflow.org/datasets/catalog/overview).
Note that some tasks require manual download, refer to the corresponding catalog page for instructions. **WARNING**: The system state needs to be populated with at least one root model before running an agent training on any task. In order to generate the root model, set `TASK_NAME` to either `"root_model/checkpoint"` or `"root_model/random_init"` for respectively loading a pretrained root model or generating a randomly initialized one.

1. Set `NUM_CYCLES_MAX` to the desired number of evolutionary cycles. Additional configuration parameters can be modified in the Agents definitions code below. Configurations are set to the settings described in the publication.

1. Set `SYSTEM_STATE_RELATIVE_DIR` to a relative path from where the system state will be read and written.

1. By default, the system state is stored under a temporary folder within the Virtual Machine (VM) memory. This temporary folder is deleted when the VM is stopped or restarted.
It is possible to store the system state folder on your Google Drive by activating the
`SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE` option. In this case, you will be prompted for access approval and the system state folder will be saved in a folder named `"munet_experiments"` under your Google Drive root folder. Furthermore, it is also possible to store the system state into a Google Drive folder shared with multiple users by creating a link to the shared folder into your Google Drive and then setting `GDRIVE_ROOT_DIR` (below) to the path of the linked shared folder. 

1. To start the experiment, select "Connect to a hosted runtime" from the dropdown menu on the top right, and then select "Run all" from the "Runtime" menu. A free hosted CPU runtime is always available. Free access to GPU and TPU accelerators are occasionally provided by the service depending on availability.

---
During the experiment execution:
---

1. The print output is displayed after the last cell of this Colab.

1. The system state folder is populated with a subfolder for each agent.
The name of each agent folder is prefixed with the `AGENT` type string and suffixed with the `TASK_NAME`.
Each agent directory is populated with incremental state subfolders  containing the sharded state of the architectures and parameters generated by the agent during the corresponding evolutionary cycle.

1. Agents can be started asynchronously and run in parallel in varying quantities.
It is possible to resume an interrupted agent training by restarting the execution with the same configuration.
It is possible to continue a completed training by increasing `NUM_CYCLES_MAX`.

1. To achieve a multi-agent execution, multiple Colabs need to be run in parallel, each set to the same configuration but different `TASK_NAME`.

1. To achieve heterogeneous hardware execution, parallel Colab Notebooks can be connected to a runtime of different types.
It is possible to switch between CPU, GPU and TPU by selecting `Change runtime type` in the `Resources` tab in this Colab Notebook.

In [None]:
# @title Agent parameters
AGENT = "VitT3" # @param ["VitT3", "VitB", "Vit", "MultiVitT3", "MultiVitB", "MultiVit"] { type: "string", isTemplate: true }
# Set TASK_NAME to "root_model/checkpoint" or "root_model/random_init" to initalize the population.
TASK_NAME = "root_model/checkpoint"  # @param { type: "string", isTemplate: true }
NUM_CYCLES_MAX = 1 # @param { type: "integer", isTemplate: true }
SYSTEM_STATE_RELATIVE_DIR = "munet_system_state/"  # @param { type: "string", isTemplate: true }

In [None]:
# Saves system state on Google drive instead of saving it in a temporary VM folder.
SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE = False  # @param { type: "boolean", isTemplate: true }
if SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE:
  from google.colab import drive
  drive.mount('/content/gdrive')
  GDRIVE_ROOT_DIR = "/content/gdrive/My Drive/munet_experiments/"
  SYSTEM_STATE_DIR = GDRIVE_ROOT_DIR + SYSTEM_STATE_RELATIVE_DIR
  print("Saving system state in Google Drive.")
else:
  SYSTEM_STATE_DIR = "/tmp/" + SYSTEM_STATE_RELATIVE_DIR
  print("WARNING: Saving system state in VM, state will be lost after reboot!")

In [None]:
# Test immutability of published paths at beginning of each cycle.
# Tollerance may be increased if the system is run with any context difference: e.g. harware, input preprocessing, libraries or datsets version.
TEST_IMMUTABILITY = False
IMMUTABILITY_RELATIVE_TOLLERANCE = 0.001  # 0.1%

# Imports

In [None]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git
!pip install -q ml_collections
!pip install -q tensorflow_addons
![ -d task_adaptation ] || git clone --depth=1 https://github.com/google-research/task_adaptation
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer

import sys
if './task_adaptation' not in sys.path:
  sys.path.append('./task_adaptation')
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')

import jax.tools.colab_tpu
try:
  jax.tools.colab_tpu.setup_tpu()
except:
  pass  # Not a Tpu

In [None]:
import copy
import flax
import flax.linen as nn
import gc
import inspect
import jax
import jax.numpy as jnp
import json
import math
import numpy as np
import optax
import os
import pandas as pd
import random
import re
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from collections import defaultdict
from functools import partial
from flax.training import checkpoints as flax_checkpoints
from ml_collections import ConfigDict, FrozenConfigDict
from tensorflow.io import gfile
from threading import Thread, Lock
from typing import Any, Type

tf.compat.v1.enable_eager_execution()

In [None]:
# ViT imports.
from vision_transformer.vit_jax import input_pipeline
from vision_transformer.vit_jax import checkpoint
from vision_transformer.vit_jax.configs import models as models_config  # Model configurations.
from vision_transformer.vit_jax import models_vit as models  # Actual model code.
# VTAB imports.
import task_adaptation.registry as task_adapt_registry
import task_adaptation.data.caltech
import task_adaptation.data.cifar
import task_adaptation.data.dtd
import task_adaptation.data.oxford_flowers102
import task_adaptation.data.oxford_iiit_pet
import task_adaptation.data.sun397
import task_adaptation.data.svhn
import task_adaptation.data.patch_camelyon
import task_adaptation.data.eurosat
import task_adaptation.data.resisc45
import task_adaptation.data.diabetic_retinopathy
import task_adaptation.data.clevr
import task_adaptation.data.dmlab
import task_adaptation.data.dsprites
import task_adaptation.data.kitti
import task_adaptation.data.smallnorb

# Utils

In [None]:
# Ref. Tfds catalog: https://www.tensorflow.org/datasets/catalog/overview
TFDS_IMAGE_CLASSIFCATON_DATASETS = set([
    "beans",
    "binary_alpha_digits",
    "caltech_birds2010",
    "caltech_birds2011",
    "cars196",
    "cassava",
    "cats_vs_dogs",
    "cifar10",
    "cifar100",
    "citrus_leaves",
    "cmaterdb/bangla",
    "cmaterdb/devanagari",
    "cmaterdb/telugu",
    "colorectal_histology",
    "controlled_noisy_web_labels/mini_imagenet_red",
    "controlled_noisy_web_labels/mini_imagenet_blue",
    "curated_breast_imaging_ddsm/patches",
    "cycle_gan/apple2orange",
    "cycle_gan/summer2winter_yosemite",
    "cycle_gan/horse2zebra",
    "cycle_gan/monet2photo",
    "cycle_gan/cezanne2photo",
    "cycle_gan/ukiyoe2photo",
    "cycle_gan/vangogh2photo",
    "cycle_gan/maps",
    "cycle_gan/cityscapes",
    "cycle_gan/facades",
    "cycle_gan/iphone2dslr_flower",
    "deep_weeds",
    "domainnet/real",
    "domainnet/painting",
    "domainnet/clipart",
    "domainnet/quickdraw",
    "domainnet/infograph",
    "domainnet/sketch",
    "emnist/balanced",
    "emnist/byclass",
    "emnist/bymerge",
    "emnist/digits",
    "emnist/letters",
    "emnist/mnist",
    "fashion_mnist",
    "food101",
    "horses_or_humans",
    "i_naturalist2017",
    "i_naturalist2018",
    "imagenet2012",
    "imagenet_a",
    "imagenet_lt",
    "imagenet_r",
    "imagenet_sketch",
    "imagenette",
    "imagewang",
    "kmnist",
    "malaria",
    "mnist",
    "omniglot",
    "pet_finder",
    "places365_small",
    "plant_village",
    "plantae_k",
    "quickdraw_bitmap",
    "rock_paper_scissors",
    "siscore/rotation",
    "siscore/size",
    "siscore/location",
    "stanford_dogs",
    "stanford_online_products",
    "stl10",
    "tf_flowers",
    "uc_merced",
    "visual_domain_decathlon/aircraft",
    "visual_domain_decathlon/cifar100",
    "visual_domain_decathlon/daimlerpedcls",
    "visual_domain_decathlon/dtd",
    "visual_domain_decathlon/gtsrb",
    "visual_domain_decathlon/imagenet12",
    "visual_domain_decathlon/omniglot",
    "visual_domain_decathlon/svhn",
    "visual_domain_decathlon/ucf101",
    "visual_domain_decathlon/vgg-flowers",
    ])

In [None]:
# Append suffix "/1k" to get the 1k version of each task.
VTAB_TASKS = [
    "caltech101",
    # cifar100/10 were already added with slightly different val split but same test set. So here is added only the 1k versions.
    "cifar100/1k",
    "cifar10/1k",
    "dtd",
    "oxford_flowers102",
    "oxford_iiit_pet",
    "sun397",
    "svhn_cropped",
    "patch_camelyon",
    "eurosat",
    "resisc45",
    "diabetic_retinopathy_detection/btgraham-300",
    "clevr/count_cylinders",  # Not in results table.
    "clevr/count_all",  # Clevr-Count
    "clevr/closest_object_distance",  # Clevr-Dist
    "dmlab",
    "dsprites/label_x_position",  # dSpr-Loc
    "dsprites/label_orientation",  # dSpr-Ori
    "kitti/closest_object_distance",  # Not in results table.
    "kitti/count_vehicles",  # Not in results table.
    "kitti/closest_vehicle_distance",  # Kitti-dist
    "smallnorb/label_category",  # Not in results table.
    "smallnorb/label_lighting",  # Not in results table.
    "smallnorb/label_azimuth",  # Azim
    "smallnorb/label_elevation",  # Elev
    ]
for tn in VTAB_TASKS:
  assert tn not in TFDS_IMAGE_CLASSIFCATON_DATASETS, tn

In [None]:
def compute_flops_hlo(flax_module, *a, **kw):
  # Compute flops on cpu for cross platform consistency.
  analysis = jax.jit(flax_module, backend='cpu').lower(*a, **kw).cost_analysis()
  return analysis["flops"]

In [None]:
class ObjectCache():
  def __init__(self, factory_fn):
    self.factory_fn = factory_fn
    self.factory_fn_signature = inspect.signature(factory_fn)
    self.cache = {}

  def __call__(self, *args, **kwargs):
    assert not args, "No positional arguments allowed."
    kw_params = {}
    fn_name = self.factory_fn.__name__
    fn_params = inspect.signature(self.factory_fn).parameters
    for k_param, v_param in fn_params.items():
      if k_param in kwargs:
        kw_params[k_param] = kwargs[k_param]
      elif v_param.default != v_param.empty:
        # Fallback to declared defalut value.
        kw_params[k_param] = fn_params[k_param].default
      else:
        assert False, (
            f"Missing value for argument {k_param} for function {fn_name}")

      if v_param.annotation != v_param.empty:
        # Apply annotated type.
        assert isinstance(type(v_param.annotation), type)
        kw_params[k_param] = v_param.annotation(kw_params[k_param])

    key = json.dumps(kw_params, sort_keys=True)
    if key not in self.cache:
      self.cache[key] = self.factory_fn(**kw_params)
      print(f"Added to cache: {fn_name}({key})  [cache size {len(self.cache)}]")
    return self.cache[key]

# Models

In [None]:
# Sample inputs
def get_sample_images(image_size, batch_size):
  return np.zeros((batch_size, image_size, image_size, 3))

In [None]:
def get_num_params(params):
  return sum(jax.tree_util.tree_flatten(
      jax.tree_util.tree_map(lambda p: np.prod(p.shape), params)
      )[0])

In [None]:
def get_optimizer(
    opt_lr: float,
    opt_lr_schedule: str,
    opt_lr_warmup_ratio: float,
    opt_momentum: float,
    opt_nesterov: bool,
    num_train_batches_between_validations: int,
    num_validations_per_path_training: int,
    ):
  min_lr = opt_lr / 1000.0
  if opt_lr_schedule == "constant":
    # Divide by 2 so that average lr is the same as other types.
    learning_rate = 0.5 * opt_lr
  elif opt_lr_schedule == "linear":
    train_steps = int(num_train_batches_between_validations * num_validations_per_path_training)
    warmup_steps = int(opt_lr_warmup_ratio * train_steps)
    schedules = [
        optax.linear_schedule(
            init_value=min_lr,
            end_value=opt_lr,
            transition_steps=warmup_steps),
        optax.linear_schedule(
            init_value=opt_lr,
            end_value=min_lr,
            transition_steps=train_steps-warmup_steps)]
    learning_rate = optax.join_schedules(schedules, [warmup_steps])
  elif opt_lr_schedule == "cosine":
    train_steps = int(num_train_batches_between_validations
                      * num_validations_per_path_training)
    learning_rate = optax.warmup_cosine_decay_schedule(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps)
  elif opt_lr_schedule == "restarts":
    train_steps = num_train_batches_between_validations
    repeats = num_validations_per_path_training
    kwargs = dict(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps,
    )
    kwargs = [kwargs] * repeats
    learning_rate = optax.sgdr_schedule(kwargs)
  else:
    assert False, f"Invalid lr schedule: {opt_lr_schedule}"

  return optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.sgd(
          learning_rate=learning_rate,
          momentum=opt_momentum,
          nesterov=opt_nesterov,
          accumulator_dtype=jnp.bfloat16))

In [None]:
def merge_params(a, b):
  params = a.copy(b)
  assert len(params) == len(a) + len(b)
  return params

## Vit Model

In [None]:
class VitModelFactory():
  @staticmethod
  def get_model(hparams, config):
    return get_vit_model(hparams, config)

  @staticmethod
  def get_init_comps(hparams, config):
    return get_vit_init_comps(hparams, config)

  @staticmethod
  def get_comps2model_fn():
    return vit_comps2model

  @staticmethod
  def get_sample_input(hparams):
    return get_sample_images(image_size=hparams["ds_image_size"], batch_size=1)

In [None]:
def get_vit_filename(query):
  df = checkpoint.get_augreg_df()
  res = df.query(query).filename.unique()
  assert len(res) == 1
  return res[0]

In [None]:
VIT_CONFIG_CACHE = {}

def get_vit_config(query):
  global VIT_CONFIG_CACHE
  if query not in VIT_CONFIG_CACHE:
    filename = get_vit_filename(query)
    config = models_config.AUGREG_CONFIGS[filename.split("-")[0]].copy_and_resolve_references()
    config.unlock()
    # Disable dropout.
    config.transformer.dropout_rate = 0.0
    config.transformer.attention_dropout_rate = 0.0
    config.lock()
    VIT_CONFIG_CACHE[query] = config
  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()

def get_set_vit_config(hparams, config):
  path_config = get_vit_config(config.vit_checkpoint_query)
  path_config.transformer.num_layers = int(hparams["num_layers"])
  path_config.unlock()
  path_config.num_classes = int(hparams["num_classes"])
  if "classifier" in hparams:
    path_config.classifier = hparams["classifier"]
  path_config.lock()
  path_config = FrozenConfigDict(path_config)
  return path_config

def get_max_num_layers(query):
  config = get_vit_config(query)
  return config.transformer.num_layers

In [None]:
# Get params from ViT checkpoints.
def get_vit_checkpoint_comps(image_size, query):
  filename = get_vit_filename(query)
  config = get_vit_config(query)
  model = models.VisionTransformer(**config, num_classes=1)  # num_classes unused.
  init_params = copy.deepcopy(jax.device_get(
      model.init(jax.random.PRNGKey(random.randrange(int(1e10))),
                 VitModelFactory.get_sample_input({"ds_image_size": image_size}),
                 train=False  # Disables dropout, no effect on params.
                 )["params"]))
  params = checkpoint.load_pretrained(
    pretrained_path=f"gs://vit_models/augreg/{filename}.npz",
    init_params=init_params,
    model_config=config)
  return vit_model2comps(params)

def get_vit_checkpoint_reshaped_posembed_component(
    agent_id: str, ds_image_size: int, query: str):
  params = get_vit_checkpoint_comps(ds_image_size, query)["posembed_input"]
  return Component(name="posembed_input",
                   agent_id=agent_id,
                   params=params,
                   train_locks=[])

In [None]:
# Get ViT model and init_params.
def get_vit_model(hparams, config):
  vit_config = get_set_vit_config(hparams, config)
  return models.VisionTransformer(**vit_config)

def get_vit_init_comps(hparams, config):
  model = get_vit_model(hparams, config)
  init_params = copy.deepcopy(jax.device_get(model.init(
      jax.random.PRNGKey(random.randrange(int(1e10))),
      VitModelFactory.get_sample_input(hparams),
      train=False  # Disables dropout, no effect on params.
      )["params"]))
  return vit_model2comps(init_params)

In [None]:
# ViT parameters mapping to components.
TRANSFORMER_KEYS = set(
    ["encoder_norm", "posembed_input" ] + \
    [f"encoderblock_{k}" for k in range(30)])

def vit_model2comps(params):
  new_params = {}
  for k in params.keys():
    if k == "Transformer":
      t_params = params[k]
      for t_k in t_params.keys():
        new_params[t_k] = t_params[t_k]
    else:
      new_params[k] = params[k]
  return flax.core.freeze(new_params)

def vit_comps2model(params):
  new_params = params.unfreeze()
  new_params["Transformer"] = {}
  for k in list(new_params.keys()):
    if k in TRANSFORMER_KEYS:
      new_params["Transformer"][k] = new_params.pop(k)
  assert len(new_params["Transformer"]) != 0
  return flax.core.freeze(new_params)

## MultiVit Model

In [None]:
class MultiVitModelFactory():
  @staticmethod
  def get_model(hparams, config):
    return get_multivit_model(hparams, config)

  @staticmethod
  def get_init_comps(hparams, config):
    return get_multivit_init_comps(hparams, config)

  @staticmethod
  def get_comps2model_fn():
    return multivit_comps2model

  @staticmethod
  def get_sample_input(hparams):
    return {str(k): get_sample_images(image_size=k, batch_size=1) for k in hparams["ds_image_size"]}

In [None]:
def get_multivit_init_comps(hparams, config):
  model = get_multivit_model(hparams, config)
  init_params = copy.deepcopy(jax.device_get(
      model.init(
          jax.random.PRNGKey(random.randrange(int(1e10))),
          MultiVitModelFactory.get_sample_input(hparams),
          train=False  # Disables dropout, no effect on params.
          )["params"]))
  return multivit_model2comps(init_params)

In [None]:
def multivit_comps2model(params):
  params = params.unfreeze()
  for k in params:
    if k.startswith("path_"):
      params[k] = vit_comps2model(flax.core.freeze(params[k]))
  return flax.core.freeze(params)

def multivit_model2comps(params):
  # Mapping of paths component skipped since those are never used from rand init.
  return params

In [None]:
class MultipathRouter(nn.Module):
  init_main_path_weight: float
  num_paths: int
  lr_mult: float

  @nn.compact
  def __call__(self, x):
    assert self.num_paths > 0
    assert self.lr_mult >= 0 and self.lr_mult <= 1
    init_bias = np.log((1/(1/self.init_main_path_weight -1))*(self.num_paths-1))
    x = nn.LayerNorm()(x)
    x = nn.Dense(self.num_paths,
                 kernel_init=nn.initializers.zeros,
                 bias_init=nn.initializers.constant(np.asarray([init_bias]+[0]*(self.num_paths-1)))
                 )(x)
    x = nn.softmax(x)
    x = self.lr_mult * x + (1-self.lr_mult) * jax.lax.stop_gradient(x)
    return x

class Connector(nn.Module):
  out_dim: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.out_dim,
                 kernel_init=nn.initializers.zeros,
                 bias_init=nn.initializers.zeros,
                 )(x)
    return x

class MultiVitModel(nn.Module):
  config: Any
  path_module: Type[nn.Module] = models.VisionTransformer
  main_path_name: str = "path_0"
  @nn.compact
  def __call__(self, inputs, *, train):
    logits_0 = self.path_module(
        name=self.main_path_name,
        **self.config.paths_configs[self.main_path_name])(
            inputs[self.config.paths_image_size[self.main_path_name]], train=train)
    out_dim = logits_0.shape[-1]
    weights = MultipathRouter(
        name="multipath_router",
        num_paths=len(self.config.paths_configs),
        **self.config.router)(
            logits_0)
    all_logits = [logits_0]
    for path_name in self.config.paths_configs:
      if path_name == self.main_path_name:
        continue
      representation = self.path_module(
          name=path_name,
          **self.config.paths_configs[path_name])(
              inputs[self.config.paths_image_size[path_name]], train=train)
      path_logits = Connector(
          name=f"head_adapter_{path_name}", out_dim=out_dim)(representation)
      all_logits.append(path_logits)
    stacked = jnp.stack(all_logits, axis=-1)
    logits_comb = jnp.einsum("BLp,Bp->BL", stacked, weights)
    logits_comb = jnp.einsum("BLp,Bp->BL", jax.lax.stop_gradient(stacked), weights)
    logits_sum = jnp.einsum("BLp,Bp->BL", stacked, jnp.ones_like(weights))
    logits_out = logits_comb - jax.lax.stop_gradient(logits_sum) + logits_sum
    return logits_out

In [None]:
def get_multivit_model(hparams, config):
  model_config = ConfigDict()
  model_config.paths_configs = {
      k: get_set_vit_config(hparams["paths"][k]["hparams"], config) for k in hparams["paths"]}
  model_config.paths_image_size = {
      k: str(hparams["paths"][k]["hparams"]["ds_image_size"]) for k in hparams["paths"]}
  model_config.router = {
      "init_main_path_weight": float(hparams["router_init_main_path_weight"]),
      "lr_mult": float(hparams["router_lr_mult"]),
  }
  return MultiVitModel(config=FrozenConfigDict(model_config))

# Agents

In [None]:
def format_agent_id(class_name, task_name):
  agent_id = f"{class_name}/{task_name}"
  assert "~" not in agent_id, f"Invalid agent id: {agent_id}"
  return agent_id.replace("/", "~")

def get_agent_class(agent_id):
  return globals()[agent_id.split("~")[0]]

In [None]:
def incremental_mutation(value, values_list):
  assert value in values_list, f"{value} not in {values_list}"
  idx = values_list.index(value)
  idx += 1 if np.random.uniform() < 0.5 else -1
  idx = max(0, min(len(values_list)-1, idx))
  return values_list[idx]

In [None]:
DATASET_HPARAMS_KEYS_PRERFIX = "ds_"

In [None]:
class Agent():
  @property
  def class_name(self):
    return self.__class__.__name__

  @property
  def id(self):
    return self.config.agent_id

  @staticmethod
  def get_model_factory():
    assert False, "Not implementd"

  def run(self):
    assert False, "Not implementd"

  def complete_config(self, system_state_dir, task_name, num_cycles_max):
    self.config.system_state_dir = system_state_dir
    self.config.task_name = task_name
    self.config.agent_id = format_agent_id(self.class_name, task_name)
    self.config.agent_dir = os.path.join(system_state_dir, self.id)
    self.config.num_cycles_max = num_cycles_max

  def agent_classes_to_load(self):
    # Defaults to load only agents of the same class. Extend this list to allow
    # to access the strucutres and parameters produced by different agent types.
    return [self.class_name]

In [None]:
def run_cycles(agent):
  config = agent.config
  task_name = config.task_name
  num_cycles = config.num_cycles_max
  for _ in range(num_cycles):
    agent.load_state()
    if agent.cycle_id >= num_cycles:
      break
    print("\n\n====")
    print(f"CYCLE: [{agent.cycle_id+1}/{num_cycles}]")
    agent.pop.start_cycle()
    agent_cycle(agent)
    agent.pop.end_cycle()
    agent.cycle_id += 1
    agent.generation_id = 0
    save_state(agent)
    if agent.cycle_id >= num_cycles:
      break

def run_root_model(agent):
  agent.load_state()
  save_state(agent)

In [None]:
# Run a full paths sampling iteration for a task.
def agent_cycle(agent):
  pop = agent.pop
  config = agent.config
  task = Path.cached_tasks(task_name=config.task_name)
  best_path = pop.get_best_path()
  if TEST_IMMUTABILITY and best_path:
    run_test_eval(best_path, test_immutability=True)
  devices = jax.local_devices()
  print("DEVICE COUNT:", len(devices))
  num_gen_batches = math.ceil(config.num_samples_per_cycle/len(devices))
  for _ in range(num_gen_batches):
    if agent.generation_id >= num_gen_batches:
      break
    print(f"----\nGENERATION: [{agent.generation_id+1}/{num_gen_batches}]")
    ds_hparams = agent.sample_ds_hparams()
    ds_hparams["num_classes"] = task.num_classes
    paths = []
    for i in range(len(devices)):
      print(f"Sampling path {Path.counter}")
      paths.append(agent.sample_path(ds_hparams))
      gc.collect()
    ds_hparams = agent.finalize_ds_hparams(ds_hparams, paths)
    ds_train = task.get_ds("train", ds_hparams)
    ds_validation = task.get_ds("validation", ds_hparams)
    train_loop(paths, ds_train, ds_validation, devices, config)
    for path in paths:
      path.metrics["generation_id"] = agent.generation_id
      if path.metrics["improved"]:
        assert path not in pop.paths[config.agent_id]
        pop.paths[config.agent_id].append(path)
    pop.prune_population()
    # Track best path.
    curr_best_path = pop.get_best_path()
    if curr_best_path != best_path:
      if best_path:
        assert curr_best_path.score() >= best_path.score()
      best_path = curr_best_path
      best_path.metrics["new_best"] = True
      agent.print_best_path_summary()
    agent.generation_id += 1
    df_leaderboard(pop_to_df(pop))
    if agent.generation_id < num_gen_batches:
      save_state(agent)
  assert best_path in pop.paths[config.agent_id], best_path
  run_test_eval(best_path)

## Vit Agent

In [None]:
def get_common_config_vit():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 300_000
  config.num_validations_per_path_training = 4
  config.num_validation_examples_max = 10_000
  config.num_samples_per_cycle = 16
  config.max_task_population_size = 5
  # Force finetune last layer norm that technically is part of the head.
  config.force_mutations = ["clone:encoder_norm"]
  config.scorer_kwargs = dict(
      scale_factor=0.99,
      base_accounted_params=2_200_000_000,
      base_flops=3_800_000_000_000,
      )
  config.hparams_defaults = {
      "_mu_": 0.2,
      "opt_lr": 0.02,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.02,
      "opt_momentum": 0.8,
      "opt_nesterov": True,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.hparams_mutation_ranges = {
      "_mu_": [0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30],
      "opt_lr": [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],
      "opt_lr_warmup_ratio": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3],
      "opt_momentum": [0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99],
      "opt_nesterov": [True, False],
      "ds_area_range_min": [0.05, 0.5, 0.95, 1.0],
      "ds_aspect_ratio_range_min": [0.5, 0.75, 1.0],
      "ds_flip_left_right": [True, False],
      "ds_brightness_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_contrast_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_saturation_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_hue_delta": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
      "ds_quality_delta": [ 0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
  }
  return config

def get_config_vit_large():
  config = get_common_config_vit()
  config.batch_size = 16
  # The query is used to get the model configs even if the checkpoint is not loaded.
  config.vit_checkpoint_query = 'name=="L/16" and ds=="i21k" and aug=="medium2" and wd==0.03 and sd==0.1'
  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.hparams_defaults["num_layers"] = max_num_layers
  config.hparams_mutation_ranges["num_layers"] = list(
      range(config.hparams_defaults["num_layers"]+1
            +1  # Allow to exceed root-model's layers by 1.
            ))
  config.hparams_defaults["ds_image_size"] = 384
  config.hparams_mutation_ranges["ds_image_size"] = [224, 384]
  return config

def get_config_vit_ti3():
  config = get_common_config_vit()
  config.batch_size = 512
  config.vit_checkpoint_query = 'name=="Ti/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'
  config.hparams_defaults["num_layers"] = 3
  config.hparams_mutation_ranges["num_layers"] = list(range(config.hparams_defaults["num_layers"]+1))
  config.hparams_defaults["ds_image_size"] = 32
  config.hparams_mutation_ranges["ds_image_size"] = [16*i for i in (range(1, 1+int(112/16)))]
  return config

def get_config_vit_base():
  config = get_common_config_vit()
  config.batch_size = 256
  config.vit_checkpoint_query = 'name=="B/16" and ds=="i21k" and aug=="medium1" and wd==0.1 and sd==0'
  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.hparams_defaults["num_layers"] = max_num_layers
  config.hparams_mutation_ranges["num_layers"] = list(range(config.hparams_defaults["num_layers"]+1))
  config.hparams_defaults["ds_image_size"] = 80
  config.hparams_mutation_ranges["ds_image_size"] = [16*i for i in (range(1, 1+int(112/16)))]
  return config

def config_validate(config):
  for khp in config.hparams_defaults:
    if khp in config.hparams_mutation_ranges:
      assert config.hparams_defaults[khp] in config.hparams_mutation_ranges[khp], khp
  for khp in config.hparams_mutation_ranges:
    assert khp in config.hparams_defaults, khp

In [None]:
class Vit(Agent):
  """ViT large"""
  def __init__(self, system_state_dir, task_name, num_cycles_max):
    self.config = self.get_config()
    self.complete_config(system_state_dir, task_name, num_cycles_max)
    self.cached_posembed_components = ObjectCache(get_vit_checkpoint_reshaped_posembed_component)

  @staticmethod
  def get_model_factory():
    return VitModelFactory

  def load_state(self):
    task_name = self.config.task_name
    self.pop = Population(self.config)
    self.cycle_id = 0
    self.generation_id = 0
    # Root models.
    if task_name.startswith("root_model/"):
      hparams = self.config.hparams_defaults.as_configdict()
      if task_name == "root_model/random_init":
        hparams["num_classes"] = 0  # Removes head layer.
        path_params = self.get_model_factory().get_init_comps(hparams, self.config)
      else:
        assert task_name == "root_model/checkpoint", task_name
        path_params = get_vit_checkpoint_comps(
            hparams["ds_image_size"],
            self.config.vit_checkpoint_query)
      path = Path(
          hparams,
          params2comps(path_params, train_locks=[self.id], agent_id=self.id),
          parent=None,
          agent_id=self.id,
          task_name=task_name)
      self.pop.paths[self.id].append(path)
      return

    # Load latest agent state.
    def validate_df(df):
      assert len(df["agent_id"].unique()) == 1, len(df["agent_id"].unique())
      assert df["agent_id"].unique()[0] == self.id, df["agent_id"].unique()[0]
    agent_checkpoint = latest_checkpoint(
        os.path.join(self.config.agent_dir, "state_*_*/"))
    if agent_checkpoint:
      matched = re.findall(r"checkpoint_([0-9]+)_([0-9]+)$", agent_checkpoint)
      assert len(matched) == 1
      self.cycle_id = int(matched[0][0])
      self.generation_id = int(matched[0][1])
      state_dir = os.path.dirname(agent_checkpoint)
      self.pop.paths_df = df_read_from_csv(state_dir, "paths")
      self.pop.comps_df = df_read_from_csv(state_dir, "components")
      validate_df(self.pop.paths_df)
      validate_df(self.pop.comps_df)
      # Set globals.
      Path.paths = []
      Path.counter = 1 + int(self.pop.paths_df.id.max())
      Component.counter = 1 + int(self.pop.comps_df.id.max())
      # Get id of the last componet saved in a non intermediate checkpoint.
      non_intermediated_checkpoint = latest_checkpoint(
          os.path.join(self.config.agent_dir, "state_*_0/"))
      if non_intermediated_checkpoint:
        ni_paths_df = df_read_from_csv(
            os.path.dirname(non_intermediated_checkpoint), "paths")
        validate_df(ni_paths_df)
        Path.last_saved = int(ni_paths_df.id.max())
        ni_comps_df = df_read_from_csv(
            os.path.dirname(non_intermediated_checkpoint), "components")
        validate_df(ni_comps_df)
        Component.last_saved = int(ni_comps_df.id.max())
      print("CONTINUING FROM STATE", self.cycle_id, self.generation_id)

    # Load all available paths.
    all_agents_dirs = []
    for agent_class_to_load in self.agent_classes_to_load():
      all_agents_dirs.extend(
          gfile.glob(os.path.join(self.config.system_state_dir,
                                  agent_class_to_load+"~*")))
    assert all_agents_dirs, f"No state for agents: {self.agent_classes_to_load()}"
    state_dir = os.path.dirname(agent_checkpoint) if agent_checkpoint else None
    load_paths(self.pop, state_dir, all_agents_dirs)

    assert self.pop.paths, "Population is empty, run an agent creating a " \
        "root model to initialize the population."
    df_leaderboard(pop_to_df(self.pop))

  def get_config(self):
    return get_config_vit_large()

  def run(self):
    if self.config.task_name.startswith("root_model/"):
      run_root_model(self)
      return
    run_cycles(self)

  def complete_config(self, system_state_dir, task_name, num_cycles_max):
    super().complete_config(system_state_dir, task_name, num_cycles_max)
    self.config = FrozenConfigDict(self.config)
    config_validate(self.config)

  def do_mutate(self, hparams, mutation_name):
    """Returns True if mutation is sampled to be applied."""
    if mutation_name in self.config.get("force_mutations", []):
      return True
    mutation_prob_k = f"_mu_|{mutation_name}"
    # Fallback is used for batch shared sampling.
    mu = hparams.get("_mu_", self.config.hparams_defaults["_mu_"])
    mutation_prob = hparams.get(mutation_prob_k, mu)
    if "_mu_" in self.config.hparams_mutation_ranges:
      if mu > np.random.uniform():
        mutation_prob = incremental_mutation(
            mutation_prob, self.config.hparams_mutation_ranges["_mu_"])
      hparams[mutation_prob_k] = mutation_prob
    return mutation_prob > np.random.uniform()

  def parent_decay_selection(self):
    for path in sorted(self.pop.paths[self.config.agent_id],
                       key=lambda p: p.score(),
                       reverse=True):
      offsprings = path.metrics.get("offsprings", 0)
      assert not math.isnan(offsprings)
      select_prob = 0.5 ** offsprings
      print(f" Candidate parent path {path.id},",
            f"selection probability: 0.5^{offsprings} == {select_prob}")
      if np.random.uniform() < select_prob:
        path.metrics["offsprings"] = path.metrics.get("offsprings", 0) + 1
        return path
    return None

  def sample_path(self, ds_hparams):
    parent = self.parent_decay_selection()
    if not parent:  # Random sample.
      parent = random.choice([p for paths in self.pop.paths.values() for p in paths])
      print(f" Randomly selected parent {parent.agent_id}:{parent.id}")
    return self.mutate_parent(parent, ds_hparams)

  def mutate_hparams(self, hparams):
    for k in sorted(self.config.hparams_mutation_ranges):
      if k in hparams and self.do_mutate(hparams, f"hp:{k}"):
        hparams[k] = incremental_mutation(
            hparams[k], self.config.hparams_mutation_ranges[k])
    return hparams

  def sample_ds_hparams(self):
    """Sample hparams that need to be shared across each paths generation."""
    ds_hparams = {}
    # Initialize shared hparams with defaults.
    for key in self.config.hparams_defaults:
      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        ds_hparams[key] = self.config.hparams_defaults[key]
    # Overwrite with values from best path if available.
    best_path = self.pop.get_best_path()
    if best_path:
      ds_hparams.update(
          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})
      ds_hparams.update(
          {k : best_path.hparams[k] for k in best_path.hparams if k.startswith(
              f"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}")})
      # Sample mutations.
      df_hparams = self.mutate_hparams(ds_hparams)
    # Validate.
    for k in ds_hparams:
      assert (k.startswith(DATASET_HPARAMS_KEYS_PRERFIX) or
              k.startswith(f"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}"))
    return ds_hparams

  def finalize_ds_hparams(self, ds_hparams, paths):
    # Validate shared params.
    for k in ds_hparams:
      if k.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        for path in paths:
          assert ds_hparams[k] == path.hparams[k]
    return ds_hparams

  def mutate_parent(self, parent, ds_hparams):
    config = self.config
    agent_id = config.agent_id
    task_name = config.task_name
    comps = []
    new_hparams = copy.deepcopy(parent.hparams)
    new_hparams = self.mutate_hparams(new_hparams)
    # Overwrite dataset hparams with those sampled for the generation batch.
    new_hparams.update(ds_hparams)

    def get_component_ref(c, clone):
      if c.is_trainable() or clone:
        # Clone trainable component.
        return c.clone(agent_id=agent_id)
      # Refer to frozen component.
      return c

    init_params = self.get_model_factory().get_init_comps(new_hparams, config)
    for new_comp_name in init_params:
      comp = None
      # Attept to reuse matching componenent from closer ancestor.
      ancestor = parent
      while ancestor is not None:
        comps_lookup = {c.name:c for c in ancestor.components}
        if new_comp_name in comps_lookup:
          # Head must be trainable if no acestor is of same agent will fall back
          # to random init of correct shape.
          if new_comp_name == "head" and agent_id != ancestor.agent_id:
            assert agent_id != ancestor.agent_id, f"{agent_id} != {ancestor.agent_id}"
            ancestor = ancestor.parent
            continue
          # Check shapes match otherwise skip.
          if (jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) !=
              jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):
            if new_comp_name == "posembed_input":
              # Change of image size changed shape of position embeddings,
              # this can happend if ds_image_size is tuned,
              # continue searching through ancestors for matching size.
              assert "ds_image_size" in config.hparams_mutation_ranges
              assert new_hparams["ds_image_size"] != ancestor.hparams["ds_image_size"]
              ancestor = ancestor.parent
              continue

            print(f"WARNING: Shapes do not match for component: {new_comp_name}  {ancestor.agent_id}->{agent_id}")
            print(jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]))
            print(jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params))
            assert False  # Should not happen in current configuration.

          ancestor_comp = comps_lookup[new_comp_name]
          comp = get_component_ref(
              ancestor_comp, clone=(
                  ancestor_comp.is_trainable() or self.do_mutate(
                      new_hparams, f"clone:{new_comp_name}")))
          break
        ancestor = ancestor.parent
      # Get reshaped posembed_input from checkpoint.
      if comp is None and new_comp_name == "posembed_input":
        pe_comp = self.cached_posembed_components(
            agent_id=agent_id,
            query=config.vit_checkpoint_query,
            **new_hparams)
        # Clone to make the component trainable.
        comp = get_component_ref(pe_comp, clone=True)
      # Otherwise create one from random init params.
      if comp is None:
        # Possible rand init triggering combinations in current configurations.
        assert (
            new_comp_name == "head"
            or (new_comp_name.startswith("encoderblock_")
                and config.hparams_defaults["num_layers"] < max(
                config.hparams_mutation_ranges.get("num_layers", [-1]))))
        comp = params2comps(
            init_params, train_locks=[],
            agent_id=agent_id, name=new_comp_name)[0]
      assert comp is not None
      comps.append(comp)
    return Path(new_hparams, comps, parent=parent, agent_id=agent_id, task_name=task_name)

  def print_best_path_summary(self):
    best_path = self.pop.get_best_path()
    print(f"Best id:{best_path.id}",
          f"score:{best_path.score():.4f}",
          f"quality:{best_path.metrics['quality']:.4f}",
          f"gen:{best_path.metrics['generation_id']}",
          f"\n{best_path.hparams}")

  def get_paths_to_publish(self):
    return [p for p in self.pop.paths[self.config.agent_id]]

class VitT3(Vit):
  """ViT tiny 3 layers"""
  def get_config(self):
    return get_config_vit_ti3()

class VitB(Vit):
  """ViT base"""
  def get_config(self):
    return get_config_vit_base()

## MultiVit Agent

In [None]:
def set_multivit_common_config(config):
  config.force_mutations = []
  config.scorer_kwargs = {}
  for rm_k in ["num_layers", "ds_image_size"]:
    del config.hparams_defaults[rm_k]
    del config.hparams_mutation_ranges[rm_k]
  config.hparams_defaults["router_init_main_path_weight"] = 0.8
  config.hparams_defaults["router_lr_mult"] = 0.05
  config.hparams_defaults["num_paths"] = 2
  config.hparams_mutation_ranges["router_lr_mult"] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]
  config.hparams_mutation_ranges["num_paths"] = [2, 3]
  return config

In [None]:
class MultiVit(Vit):
  def get_config(self):
    config = get_config_vit_large()
    return set_multivit_common_config(config)

  @staticmethod
  def get_model_factory():
    return MultiVitModelFactory

  def run(self):
    run_cycles(self)

  def agent_classes_to_load(self):
    return [self.single_path_agent_class()]

  def single_path_agent_class(self):
    return self.class_name.replace("Multi", "")

  def single_path_main_agent_id(self):
    return format_agent_id(self.single_path_agent_class(), self.config.task_name)

  def load_state(self):
    if self.config.task_name.startswith("root_model/"):
      assert False, (
          "Root models need to be generated with the corresponding" \
          f"single path agent: '{self.single_path_agent_class()}'.")
    super().load_state()
    assert self.single_path_main_agent_id() in self.pop.paths, (
        "Missing state for the corresponding single path main agent. " \
        f"Run agent '{self.single_path_agent_class()}' " \
        f"on the '{self.config.task_name}' task to generated it.")

  def sample_path(self, ds_hparams):
    selected_paths = {}
    selected_paths["path_0"] = random.choice(self.pop.paths[self.single_path_main_agent_id()])
    parent = self.parent_decay_selection()
    if parent is not None:
      new_hparams = copy.deepcopy(parent.hparams)
      print(" Selected")
    else:
      parent = selected_paths["path_0"]
      best_path = self.pop.get_best_path()
      if best_path:
        new_hparams = copy.deepcopy(best_path.hparams)
      else:
        new_hparams = self.config.hparams_defaults.to_dict()
      print(" New random")
    new_hparams["paths"] = {}
    new_hparams = self.mutate_hparams(new_hparams)
    # Overwrite dataset hparams with those sampled for the generation batch.
    new_hparams.update(ds_hparams)

    for path_name in [f"path_{i}" for i in range(int(new_hparams["num_paths"]))]:
      if path_name in selected_paths:
        continue
      if path_name in parent.hparams.get("paths", {}):
        selected_paths[path_name] = self.pop.get_path_from_full_id(parent.hparams["paths"][path_name]["agent_id"], parent.hparams["paths"][path_name]["id"])
        print(" Subpath from parent: ", path_name, selected_paths[path_name].full_id)
        continue
      selected_paths[path_name] = random.choice([
          p for paths in self.pop.paths.values() for p in paths if (
              p.agent_id not in [sp.agent_id for sp in selected_paths.values()]
              and p.agent_id.startswith(f"{self.single_path_agent_class()}~")
              and not p.agent_id.endswith("~1k")  # Excludes VTAB-1k tasks.
              # and (path_name != "path_1" or p.agent_id in ['Vit~i_naturalist2017'])  # Forces i_naturalist2017 selection.
              )])
      print(" Subpath rand selected:", path_name, selected_paths[path_name].full_id)

    for (path_name, path) in selected_paths.items():
      new_hparams["paths"][path_name] = {
          "id": path.id,
          "agent_id": path.agent_id,
          "hparams": copy.deepcopy(path.hparams)}
    print(" Sampled subpaths:",
          {k: v["agent_id"] for k, v in new_hparams["paths"].items()})
    # Set headless model config for models of different tasks.
    for k in new_hparams["paths"]:
      if selected_paths[k].task_name != self.config.task_name:
        new_hparams["paths"][k]["hparams"]["num_classes"] = 0
    # Collect image sizes needed.
    image_sizes = set()
    for k in new_hparams["paths"]:
      image_sizes.add(int(new_hparams["paths"][k]["hparams"]["ds_image_size"]))
    new_hparams["ds_image_size"] = list(image_sizes)
    # Collect components.
    init_params = self.get_model_factory().get_init_comps(new_hparams, self.config)
    comps = []
    for new_comp_name in init_params:
      if new_comp_name in new_hparams["paths"]:
        comps.append(ComponentPath(name=new_comp_name,
                                   path=selected_paths[new_comp_name]))
      else:
        comps_lookup = {c.name:c for c in parent.components}
        if new_comp_name in comps_lookup and (
          jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) ==
          jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):
          print(" COMP Reusing", new_comp_name)
          comp = comps_lookup[new_comp_name].clone(agent_id=self.config.agent_id)
        else:
          print(" COMP Init", new_comp_name)
          comp = Component(name=new_comp_name,
                           agent_id=self.config.agent_id,
                           params=init_params[new_comp_name],
                           train_locks=[])
        comps.append(comp)
    return Path(new_hparams, comps, parent=parent,
                agent_id=self.config.agent_id, task_name=self.config.task_name)

  def finalize_ds_hparams(self, ds_hparams, paths):
    # Validate shared params.
    for k in ds_hparams:
      if k.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        for path in paths:
          assert ds_hparams[k] == path.hparams[k], (k, ds_hparams[k], path.hparams[k])
    image_sizes = set()
    for path in paths:
      image_sizes.update(path.hparams["ds_image_size"])
    ds_hparams["ds_image_size"] = list(image_sizes)
    print("Image sizes:", ds_hparams["ds_image_size"])
    return ds_hparams

  def print_best_path_summary(self):
    super().print_best_path_summary()
    best_path = self.pop.get_best_path()
    print("Paths used by best model:",
          {k: best_path.hparams["paths"][k]["agent_id"] for k in best_path.hparams["paths"]})

  def get_paths_to_publish(self):
    paths_to_publish = [p for p in self.pop.paths[self.config.agent_id]]
    for path in list(paths_to_publish):
      for k in path.hparams["paths"]:
        paths_to_publish.append(self.pop.get_path_from_full_id(path.hparams["paths"][k]["agent_id"], path.hparams["paths"][k]["id"]))
    return paths_to_publish

class MultiVitT3(MultiVit):
  def get_config(self):
    config = get_config_vit_ti3()
    return set_multivit_common_config(config)

class MultiVitB(MultiVit):
  def get_config(self):
    config = get_config_vit_base()
    return set_multivit_common_config(config)

# Tasks

In [None]:
TFDS_BUILDERS_CACHE = {}
def get_tfds_builder(tfds_name):
  global TFDS_BUILDERS_CACHE
  if tfds_name not in TFDS_BUILDERS_CACHE:
    TFDS_BUILDERS_CACHE[tfds_name] = tfds.builder(tfds_name)
    TFDS_BUILDERS_CACHE[tfds_name].download_and_prepare()
  return TFDS_BUILDERS_CACHE[tfds_name]

In [None]:
def get_default_splits(tfds_name):
  info = get_tfds_builder(tfds_name).info
  splits = list(info.splits.keys())
  assert "train" in splits, splits
  splits.remove("train")
  used_percent = 0
  slice_percent = 5
  pp = {}
  for k in ["test", "validation"]:
    if k in splits:
      pp[k] = k
      splits.remove(k)
    else:
      pp[k] = f"train[{used_percent}%:{used_percent+slice_percent}%]"
      used_percent += slice_percent
  pp["train"] = f"train[{used_percent}%:]"
  return pp

def get_dataset_and_splits(tfds_name: str):
  vtab_class = None
  if tfds_name in ["imagenet_v2", "cifar10_1"]:
    assert False,  f"{tfds_name} used as validation set for other tasks."

  if tfds_name == "imagenet2012":
    dataset = {
        "train":"imagenet2012", "validation":"imagenet_v2", "test":"imagenet2012"}
    splits = {
        "train":"train", "validation":"test", "test":"validation"}
  elif tfds_name == "cifar100":
    dataset = tfds_name
    splits = {
        "train":"train[:98%]", "validation":"train[98%:]", "test":"test"}
  elif tfds_name == "cifar10":
    dataset = {
        "train":"cifar10", "validation":"cifar10_1", "test":"cifar10"}
    splits = {
        "train":"train", "validation":"test", "test":"test"}
  elif (tfds_name.startswith("visual_domain_decathlon/") or
        tfds_name in ["i_naturalist2017", "i_naturalist2018", "places365_small"]):
    dataset = tfds_name
    # Test has no labels, split validation in half.
    splits =  {
        "train":"train", "validation":"validation[:50%]", "test":"validation[50%:]"}
  elif tfds_name.startswith("cmaterdb/"):
    dataset = tfds_name
    # Increase size of validation set due to small dataset size.
    splits =  {
        "train":"train[20%:]", "validation":"train[:20%]", "test":"test"}
  elif tfds_name == "omniglot":
    # Test has no labels, and missing validation, use additional splits.
    dataset = tfds_name
    splits = {"train":"train", "validation":"small1", "test":"small2"}
  elif tfds_name.startswith("controlled_noisy_web_labels/"):
    dataset = tfds_name
    splits =  {
        "train":"train_00",
        "validation":"validation[:50%]",
        "test":"validation[50%:]"}
  elif tfds_name.startswith("cycle_gan/"):
    dataset = tfds_name
    splits =  {
        "train":"trainA[10%:]+trainB[10%:]",
        "validation":"trainA[:10%]+trainB[:10%]",
        "test":"testA+testB"}
  elif tfds_name in ["imagenet_a", "imagenet_r", "imagenet_sketch",
                     "siscore/rotation", "siscore/size", "siscore/location",]:
    # Only test split.
    dataset = tfds_name
    splits =  {
        "train":"test[10%:]",
        "validation":"test[5%:10%]",
        "test":"test[:5%]"}
  elif tfds_name in ["pet_finder"]:
    # Explicitly use only train split. E.g. test has no labels.
    dataset = tfds_name
    splits =  {
        "train":"train[10%:]",
        "validation":"train[5%:10%]",
        "test":"train[:5%]"}
  elif tfds_name == "quickdraw_bitmap":
    dataset = tfds_name
    # Cap size of test and validation set.
    splits =  {
        "train":"train[20000:]", "validation":"train[10000:20000]", "test":"train[:10000]"}
  elif tfds_name == "stanford_online_products":
    dataset = tfds_name
    # Use the first 10k test samples as validation since test has 60k.
    splits =  {
        "train":"train", "validation":"test[:10000]", "test":"test[10000:]"}
  elif tfds_name in VTAB_TASKS or (
      tfds_name.endswith("/1k") and tfds_name.replace("/1k", "") in VTAB_TASKS):
    is_vtab_1k = tfds_name.endswith("/1k")
    tfds_name = tfds_name.replace("/1k", "")
    registry_name = {
        "diabetic_retinopathy_detection/btgraham-300": "diabetic_retinopathy",
        "svhn_cropped": "svhn",
        "cifar100": "cifar",
        "cifar10": "cifar",
    }.get(tfds_name, tfds_name.split("/")[0])
    args = {
        "clevr/count_all": ("count_all",),
        "clevr/count_cylinders": ("count_cylinders",),
        "clevr/closest_object_distance": ("closest_object_distance",),
        "dsprites/label_x_position": ("label_x_position",),
        "dsprites/label_orientation": ("label_orientation",),
        "kitti/closest_object_distance": ("closest_object_distance",),
        "kitti/count_vehicles": ("count_vehicles",),
        "kitti/closest_vehicle_distance": ("closest_vehicle_distance",),
        "smallnorb/label_category": ("label_category",),
        "smallnorb/label_lighting": ("label_lighting",),
        "smallnorb/label_azimuth": ("label_azimuth",),
        "smallnorb/label_elevation": ("label_elevation",),
        "cifar100": (100,),
        "cifar10": (10,),
    }.get(tfds_name, ())
    vtab_class = task_adapt_registry.Registry.lookup(
        f"data.{registry_name}")(*args)
    vtab_splits = vtab_class._tfds_splits
    dataset = {
        "caltech101": "caltech101:3.*.*",
        "dtd": "dtd:3.*.*",
        "oxford_flowers102": "oxford_flowers102:2.*.*",
        "oxford_iiit_pet": "oxford_iiit_pet:3.*.*",
        "sun397": "sun397/tfds:4.*.*",
        "svhn": "svhn_cropped:3.*.*",
        "patch_camelyon": "patch_camelyon:2.*.*",
        "eurosat": "eurosat/rgb:2.*.*",
        "resisc45": "resisc45:3.*.*",
        "diabetic_retinopathy": "diabetic_retinopathy_detection/btgraham-300:3.*.*",
        "clevr": "clevr:3.*.*",
        "dmlab": "dmlab:2.0.1",
        "dsprites": "dsprites:2.*.*",
        "kitti": "kitti:3.2.0",
        "smallnorb": "smallnorb:2.*.*",
        "cifar" : "cifar100:3.*.*" if tfds_name == "cifar100" else "cifar10:3.*.*",
    }[registry_name]
    if is_vtab_1k:
      splits =  {
          "train": str(vtab_splits["train800"]),
          "validation": str(vtab_splits["val200"]),
          "test": str(vtab_splits["test"]),
          }
    else:
      splits =  {
          "train": str(vtab_splits["train"]),
          "validation": str(vtab_splits["val"]),
          "test": str(vtab_splits["test"]),
          }
  else:
    dataset = tfds_name
    splits = get_default_splits(tfds_name)
  return dataset, splits, vtab_class


class Task():
  def __init__(self, name, config):
    self.config = config

    self.dataset, self.splits, self.vtab_class = get_dataset_and_splits(name)
    self.name = name
    if self.vtab_class:
      self.num_classes = self.vtab_class.get_num_classes()
    else:
      self.num_classes = self.get_builder(
          "train").info.features[self.get_label_key()].num_classes
    num_train_examples = self.get_builder(
        "train").info.splits[self.splits["train"]].num_examples
    self.train_batch_size = config.batch_size
    self.num_train_batches_between_validations = math.ceil(
        min(num_train_examples,
            config.num_train_examples_between_validations_max)
        / self.train_batch_size)

    num_validation_examples_tot = self.get_builder(
        "validation").info.splits[self.splits["validation"]].num_examples
    if config.num_validation_examples_max <= num_validation_examples_tot:
      self.validation_batch_size = config.batch_size
      self.num_validation_batches = math.floor(
          config.num_validation_examples_max / self.validation_batch_size)
    else:
      # Adjust batch_size and num_batches to cover the smaller validation sets.
      self.num_validation_batches = math.ceil(
          num_validation_examples_tot / config.batch_size)
      self.validation_batch_size = math.floor(
          num_validation_examples_tot / self.num_validation_batches)
      assert num_validation_examples_tot >= (
          self.num_validation_batches*self.validation_batch_size)
    self.num_validation_examples = (
        self.num_validation_batches * self.validation_batch_size)

    print(f"Task: {self.name}")
    print(f"  Train batches between validations: {self.num_train_batches_between_validations}")
    print(f"  Validation batches: {self.num_validation_batches}")
    print(f"  Validation batch size: {self.validation_batch_size}")
    print(f"  Dataset {{\n{self.dataset}}}")
    print(f"  Splits {{\n{self.splits}}}")

  def get_label_key(self):
    return {
        "stanford_online_products": "super_class_id",
        }.get(self.name, "label")

  def get_builder(self, mode):
    if type(self.dataset) == str:
      return get_tfds_builder(self.dataset)
    return get_tfds_builder(self.dataset[mode])

  def get_ds(self, mode, hparams):
    data = self.get_builder(mode).as_dataset(
        split=self.splits[mode],
        shuffle_files=mode=="train")

    def _pp(data):
      im = data["image"]
      tf.debugging.assert_type(im, tf.uint8)

      if mode == "train":
        if hparams.get("ds_quality_delta", 0.0) > 0.0:
          im = tf.image.random_jpeg_quality(
              im,
              min_jpeg_quality=int(100 * (1 - hparams["ds_quality_delta"])),
              max_jpeg_quality=100)

      # Must have 3 channels.
      if im.shape[-1] == 1:
        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)
      assert im.shape[-1] == 3
      im = tf.cast(im, tf.float32)
      if mode == "train":
        if hparams.get("ds_area_range_min", 1.0) < 1.0:
          channels = im.shape[-1]
          begin, size, _ = tf.image.sample_distorted_bounding_box(
              tf.shape(im),
              tf.zeros([0, 0, 4], tf.float32),
              aspect_ratio_range=[hparams["ds_aspect_ratio_range_min"],
                                  1.0/hparams["ds_aspect_ratio_range_min"]],
              area_range=[hparams["ds_area_range_min"], 1.0],
              # Overlap with bounding box, the bounding box should anyway
              # default defaults to whole image in this case.
              min_object_covered=0,
              use_image_if_no_bounding_boxes=True)
          im = tf.slice(im, begin, size)
          # Restore the depth-dimension lost by the above operation.
          im.set_shape([None, None, channels])
        if hparams.get("ds_flip_left_right", False):
          if tf.random.uniform(shape=[]) > 0.5:
            im = tf.image.flip_left_right(im)
        if hparams.get("ds_brightness_delta", 0.0) > 0.0:
          im = tf.image.random_brightness(
              im, max_delta=hparams["ds_brightness_delta"])
        if hparams.get("ds_contrast_delta", 0.0) > 0.0:
          im = tf.image.random_contrast(
              im, lower=1-hparams["ds_contrast_delta"],
              upper=1+hparams["ds_contrast_delta"])
        if hparams.get("ds_saturation_delta", 0.0) > 0.0:
          im = tf.image.random_saturation(
              im, lower=1-hparams["ds_saturation_delta"],
              upper=1 + hparams["ds_saturation_delta"])
        if hparams.get("ds_hue_delta", 0.0) > 0.0:
          im = tf.image.random_hue(im, max_delta=hparams["ds_hue_delta"])

      def get_formatted_image(image, image_size):
        image = tf.image.resize(image, [image_size, image_size])
        # Values in range [-1 , 1].
        image = image / 127.5 - 1
        image = tf.clip_by_value(image, -1, 1)
        return image

      if type(hparams["ds_image_size"]) is list:
        out_im = {}
        for im_size in hparams["ds_image_size"]:
          out_im[str(im_size)] = get_formatted_image(im, int(im_size))
      else:
        out_im = get_formatted_image(im, int(hparams["ds_image_size"]))
      return {"image": out_im,
              "label": data[self.get_label_key()]}

    if mode == "validation":
      data = data.take(self.num_validation_examples).cache()
    if mode != "test":
      data = data.repeat()
    if self.vtab_class and self.vtab_class._base_preprocess_fn:
      data = data.map(self.vtab_class._base_preprocess_fn, tf.data.AUTOTUNE)
    data = data.map(_pp, tf.data.AUTOTUNE)
    if mode == "train":
      batch_size = self.train_batch_size
    else:
      batch_size = self.validation_batch_size
    data = data.batch(batch_size)
    if mode == "train":
      data = data.shuffle(10)
    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))

def get_task_factory_fn(config):
  def get_task(task_name: str):
    return Task(name=task_name, config=config)
  return get_task

# Components

In [None]:
def params2comps(params, train_locks, agent_id, name=None):
  """Convert frozend dict of params to a list of components."""
  components = []
  for k in params:
    if name is None or name == k:
      c = Component(
          name=k, agent_id=agent_id,
          params=params[k], train_locks=train_locks)
      components.append(c)
  return components

In [None]:
def fingerprint_params(params):
  return np.sum(np.array(jax.tree_util.tree_leaves(
      jax.tree_util.tree_map(jnp.sum, params))))

class Component():
  counter = 0
  # Components of retained paths with id <= last_saved are saved in checkpoint.
  last_saved = -1

  def reset_globals():
    Component.counter = 0
    Component.last_saved = -1

  def __init__(
      self, name: str, agent_id: str, params, train_locks):
    self.name = name
    self.agent_id = agent_id
    self.params = jax.device_get(params)
    self.num_params = None
    self.train_locks = set(train_locks)
    self.id = Component.counter
    Component.counter += 1

  def get_num_params(self):
    if self.num_params is None:
      self.num_params = get_num_params(self.params)
    return self.num_params

  def fingerprint(self):
    return fingerprint_params(self.params)

  def is_trainable(self):
    return len(self.train_locks) == 0

  def clone(self, agent_id):
    return Component(name=self.name,
                     agent_id=agent_id,
                     params=copy.deepcopy(jax.device_get(self.params)),
                     train_locks=set())

In [None]:
class ComponentPath():
  """Wraps a Paths to be used as a Component."""
  def __init__(self, name, path):
    self.name = name
    self.path = path
    self.train_locks = set(["FROZEN"])

  def is_trainable(self):
    return False

  @property
  def params(self):
    return flax.core.freeze(self.path.get_all_params())

 # Paths & Population

In [None]:
class Path():
  def reset_globals(config):
    Path.config = config
    Path.counter = 0
    Path.last_saved = -1
    Path.paths = []
    Path.scorer = globals()[config.get("scorer_class", "ScorerDecay")](
        **config.get("scorer_kwargs", {}))
    # Cache output of functions calls with same args.
    Path.cached_tasks = ObjectCache(get_task_factory_fn(config))
    Path.cached_optimizers = ObjectCache(get_optimizer)

  def __init__(self, hparams, components, parent, agent_id, task_name):
    self.components = components
    self.id = Path.counter
    Path.counter += 1
    self.agent_id = agent_id
    self.task_name = task_name
    self.parent = parent
    self.hparams = hparams
    self._model = None
    self.metrics = {
        "generation": 0 if parent is None else parent.metrics["generation"] + 1,
    }
    Path.paths.append(self)

  @property
  def task(self):
    return Path.cached_tasks(task_name=self.task_name)

  @property
  def model_factory(self):
    return get_agent_class(self.agent_id).get_model_factory()

  @property
  def model(self):
    if self._model == None:
      self._model = self.model_factory.get_model(self.hparams, self.config)
    return self._model

  @property
  def full_id(self):
    return f"{self.agent_id}:{self.id}"

  def comps_only(self):  # Exclude wrapped paths.
    return [c for c in self.components if c.__class__ is Component]

  def score(self):
    return Path.scorer.score(self)

  def get_all_params(self):
    params = {}
    for c in self.components:
      assert c.name not in params, c.name
      params[c.name] = c.params
    return flax.core.freeze(params)

  def get_trainable_params(self):
    params = {}
    for c in self.components:
      if c.is_trainable():
        assert c.name not in params, c.name
        params[c.name] = c.params
    return flax.core.freeze(params)

  def get_fixed_params(self):
    params = {}
    for c in self.components:
      if not c.is_trainable():
        assert c.name not in params, c.name
        params[c.name] = c.params
    return flax.core.freeze(params)

  def update_trainable(self, trained_params):
    trainable_count = 0
    for c in self.components:
      if c.is_trainable():
        trainable_count += 1
        assert c.name in trained_params.keys()
        c.params = trained_params[c.name]
    assert len(trained_params.keys()) == trainable_count, (
        f"{len(trained_params.keys())} {trainable_count}")

  def get_num_accounted_params(self):
    rtn = 0
    for c in self.components:
      tl = copy.copy(c.train_locks)
      assert type(tl) is set
      tl.add(self.agent_id)
      assert tl
      rtn += c.get_num_params() / len(tl)
    return rtn

  def get_flops(self):
    return compute_flops_hlo(
          partial(self.model.apply, train=False),
          {"params": self.model_factory.get_comps2model_fn()(merge_params(
              self.get_trainable_params(),
              self.get_fixed_params()))},
          self.model_factory.get_sample_input(self.hparams))

  def get_optimizer(self):
    return Path.cached_optimizers(
        num_train_batches_between_validations=
            self.task.num_train_batches_between_validations,
        num_validations_per_path_training=
            self.task.config.num_validations_per_path_training,
        **self.hparams)

In [None]:
class ScorerDecay():
  def __init__(self, scale_factor=1, base_accounted_params=0, base_flops=0):
    assert 0.0 < scale_factor <= 1.0
    self.scale_factor = scale_factor
    self.base_accounted_params = base_accounted_params
    self.base_flops = base_flops

  def score(self, path):
    if ("quality" not in path.metrics
        or math.isnan(path.metrics["quality"])):
      return None
    assert path.metrics["quality"] >= 0, (
        f"{path.task_name} {path.metrics['quality']}")
    score = path.metrics["quality"]
    if self.base_accounted_params > 0:
      # Accounted params needs to be updated since it depends on the
      # changing structure of the system.
      path.metrics["accounted_params"] = path.get_num_accounted_params()
      score *= self.scale_factor ** (
          path.metrics["accounted_params"] / self.base_accounted_params)
    if self.base_flops > 0:
      if "flops" not in path.metrics:
        path.metrics["flops"] = path.get_flops()
      score *= self.scale_factor ** (path.metrics["flops"] / self.base_flops)
    assert score >= 0
    path.metrics["score"] = score
    return score

In [None]:
class Population():
  def __init__(self, config):
    Path.reset_globals(config)
    Component.reset_globals()
    self.paths = defaultdict(list)
    self.config = config
    self.paths_df = pd.DataFrame()
    self.comps_df = pd.DataFrame()

  def get_best_path(self):
    if len(self.paths[self.config.agent_id]) == 0:
      return None
    # Oldest path achieving max score.
    return max(sorted(self.paths[self.config.agent_id], key=lambda p: p.id, reverse=False), key=lambda p: p.score())

  def prune_population(self):
    if self.config.get("max_task_population_size", None) and (
        len(self.paths[self.config.agent_id]) > self.config.max_task_population_size):
      self.paths[self.config.agent_id] = sorted(
          self.paths[self.config.agent_id], key=lambda p: p.score(), reverse=True
          )[:self.config.max_task_population_size]

  def add_train_locks(self):
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert self.config.agent_id not in c.train_locks
    # Add locks.
    paths = self.paths[self.config.agent_id]
    for p in paths:
      for c in p.components:
        c.train_locks.add(self.config.agent_id)

  def rm_train_locks(self):
    # Remove locks.
    paths = self.paths[self.config.agent_id]
    for p in paths:
      for c in p.components:
        if self.config.agent_id in c.train_locks:
          c.train_locks.remove(self.config.agent_id)
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert self.config.agent_id not in c.train_locks

  def start_cycle(self):
    self.rm_train_locks()

  def end_cycle(self):
    # Keep only best one.
    best_path = self.get_best_path()
    assert best_path is not None
    best_path.metrics["num_cycles"] = best_path.metrics.get("num_cycles", 0) + 1
    self.paths[self.config.agent_id] = [best_path]
    self.add_train_locks()
    self.garbage_collect_paths()

  def garbage_collect_paths(self):
    # Store history before dropping references to unused paths to trigger
    # garbage collection of components and parameters.
    self.paths_df = self.paths_df.append(
        paths_to_df(Path.paths), ignore_index=True
        ).query(f'agent_id=="{self.config.agent_id}" and id>{Path.last_saved}'
        # Drop duplicates generated by reloads, notice that some state-based
        # metrics may vary (e.g. accounted parameters) so we match only id
        # (agent_id is already matched from the preceding query).
        ).drop_duplicates("id")
    self.comps_df = self.comps_df.append(
        components_to_df(Path.paths), ignore_index=True
        ).query(f'agent_id=="{self.config.agent_id}" and id>{Component.last_saved}'
        ).drop_duplicates()
    # Drop unused paths generated in this agent cycle for garbage collection.
    Path.paths = []
    # Simplify ancestor tree to contain only live paths.
    # Notice that the simplification is done also for paths of other tasks,
    # since they may be pointing to a path of this task that was discarded.
    live_paths_ids = [p.full_id for paths in self.paths.values() for p in paths]
    for path in [path for paths in self.paths.values() for path in paths]:
      ancestor = path.parent
      if ancestor is None:
        continue
      while True:
        if ancestor.full_id in live_paths_ids:
          path.parent = ancestor
          break
        ancestor = ancestor.parent

  def get_path_from_full_id(self, agent_id, path_id):
    for p in self.paths[agent_id]:
      if p.id == path_id:
        return p
    assert False, f"Path not found {agent_id}:{path_id}"

In [None]:
pd.set_option("display.expand_frame_repr", False)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)

def pop_to_df(pop):
  return paths_to_df([p for paths in pop.paths.values() for p in paths])

def paths_to_df(paths):
  # Collect all metrics names.
  metrics_keys = set()
  hparams_keys = set()
  for path in paths:
    path.score()  # Update scores.
    metrics_keys.update(path.metrics)
    hparams_keys.update(path.hparams)

  def _format(x):
    if type(x) in [dict, list]:
      return json.dumps(x)
    return x

  data = defaultdict(list)
  for path in paths:
    data["agent_id"].append(path.agent_id)
    data["task_name"].append(path.task_name)
    data["id"].append(path.id)
    data["parent_id"].append(path.parent.id if path.parent else -1)
    data["parent_agent_id"].append(path.parent.agent_id if path.parent else None)
    data["components"].append(",".join(
        [f"{c.agent_id}:{c.id}" for c in path.comps_only()]))
    for k in hparams_keys:
      data[f"hparams.{k}"].append(_format(path.hparams[k]) if k in path.hparams else None)
    for k in metrics_keys:
      data[f"metrics.{k}"].append(path.metrics[k] if k in path.metrics else None)
  return pd.DataFrame(data)

def components_to_df(paths):
  # Collect all components.
  comps = set()
  for p in paths:
    comps.update(p.comps_only())

  data = defaultdict(list)
  for c in comps:
    data["id"].append(c.id)
    data["name"].append(c.name)
    data["agent_id"].append(c.agent_id)
    data["num_params"].append(c.get_num_params())
  return pd.DataFrame(data)

def print_df_segments(df, segment_length:int = 5):
  tot_length = df.shape[0]
  # Pad column title with spaces to keep alignment across segments.
  def prepend_spaces(original_str, pad_to_len):
    return " " * (pad_to_len-len(original_str)) + original_str
  pad_to_len = max([len(tn) for tn in set(df["agent_id"].to_list())])+1
  df = df.rename(columns={
    "agent_id": prepend_spaces("agent_id", pad_to_len),
    "task_name": prepend_spaces("task_name", pad_to_len),
    "parent_agent_id": prepend_spaces("parent_agent_id", pad_to_len),
    })
  for x in range(0, tot_length, segment_length):
    print(df[x:min(x+segment_length, tot_length)])

def df_leaderboard(df):
  # Place columns on the left for readability.
  all_keys = sorted(df.columns.tolist())
  first_keys = ["agent_id", "task_name", "metrics.test_quality", "metrics.score",
                "metrics.quality", "metrics.accounted_params", "metrics.flops",
                "id", "parent_id", "parent_agent_id"]
  first_keys = [k for k in first_keys if k in all_keys]
  sorted_keys = first_keys + [k for k in all_keys if k not in first_keys]
  # Filter mu function parameters.
  sorted_keys = [k for k in sorted_keys if "_mu_|" not in k]
  df = df[sorted_keys]
  if "metrics.score" in df:
    df = df.sort_values(["agent_id", "metrics.score"], ascending=[True, False], ignore_index=True)
  else:
    df = df.sort_values("agent_id", ignore_index=True)
  print_df_segments(df)
  for k in ["metrics.score", "metrics.quality", "metrics.test_quality"]:
    if k in df:
      print(f"Avg {k}: {df[k].mean():.6f}")

# Checkpointing

In [None]:
def df_write_to_csv(df, dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "w") as outfile:
    df.to_csv(outfile, index=False)

def df_read_from_csv(dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "r") as infile:
    df = pd.read_csv(infile)
  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in
  # columns with type strings/object.
  for c in df.columns:
    if df[c].dtype == np.object_:
        df[c].fillna("", inplace=True)
  return df

def get_comps_params_to_save(pop):
  comps_params = {}
  # All components generated by this agent.
  all_comps = set(
      [c for p in pop.paths[pop.config.agent_id] for c in p.comps_only() if c.agent_id == pop.config.agent_id])
  # Check that there are not duplicate ids.
  assert len(all_comps) == len(set([c.id for c in all_comps])), (
      [f"{c.name}:{c.agent_id}:{c.id}" for c in all_comps])
  for c in all_comps:
    if c.id <= Component.last_saved:
      continue
    assert c.agent_id == pop.config.agent_id
    c_id_string = f"{c.name}:{c.agent_id}:{c.id}"
    comps_params[c_id_string] = c.params
  return comps_params

In [None]:
def latest_checkpoint(ckpt_dir, prefix = "checkpoint_"):
  ckpt_dir = os.fspath(ckpt_dir)
  glob_path = os.path.join(ckpt_dir, f"{prefix}*")
  checkpoint_files = flax_checkpoints.natural_sort(gfile.glob(glob_path))
  checkpoint_files = [f for f in checkpoint_files if not f.endswith("_tmp")]
  return checkpoint_files[-1] if checkpoint_files else None

In [None]:
def save_checkpoint(ckpt_dir, comps_params, cycle_id, generation_id):
  print("SAVING", cycle_id, generation_id, comps_params.keys())
  # Write checkpoint.
  flax_checkpoints.save_checkpoint(
      ckpt_dir=ckpt_dir,
      target=comps_params,
      step=generation_id,
      prefix=f"checkpoint_{cycle_id}_",
      overwrite=True)
  # Delete intermediate checkpoint directories.
  if generation_id == 0:
    intermediate_ckpt_dirs = gfile.glob(
        os.path.join(os.path.dirname(ckpt_dir), "state_*_[^0]*"))
    for d in intermediate_ckpt_dirs:
      print("Deleting intermediate checkpoint:", d)
      gfile.rmtree(d)

def save_state(agent):
  pop = agent.pop
  cycle_id = agent.cycle_id
  generation_id = agent.generation_id
  config = agent.config
  write_start = time.time()
  # Save data needed to resume exp.
  pop.garbage_collect_paths()
  state_dir = os.path.join(config.agent_dir, f"state_{cycle_id}_{generation_id}")
  gfile.makedirs(state_dir)
  assert not latest_checkpoint(state_dir), f"Checkpoint already present in forlder: {state_dir}"
  print("WRITING CHECKPOINT:", cycle_id, generation_id)
  df_write_to_csv(paths_to_df(agent.get_paths_to_publish()), state_dir, "published")
  df_write_to_csv(paths_to_df([p for paths in pop.paths.values() for p in paths]), state_dir, "population")
  df_write_to_csv(pop.paths_df, state_dir, "paths")
  df_write_to_csv(pop.comps_df, state_dir, "components")
  json.dump(config.as_configdict().to_dict(), gfile.GFile(os.path.join(state_dir, "config.json"), "w"), indent=2)
  save_checkpoint(state_dir, get_comps_params_to_save(pop), cycle_id, generation_id)
  # Update last saved.
  if generation_id == 0:
    Path.last_saved = pop.paths_df.id.max()
    Component.last_saved = pop.comps_df.id.max()
  print(f"STATE WRITE TIME: {time.time() - write_start:.2f} s")

In [None]:
def load_paths(pop, state_dir, all_agents_dirs):
  if state_dir:
    state_dir = state_dir.rstrip("/")
  load_start = time.time()

  # Load system state info.
  population_df = pd.DataFrame()
  skip_agent_dir = None
  if state_dir:
    # Load agent state, possibly intermediate.
    population_df = population_df.append(df_read_from_csv(state_dir, "published"))
    skip_agent_dir = os.path.dirname(state_dir)
  for agent_dir in all_agents_dirs:
    if agent_dir == skip_agent_dir:
      continue
    agent_checkpoint = latest_checkpoint(os.path.join(agent_dir, "state_*_0/"))
    if agent_checkpoint:
      population_df = population_df.append(
          df_read_from_csv(os.path.dirname(agent_checkpoint), "published"))

  # Load parameters from sharded system checkpoint.
  loaded_params = {}  # Dictionary to accumlate loaded parameters.
  lock = Lock()
  duplicate_keys = set()
  def append_loaded_params(add_chkp_dir: str):
    if latest_checkpoint(add_chkp_dir) is None:
      return  # Skip folders without a completed checkpoint.
    lp_add = flax_checkpoints.restore_checkpoint(
        ckpt_dir=add_chkp_dir,
        target=None)
    if lp_add:
      lock.acquire()
      print("LOADED COMPONENTS", add_chkp_dir, lp_add.keys())
      duplicate_keys.update(loaded_params.keys() & lp_add.keys())
      loaded_params.update(lp_add)
      lock.release()
  all_state_dirs = []
  if state_dir:
    # Include active agent state, possibly intermediate.
    all_state_dirs.append(state_dir)
    all_state_dirs.extend(gfile.glob(os.path.join(os.path.dirname(state_dir), "state_*_0")))
  for agent_dir in all_agents_dirs:
    all_state_dirs.extend(gfile.glob(os.path.join(agent_dir, "state_*_0")))
  threads = []
  for s_dir in set(all_state_dirs):
    threads.append(Thread(target=append_loaded_params, args=(s_dir,)))
    threads[-1].start()
  for t in threads:
    t.join()
  assert not duplicate_keys, duplicate_keys
  print(f"LOAD TIME: {time.time() - load_start:.2f} s")
  frozen_params = flax.core.freeze(loaded_params)
  sid_2_comp = {}
  for k in frozen_params.keys():
    assert len(k.split(":")) == 3, k
    name, agent_id, id = k.split(":")
    c = Component(
        name=name, agent_id=agent_id, params=frozen_params[k], train_locks=[])
    c.id = int(id)
    source_id = f"{agent_id}:{id}"
    assert source_id not in sid_2_comp, source_id
    sid_2_comp[source_id] = c
  # For parent assignemt.
  sid_2_path = {}
  path_2_parent_sid = {}
  for index, row in population_df.iterrows():
    agent_id = row["agent_id"]
    path_id = int(row["id"])
    path_sid = f"{agent_id}:{path_id}"
    if path_sid in sid_2_path:
      continue
    comps_sids = row["components"].split(",")
    comps = []
    for sid in comps_sids:
      comps.append(sid_2_comp[sid])
    task_name = row["task_name"]
    # Retrieve hparams and metrics.
    hparams = {}
    metrics = {}
    for k in row.keys():
      v = row[k]
      if type(v) is float and math.isnan(v):
        continue
      if k.startswith("hparams."):
        if type(v) == str and (v.startswith("{") or v.startswith("[")):
          v = json.loads(v)
        hparams[k[len("hparams."):]] = v
      if k.startswith("metrics."):
        metrics[k[len("metrics."):]] = v
    # Create path.
    path = Path(hparams, comps, parent=None, agent_id=agent_id, task_name=task_name)
    path.metrics = metrics
    path.id = path_id
    # Add train locks.
    for c in path.components:
      c.train_locks.add(agent_id)
    pop.paths[agent_id].append(path)
    sid_2_path[path_sid] = path
    if row["parent_id"] >= 0:
      parent_sid = f'{row["parent_agent_id"]}:{row["parent_id"]}'
      path_2_parent_sid[path] = parent_sid
  # Set parents.
  for path, parent_sid in path_2_parent_sid.items():
    if parent_sid not in sid_2_path:
      # This can happen if parent is retired by a parallel agent.
      # In this case fall back to root model.
      for k in sid_2_path.keys():
        if "root_model" in k:
          parent_sid = k
      print(f"{path.agent_id}:{path.id} orphaned, fallback: {parent_sid}")
    path.parent = sid_2_path[parent_sid]
  # Set reference to components representing sub paths.
  for path in [p for paths in pop.paths.values() for p in paths]:
    if "paths" in path.hparams:
      for k in path.hparams["paths"]:
        sub_path = pop.get_path_from_full_id(path.hparams["paths"][k]["agent_id"], path.hparams["paths"][k]["id"])
        path.components.append(ComponentPath(name=k, path=sub_path))

# Training

In [None]:
@partial(jax.jit, static_argnames="model")
def eval_step(params, inputs, labels, model):
  logits = model.apply({"params": params}, inputs, train=False)
  # Avg accuracy on the batch.
  return (logits.argmax(axis=-1) == labels).mean()

In [None]:
@partial(jax.jit, static_argnames=["model", "optimizer", "format_params_fn"], donate_argnums=[0, 2])
def train_step(params, fixed_params, opt_state, inputs, labels, model, optimizer, format_params_fn):
  def loss_fn(params, fixed_params, inputs, labels):
    logits = model.apply(
        {"params": format_params_fn(merge_params(params, fixed_params))},
        inputs, train=True)
    labels = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))
  grads = jax.grad(loss_fn)(params, fixed_params, inputs, labels)
  updates, opt_state = optimizer.update(grads, opt_state, params=params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [None]:
def execute_train_step(path, train_batch):
  path.params_device, path.opt_state_device = train_step(
      path.params_device,
      path.fixed_params_device,
      path.opt_state_device,
      train_batch["image"],
      train_batch["label"],
      path.model,
      path.optimizer,
      path.model_factory.get_comps2model_fn())

def execute_eval_step(path, eval_batch):
  path.accs.append(
      eval_step(
          path.model_factory.get_comps2model_fn()(merge_params(
              path.params_device, path.fixed_params_device)),
          eval_batch["image"],
          eval_batch["label"],
          path.model))

In [None]:
PREV_LOOP_END = time.time()

def train_loop(paths, ds_train, ds_validation, devices, config):
  global PREV_LOOP_END
  timing = {}
  task = paths[0].task
  for path in paths:
    assert task.name == path.task_name
  for p_id, path in enumerate(paths):
    path.device_id = p_id % len(devices)
    path.device = devices[path.device_id]
    path.optimizer = path.get_optimizer()
    path.best_params_local = None
    path.best_quality = None
    path.best_score = path.parent.score() if path.agent_id == path.parent.agent_id else -np.inf
    path.evals = []
    path.exe_thread = None
  gc.collect()
  # Tranfer parameters to devices.
  for path in paths:
    path.params_device = jax.device_put(path.get_trainable_params(), path.device)
    path.fixed_params_device = jax.device_put(path.get_fixed_params(), path.device)
    path.opt_state_device = jax.jit(path.optimizer.init, device=path.device)(path.params_device)
  iter_ds_validation = iter(ds_validation)
  # Train loop.
  for t_step, train_batch in zip(
      range(config.num_validations_per_path_training
            * task.num_train_batches_between_validations),
      ds_train):
    if t_step == 0:
      timing["start_train"] = time.time()
    for p_id, path in enumerate(paths):
      train_batch = jax.device_put(train_batch, path.device)
      if path.exe_thread is not None:
        path.exe_thread.join()
      path.exe_thread = Thread(target=execute_train_step, args=(path, train_batch))
      path.exe_thread.start()
    if t_step == 0:
      [p.exe_thread.join() for p in paths]
      timing["end_train_compile"] = time.time()
    # Evaluation on validation set.
    if (t_step+1) % task.num_train_batches_between_validations == 0:
      for path in paths:
        path.accs = []
      for e_step, eval_batch in zip(range(task.num_validation_batches), iter_ds_validation):
        if e_step == 0:
          start_eval_round = time.time()
          if "start_eval" not in timing:
            timing["start_eval"] = start_eval_round
        for p_id, path in enumerate(paths):
          eval_batch = jax.device_put(eval_batch, path.device)
          path.exe_thread.join()
          path.exe_thread = Thread(target=execute_eval_step, args=(path,eval_batch))
          path.exe_thread.start()
        if e_step == 0 and "end_eval_compile" not in timing:
          [p.exe_thread.join() for p in paths]
          timing["end_eval_compile"] = time.time()

      # Get params of best models.
      qs = []
      eval_idx = (t_step+1) // task.num_train_batches_between_validations
      for path in paths:
        path.exe_thread.join()
        quality = np.mean(path.accs)
        del path.accs
        qs.append(f"{quality:.4f}")
        path.evals.append(quality)
        # Set quality in metrics for current score computation.
        path.metrics["quality"] = quality
        path_score = path.score()
        if path_score > path.best_score:
          path.best_params_local = jax.device_get(path.params_device)
          path.best_score = path_score
          path.best_quality = quality
          qs[-1] += "*"
      time_train = time.time() - PREV_LOOP_END
      avg_path_time = (time_train / eval_idx) / len(paths)
      print(("\t".join(qs) + f"\t< Eval {eval_idx}").expandtabs(8),
            f"tot:{time_train:.1f}s", f"avg/path:{avg_path_time:.1f}s")
      timing["time_eval"] = timing.get("time_eval", 0) + (time.time() - start_eval_round)
      del eval_batch
  del train_batch
  for path in paths:
    del path.params_device
    del path.fixed_params_device
    del path.opt_state_device
    del path.optimizer
    del path.exe_thread
  gc.collect()

  timing["end_train"] = time.time()

  time_init = timing["start_train"] - PREV_LOOP_END
  time_train_compile = timing["end_train_compile"] - timing["start_train"]
  time_eval_compile = timing["end_eval_compile"] - timing["start_eval"]
  time_eval = timing["time_eval"] - time_eval_compile
  time_train = timing["end_train"] - timing["end_train_compile"] - time_eval - time_eval_compile
  PREV_LOOP_END = timing["end_train"]

  for path in paths:
    path.metrics["time_init"] = time_init
    path.metrics["time_train_compile"] = time_train_compile
    path.metrics["time_eval_compile"] = time_eval_compile
    path.metrics["time_train"] = time_train
    path.metrics["time_eval"] = time_eval
    path.metrics["timestamp_end"] = PREV_LOOP_END
    path.metrics["num_params"] = get_num_params(path.get_all_params())
    path.metrics["num_trainable_params"] = get_num_params(path.get_trainable_params())
    path.metrics["quality"] = max(path.evals)
    path.metrics["evals"] = json.dumps([float(v) for v in path.evals])

    if path.best_params_local != None:
      path.metrics["improved"] = True
      path.update_trainable(path.best_params_local)
      assert path.best_quality == path.metrics["quality"]
      assert path.best_score == path.score()
    else:
      path.metrics["improved"] = False
      # Sampled path will be dropped if not improved, so skip paramter update.
      assert path.best_quality == None

    del path.best_params_local
    del path.best_score
    del path.best_quality
    del path.evals

  pqs = []
  qs = []
  psc = []
  sc = []
  for path in paths:
    if path.task_name == path.parent.task_name:
      metric_suffix = "" if path.agent_id == path.parent.agent_id else "A"
      pqs.append(f"{path.parent.metrics['quality']:.4f}{metric_suffix}")
      psc.append(f"{path.parent.score():.4f}{metric_suffix}")
    else:
      pqs.append("NEW")
      psc.append("NEW")
    qs.append(f"{path.metrics['quality']:.4f}")
    sc.append(f"{path.score():.4f}")
    if path.metrics["improved"]:
      sc[-1] += "+"

  print(("\t".join([f"{path.parent.id}" for path in paths]) + "\t< Parent id").expandtabs(8))
  print(("\t".join([f"{path.id}" for path in paths]) + "\t< Path id").expandtabs(8))
  print(("\t".join(pqs) + "\t< Parent best quality").expandtabs(8))
  print(("\t".join(qs) + "\t< Path best quality").expandtabs(8))
  print(("\t".join(psc) + "\t< Parent score").expandtabs(8))
  print(("\t".join(sc) + "\t< Path score").expandtabs(8))
  print("time\tINIT\tCOMPtrn\tCOMPevl\tTRN\tEVAL".expandtabs(8))
  print(f"(s)\t{time_init:.1f}\t{time_train_compile:.1f}\t{time_eval_compile:.1f}\t{time_train:.1f}\t{time_eval:.1f}".expandtabs(8))

In [None]:
def has_test_quality(path):
  return ("test_quality" in path.metrics and not math.isnan(path.metrics["test_quality"]))

# Run final eval on test set.
def run_test_eval(path, test_immutability=False):
  # Skip if test_quality already computed and no immutability test required.
  if not test_immutability and has_test_quality(path):
    return
  eval_st = time.time()
  ds_test = path.task.get_ds("test", path.hparams)
  params = path.get_all_params()
  # Running on same device can allow to reuse the fn compiled for validation.
  if not hasattr(path, "device"):
    path.device = random.choice(jax.local_devices())
  params_device = jax.device_put(path.model_factory.get_comps2model_fn()(params), path.device)
  acc_sum = []
  tot_num_samples = 0
  # Warning: if repeat() is called on this dataset, then this loop never ends.
  for batch in ds_test:
    acc_avg = eval_step(params_device, batch["image"], batch["label"], path.model)
    batch_size = batch["label"].shape[0]
    # Need to recompute sum because last batch can have different size to allow
    # for exact eval on the test set.
    acc_sum.append(acc_avg * batch_size)
    tot_num_samples += batch_size
  del params_device
  acc_avg = np.sum(acc_sum) / tot_num_samples
  # Assert test quality equivalence to test immutability.
  if has_test_quality(path):
    print(f"Testing immutability of path {path.id} : {path.metrics['test_quality']} ~= {acc_avg}")
    assert test_immutability
    if not np.isclose(path.metrics["test_quality"], acc_avg, rtol=IMMUTABILITY_RELATIVE_TOLLERANCE):
      print("WARNING IMMUTABILITY TEST FAILED, delta:", acc_avg-path.metrics["test_quality"])
    assert np.isclose(path.metrics["test_quality"], acc_avg), \
        f"{path.task_name} {path.metrics['test_quality']} {acc_avg}"
  path.metrics["test_quality"] = acc_avg
  print(f"TEST QUALITY: {acc_avg}\nTEST TIME: {time.time()-eval_st:.2f}s")

# Main

In [None]:
agent = get_agent_class(AGENT)(system_state_dir=SYSTEM_STATE_DIR, task_name=TASK_NAME, num_cycles_max=NUM_CYCLES_MAX)
agent.run()