# License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Instructions

This Notebook allows to reproduce all the experiments reported in the publication titled:

["*muNet: Evolving Pretrained Deep Neural Network into Scalable Auto-tuning Multitask Systems*" (2022)](https://arxiv.org/abs/2205.10937)

---

Set `EXPERIMENT_NAME` to a name of choice.

Set `BENCHMARK` to:

1. `ViT tiny 3 layers / characters benchmark` to reproduce the experiments on the "Multitask Character Classification Benchmark".
1. `ViT base / decathlon benchmark` to reproduce the experiments on the "Visual Domain Decathlon Benchmark".

Set `CONFIGURATION` to:
1. `muNet` to run the muNet evolutionary method with scale factor = 1.
1. `Size scale:X` to run muNet with scale factor = X/100.
1. `Finetune all` to run the corresponding full fine-tuning baseline model.
1. `Freeze bottom layers:X` to run fine-tuning baseline with X layers shared and frozen.
1. `Adapters:X` to run the correspoinding residual adapters baseline with X inner dimension.

Select `AUTOTUNE` to activate auto-tuning for muNet experiments.

Set `EXPERIMENTS_ROOT_DIR` to the desired root directory that will contain experiment directories storing configuration and state.

To reproduce the configuration of the experiments reported in the paper it is required to connect to a TPUv3 machine with 8 cores.

To start the configured experiment select "Run all" from the "Runtime" menu.

The output is printed after the last cell.

Note: this Colab connects by default to a free TPUv2 machine with 8 cores,
while the experiments reported in the paper are executed on a TPUv3 machine.
Thus, the system requirement for larger models (e.g. ViT B/16)
and datasets (e.g. visual_domain_decathlon/imagenet12)
may exceed the capacity of the default instance and may require a custom GCE VM.

In [None]:
# @title Experiment configuration
EXPERIMENT_NAME = 'Experiment'  # @param { type: 'string', isTemplate: true }
BENCHMARK = 'ViT tiny 3 layers / characters benchmark' # @param ['ViT tiny 3 layers / characters benchmark', 'ViT base / decathlon benchmark', 'ViT large / ViT benchmark'] { type: 'string', isTemplate: true }
CONFIGURATION = 'muNet'  # @param ['muNet', 'Size scale:98', 'Size scale:95', 'Size scale:90', 'Size scale:70', 'Size scale:30', 'Size scale:2', 'Finetune all', 'Freeze bottom layers:0', 'Freeze bottom layers:1', 'Freeze bottom layers:2', 'Freeze bottom layers:3', 'Freeze bottom layers:4', 'Freeze bottom layers:12', 'Adapters:8', 'Adapters:16', 'Adapters:32', 'Adapters:64', 'Adapters:128', 'Adapters:256', 'Adapters:512']  { type: 'string', isTemplate: true }
AUTO_TUNE = True  # @param [True, False] { type: 'boolean', isTemplate: true }
EXPERIMENTS_ROOT_DIR = '/tmp/' # @param { type: 'string', isTemplate: true }

if AUTO_TUNE:
  assert CONFIGURATION == 'muNet' or CONFIGURATION.startswith('Size scale:'), \
      f'Invalid configuration for auto-tune: {CONFIGURATION}'

In [None]:
# @title Additional parameters
# Set to true to continue interrupted experiment with matching EXPERIMENT_NAME
AUTO_CONTINUE = False  # @param [True, False] { type: 'boolean', isTemplate: true }
# Print debug statements.
VERBOSE = False  # @param [True, False] { type: 'boolean', isTemplate: true }

In [None]:
!pip install -q flax

In [None]:
!pip install -q ml_collections

In [None]:
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!pip install -qr vision_transformer/vit_jax/requirements.txt
import sys
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')

In [None]:
import copy
import datetime
import jax
import jax.numpy as jnp
import json
import math
import matplotlib
import numpy as np
import random
import re
import os
import optax
import pandas as pd
import time
from collections import defaultdict
from functools import partial
from matplotlib import pyplot as plt
from threading import Thread
from typing import Optional

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
import flax
import flax.linen as nn
from flax.training import checkpoints as flax_checkpoints

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
tf.compat.v1.enable_eager_execution()

In [None]:
from ml_collections import ConfigDict, FrozenConfigDict
from vision_transformer.vit_jax import input_pipeline
from vision_transformer.vit_jax import checkpoint
from vision_transformer.vit_jax.configs import models as models_config  # Model configurations.
from vision_transformer.vit_jax import models_vit as models # Actual model code.

In [None]:
# Ref Tfds catalog: https://www.tensorflow.org/datasets/catalog/beans
TFDS_IMAGE_CLASSIFCATON_DATASETS = set([
   'emnist/digits',
   'emnist/letters',
   'kmnist',
   'mnist',
   'omniglot',
   'cmaterdb/bangla',
   'cmaterdb/devanagari',
   'cmaterdb/telugu',
   'visual_domain_decathlon/imagenet12',
   'visual_domain_decathlon/svhn',
   'visual_domain_decathlon/cifar100',
   'visual_domain_decathlon/gtsrb',
   'visual_domain_decathlon/daimlerpedcls',
   'visual_domain_decathlon/omniglot',
   'visual_domain_decathlon/ucf101',
   'visual_domain_decathlon/aircraft',
   'visual_domain_decathlon/dtd',
   'visual_domain_decathlon/vgg-flowers',
])

In [None]:
# tfds.builder is slow, this build a cache in the background using parallel threads.
# Call TfdsBuildersCache.regenerate() to force regeneration after editing list of tasks.
class TfdsBuildersCache():
  class Worker():
    def __init__(self, tfds_name):
      self.tfds_name = tfds_name
      self.thread = Thread(
        target=self.set_builder, args=())
      self.thread.start()
    def set_builder(self):
      self.builder = tfds.builder(self.tfds_name)
    def get_builder(self):
      self.thread.join()
      return self.builder
  def initalize():
    if 'TFDS_BUILDERS_CACHE' not in globals():
      print('CREATING TFDS_BUILDERS_CACHE')
      global TFDS_BUILDERS_CACHE
      TFDS_BUILDERS_CACHE = {}
      workers = []
      for tfds_name in TFDS_IMAGE_CLASSIFCATON_DATASETS:
        workers.append(TfdsBuildersCache.Worker(tfds_name))
      for worker in workers:
        assert worker.tfds_name not in TFDS_BUILDERS_CACHE
        TFDS_BUILDERS_CACHE[worker.tfds_name] = worker
  def get(tfds_name):
    return TFDS_BUILDERS_CACHE[tfds_name].get_builder()

  def regenerate():
    if 'TFDS_BUILDERS_CACHE' in globals():
      print('REGENERATING TFDS_BUILDERS_CACHE')
      global TFDS_BUILDERS_CACHE
      del TFDS_BUILDERS_CACHE
    TfdsBuildersCache.initalize()

TfdsBuildersCache.initalize()
# TfdsBuildersCache.regenerate()

In [None]:
def get_splits(tfds_name):
  info = TfdsBuildersCache.get(tfds_name).info
  splits = list(info.splits.keys())
  assert 'train' in splits, splits
  splits.remove('train')
  used_percent = 0
  slice_percent = 5
  pp = {}
  for k in ['test', 'validation']:
    if k in splits:
      pp[k] = k
      splits.remove(k)
    else:
      pp[k] = f'train[{used_percent}%:{used_percent+slice_percent}%]'
      used_percent += slice_percent
  pp['train'] = f'train[{used_percent}%:]'
  return pp

# Task names must be unique and immutable across experiments to allow reloads.
def add_dataset_config(
    tasks_configs,
    tfds_name,
    unique_name=None,
    private=False):
  if tfds_name in ['imagenet_v2', 'cifar10_1']:
    return  # Used as validation set for other tasks.

  config = ConfigDict()
  if tfds_name == 'imagenet2012':
    config.dataset = {
        'train':'imagenet2012', 'validation':'imagenet_v2', 'test':'imagenet2012'}
    config.splits = {
        'train':'train', 'validation':'test', 'test':'validation'}
  elif tfds_name == 'cifar100':
    config.dataset = tfds_name
    config.splits = {
        'train':'train[:98%]', 'validation':'train[98%:]', 'test':'test'}
  elif tfds_name == 'cifar10':
    config.dataset = {
        'train':'cifar10', 'validation':'cifar10_1', 'test':'cifar10'}
    config.splits = {
        'train':'train', 'validation':'test', 'test':'test'}
  elif tfds_name.startswith('visual_domain_decathlon'):
    config.dataset = tfds_name
    # test has no labels, split validation in half.
    config.splits =  {
        'train':'train', 'validation':'validation[:50%]', 'test':'validation[50%:]'}
  elif tfds_name == 'omniglot':
    # test has no labels, and missing validation, use additional splits.
    config.dataset = tfds_name
    config.splits = {'train':'train', 'validation':'small1', 'test':'small2'}
  else:
    config.dataset = tfds_name
    config.splits = get_splits(tfds_name)
  config.unique_name = unique_name if unique_name else tfds_name
  config.private = private
  assert unique_name not in tasks_configs
  tasks_configs[config.unique_name] = FrozenConfigDict(config)

def get_task_configs():
  task_configs = {}

  # Add standard tasks.
  for tfds_name in TFDS_IMAGE_CLASSIFCATON_DATASETS:
    add_dataset_config(task_configs, tfds_name)

  # Add private tasks.
  tfds_names_private = []
  for tfds_name in TFDS_IMAGE_CLASSIFCATON_DATASETS:
    if tfds_name.startswith('visual_domain_decathlon/'):
      tfds_names_private.append(tfds_name)
  for tfds_name in tfds_names_private:
    add_dataset_config(
        task_configs,
        tfds_name,
        unique_name=f'private:{tfds_name}',
        private=True)

  return task_configs

In [None]:
def ids_str2ints(ids_str):
  return [int(v) for v in ids_str.split('_')] if ids_str else []
def ids_ints2str(ids_ints):
  return '_'.join([str(v) for v in sorted(ids_ints)])

In [None]:
AddPositionEmbs = models.AddPositionEmbs
Encoder1DBlock = models.Encoder1DBlock
VisionTransformer = models.VisionTransformer

class ResidualAdapter(nn.Module):
  adapter_dim: int

  @nn.compact
  def __call__(self, x):
    hidden_dim = x.shape[-1]
    y = nn.LayerNorm()(x)
    y = nn.Dense(self.adapter_dim)(y)
    y = nn.gelu(y)
    # Default initalization.
    # y = nn.Dense(hidden_dim)(y)
    # Initialization from https://arxiv.org/pdf/1902.00751.pdf
    # y = nn.Dense(hidden_dim, kernel_init=nn.initializers.normal(stddev=1e-3))(y)
    # Zero Initialization so that added adapter does not change the representation.
    y = nn.Dense(hidden_dim, kernel_init=jax.nn.initializers.zeros)(y)
    return x + y  # Residual.

# Modified from vision_transformer/vit_jax/models Encoder to add residual adapters.
class Encoder(nn.Module):
  num_layers: int
  mlp_dim: int
  num_heads: int
  adapter_layers: str  # <MOD
  adapter_dim: int  # MOD>
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, inputs, *, train):
    assert inputs.ndim == 3  # (batch, len, emb)

    x = AddPositionEmbs(
        posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
        name='posembed_input')(
            inputs)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

    # Input Encoder
    adapter_layers_ids = ids_str2ints(self.adapter_layers)  # <MOD>
    for lyr in range(self.num_layers):
      if lyr in adapter_layers_ids:  # <MOD
        x = ResidualAdapter(
            adapter_dim=self.adapter_dim,
            name=f'residual_adapter_{lyr}'
            )(x)  # MOD>
      x = Encoder1DBlock(
          mlp_dim=self.mlp_dim,
          dropout_rate=self.dropout_rate,
          attention_dropout_rate=self.attention_dropout_rate,
          name=f'encoderblock_{lyr}',
          num_heads=self.num_heads)(
              x, deterministic=not train)
    encoded = nn.LayerNorm(name='encoder_norm')(x)
    return encoded

In [None]:
def get_vit_filename(query):
  df = checkpoint.get_augreg_df()
  res = df.query(query).filename.unique()
  assert len(res) == 1
  return res[0]

In [None]:
USE_DROPOUT = False
VIT_CONFIG_CACHE = {}

def get_vit_config(query):
  if query not in VIT_CONFIG_CACHE:
    filename = get_vit_filename(query)
    config = models_config.AUGREG_CONFIGS[filename.split('-')[0]].copy_and_resolve_references()
    # Ovewrite with custom Encoder.
    config.unlock()
    config.encoder = Encoder
    config.transformer.adapter_layers = ''
    config.transformer.adapter_dim = -1
    if not USE_DROPOUT:
      config.transformer.dropout_rate = 0.0
      config.transformer.attention_dropout_rate = 0.0
    config.lock()
    VIT_CONFIG_CACHE[query] = config
  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()

def get_max_num_layers(query):
  config = get_vit_config(query)
  return config.transformer.num_layers

In [None]:
DATASET_HPARAMS_KEYS_PRERFIX = 'ds_'
OPTIMIZER_HPARAMS_KEYS_PRERFIX = 'opt_'

def get_exp_config_ti3_chars():
  exp_config = ConfigDict()
  exp_config.experiment_name = EXPERIMENT_NAME
  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR
  # Capping the size of an epoch.
  exp_config.num_train_batches_between_validations_max = 100
  exp_config.num_validations_per_path_training = 5
  exp_config.num_validation_batches_max = 10
  exp_config.batch_size = 512
  exp_config.num_task_iters = 2
  exp_config.num_samples_per_task = 8*8
  exp_config.mutation_prob = 0.1
  exp_config.mutate_adapters = True
  # Force finetune last layer norm that technically is part of the head.
  exp_config.force_finetune_components = ['encoder_norm']
  # Population policy params:
  exp_config.policy_class = 'PPDecay'
  exp_config.policy_kwargs = {}
  # Scorer params:
  exp_config.scorer_class = 'ScorerDecay'
  exp_config.scorer_kwargs = dict(
      base=1.0,
      num_params=1_484_162,  # 1_484_162 params in Ti/16 with 3 layers.
      )

  # Seed models params:
  exp_config.load_rand_init = False
  exp_config.load_vit_checkpoint = True
  exp_config.load_vit_checkpoint_query = 'name=="Ti/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'
  exp_config.load_experiment = False
  exp_config.load_experiment_dir = ''

  # Hyperparameters:
  exp_config.models_default_hparams = {
      # Default num_classes has no effect since it is always overwritten or used
      # for rand init models whose head is always replaced.
      'num_classes': 1,
      # Set to ids_ints2str(range(max_num_layers)) to activate all adapters.
      'adapter_layers': '',
      'num_layers': 3,
      'adapter_dim': 32,
      'opt_lr': 0.01,
      'opt_lr_schedule': 'cosine',
      'opt_lr_warmup_ratio': 0.1,
      'opt_momentum': 0.9,
      'opt_nesterov': False,
      'ds_image_size': 32,
      'ds_area_range_min': 0.05,
      'ds_aspect_ratio_range_min': 0.75,
      'ds_flip_left_right': True,
      'ds_brightness_delta': 0.0,
      'ds_contrast_delta': 0.0,
      'ds_saturation_delta': 0.0,
      'ds_hue_delta': 0.0,
  }

  exp_config.models_mutation_ranges = {
      'num_layers': list(range(1, exp_config.models_default_hparams['num_layers']+1)),
  }

  # Tasks params:
  exp_config.task_configs = get_task_configs()
  # Tasks to train on during this experiment.
  exp_config.task_names = \
  [
   'emnist/digits',
   'emnist/letters',
   'kmnist',
   'mnist',
   'omniglot',
   'cmaterdb/bangla',
   'cmaterdb/devanagari',
   'cmaterdb/telugu',
   ]
  exp_config_validate(exp_config)
  return exp_config

def get_exp_config_base_deca():
  exp_config = ConfigDict()
  exp_config.experiment_name = EXPERIMENT_NAME
  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR
  exp_config.num_train_batches_between_validations_max = 200
  exp_config.num_validations_per_path_training = 30
  exp_config.num_validation_batches_max = 10
  exp_config.batch_size = 256
  exp_config.num_task_iters = 2
  exp_config.num_samples_per_task = 8*8
  exp_config.mutation_prob = 0.1
  exp_config.mutate_adapters = True
  exp_config.force_finetune_components = ['encoder_norm']
  # Population policy params:
  exp_config.policy_class = 'PPDecay'
  exp_config.policy_kwargs = {}
  # Scorer params:
  exp_config.scorer_class = 'ScorerDecay'
  exp_config.scorer_kwargs = dict(
      base=1.0,
      num_params=85_652_738,  # 85_652_738 params in B/16
      )
  # Seed models params:
  exp_config.load_rand_init = False
  exp_config.load_vit_checkpoint = True
  exp_config.load_vit_checkpoint_query = 'name=="B/16" and ds=="i21k" and aug=="medium1" and wd==0.1 and sd==0'
  exp_config.load_experiment = False
  exp_config.load_experiment_dir = ''
  # Hyperparameters:
  max_num_layers = get_max_num_layers(exp_config.load_vit_checkpoint_query)
  exp_config.models_default_hparams = {
      'num_classes': 1,
      'adapter_layers': '',
      'num_layers': max_num_layers,
      'adapter_dim': 32,
      'opt_lr': 0.01,
      'opt_lr_schedule': 'cosine',
      'opt_lr_warmup_ratio': 0.1,
      'opt_momentum': 0.9,
      'opt_nesterov': False,
      'ds_image_size': 80,
      'ds_area_range_min': 0.05,
      'ds_aspect_ratio_range_min': 0.75,
      'ds_flip_left_right': True,
      'ds_brightness_delta': 0.0,
      'ds_contrast_delta': 0.0,
      'ds_saturation_delta': 0.0,
      'ds_hue_delta': 0.0,
  }

  exp_config.models_mutation_ranges = {
      'num_layers': list(range(1, exp_config.models_default_hparams['num_layers']+1)),
  }

  exp_config.task_configs = get_task_configs()
  exp_config.task_names = [
      'visual_domain_decathlon/imagenet12',
      'visual_domain_decathlon/svhn',
      'visual_domain_decathlon/cifar100',
      'visual_domain_decathlon/gtsrb',
      'visual_domain_decathlon/daimlerpedcls',
      'visual_domain_decathlon/omniglot',
      'visual_domain_decathlon/ucf101',
      'visual_domain_decathlon/aircraft',
      'visual_domain_decathlon/dtd',
      'visual_domain_decathlon/vgg-flowers',
      ]
  exp_config_validate(exp_config)
  return exp_config

def get_exp_config_large():
  exp_config = ConfigDict()
  exp_config.experiment_name = EXPERIMENT_NAME
  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR

  # 1/10th of epoch for imagenet to have similar ratio of exps reported in:
  # https://arxiv.org/abs/2106.10270
  exp_config.num_train_batches_between_validations_max = 4000
  exp_config.num_validations_per_path_training = 2
  # 312 * 32 ~= 10k size of imagenet2012 validation set.
  exp_config.num_validation_batches_max = 312
  # Reduced batch size to fit in HBM, but increased num batches.
  exp_config.batch_size = 32
  exp_config.num_task_iters = 32
  exp_config.num_samples_per_task = 8*2
  exp_config.mutation_prob = 0.1
  exp_config.mutate_adapters = True
  exp_config.force_finetune_components = ['encoder_norm']
  # Population policy params:
  exp_config.policy_class = 'PPDecay'
  exp_config.policy_kwargs = {}
  # Scorer params:
  exp_config.scorer_class = 'ScorerDecay'
  exp_config.scorer_kwargs = dict(
      base=1.0,
      num_params=303_303_682,  # Params in L/16
      )
  # Seed models params:
  exp_config.load_rand_init = False
  exp_config.load_vit_checkpoint = True
  exp_config.load_vit_checkpoint_query = 'name=="L/16" and ds=="i21k" and aug=="medium2" and wd==0.03 and sd==0.1'
  # 'name=="L/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'
  exp_config.load_experiment = False
  exp_config.load_experiment_dir = ''
  # Hyperparameters:
  max_num_layers = get_max_num_layers(exp_config.load_vit_checkpoint_query)
  exp_config.models_default_hparams = {
      'num_classes': 1,
      'adapter_layers': '',
      'num_layers': max_num_layers,
      'adapter_dim': 32,
      'opt_lr': 0.01,
      'opt_lr_schedule': 'cosine',
      'opt_lr_warmup_ratio': 0.05,
      'opt_momentum': 0.9,
      'opt_nesterov': False,
      'ds_image_size': 384,
      'ds_area_range_min': 0.05,
      'ds_aspect_ratio_range_min': 0.75,
      'ds_flip_left_right': True,
      'ds_brightness_delta': 0.0,
      'ds_contrast_delta': 0.0,
      'ds_saturation_delta': 0.0,
      'ds_hue_delta': 0.0,
  }

  exp_config.models_mutation_ranges = {}

  exp_config.task_configs = get_task_configs()
  exp_config.task_names = [
      'imagenet2012',
      'cifar100',
      'cifar10',
      ]
  exp_config_validate(exp_config)
  return exp_config

def exp_config_add_auto_tune(exp_config):
  exp_config.models_mutation_ranges['adapter_dim'] = [8, 16, 32, 64, 128]
  exp_config.models_mutation_ranges['opt_lr'] = [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]
  exp_config.models_mutation_ranges['opt_lr_schedule'] = ['constant', 'cosine', 'restarts']
  exp_config.models_mutation_ranges['opt_lr_warmup_ratio'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4]
  exp_config.models_mutation_ranges['opt_momentum'] = [None, 0.2, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
  exp_config.models_mutation_ranges['opt_nesterov'] = [True, False]
  exp_config.models_mutation_ranges['ds_image_size'] = [ 16 * i for i in (range(1, 1+int(384/16))) ]
  exp_config.models_mutation_ranges['ds_area_range_min'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
  exp_config.models_mutation_ranges['ds_aspect_ratio_range_min'] = [0.25, 0.5, 0.75, 1.0]
  exp_config.models_mutation_ranges['ds_flip_left_right'] = [True, False]
  exp_config.models_mutation_ranges['ds_brightness_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]
  exp_config.models_mutation_ranges['ds_contrast_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]
  exp_config.models_mutation_ranges['ds_saturation_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]
  exp_config.models_mutation_ranges['ds_hue_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]
  return exp_config

def exp_config_validate(exp_config):
  for khp in exp_config.models_default_hparams:
    if khp in exp_config.models_mutation_ranges:
      assert exp_config.models_default_hparams[khp] \
          in exp_config.models_mutation_ranges[khp]

def exp_config_set_size_scale(exp_config, base_percent:int):
  exp_config.scorer_kwargs['base'] = float(base_percent) / 100.0
  return exp_config

def exp_config_set_baseline_common(exp_config):
  parallelism = jax.local_device_count()
  assert (int(exp_config.num_samples_per_task / parallelism) ==
          exp_config.num_samples_per_task / parallelism)
  exp_config.num_validations_per_path_training *= \
      exp_config.num_task_iters \
      * int(exp_config.num_samples_per_task/parallelism)
  exp_config.num_task_iters = 1
  exp_config.num_samples_per_task = parallelism
  exp_config.models_mutation_ranges = {}
  exp_config.policy_class = 'PPBaseline'
  exp_config.policy_kwargs = {}
  exp_config_validate(exp_config)
  return exp_config

def exp_config_set_baseline_finetune_all(exp_config):
  exp_config = exp_config_set_baseline_common(exp_config)
  exp_config.mutation_prob = 1.0
  exp_config.mutate_adapters = False
  exp_config.models_default_hparams['adapter_layers'] = ''
  exp_config_validate(exp_config)
  return exp_config

def exp_config_set_baseline_freeze_bottom_layers(exp_config, num_layers:int):
  exp_config = exp_config_set_baseline_common(exp_config)
  max_num_layers = exp_config.models_default_hparams['num_layers']
  assert max_num_layers >= num_layers
  unfrozen_layers = [f'encoderblock_{id}' for id in range(num_layers, max_num_layers)]
  exp_config.force_finetune_components = ['encoder_norm'] + unfrozen_layers
  exp_config.mutation_prob = 0.0
  exp_config.mutate_adapters = False
  exp_config.models_default_hparams['adapter_layers'] = ''
  exp_config_validate(exp_config)
  return exp_config

def exp_config_set_baseline_adapters(exp_config, adapter_dim:int):
  exp_config = exp_config_set_baseline_common(exp_config)
  # To unfreeze all layer norms in the model also set GATHER_LAYER_NORMS to True.
  exp_config.force_finetune_components = ['encoder_norm']
  exp_config.mutation_prob = 0.0
  exp_config.mutate_adapters = True
  max_num_layers = exp_config.models_default_hparams['num_layers']
  exp_config.models_default_hparams['adapter_layers'] = ids_ints2str(
      range(max_num_layers))
  exp_config.models_default_hparams['adapter_dim'] = adapter_dim
  exp_config_validate(exp_config)
  return exp_config

In [None]:
def get_sample_image(image_size:int, batch_size:int):
  return np.zeros((batch_size, image_size, image_size, 3))

def get_sample_label(batch_size:int):
  return np.zeros(batch_size, dtype=np.int32)

In [None]:
def get_vit_checkpoint(image_size, query):
  filename = get_vit_filename(query)

  config = get_vit_config(query)

  model = VisionTransformer(**config, num_classes=2)  # num_classes unused.
  init_params = model.init(jax.random.PRNGKey(0),
                           get_sample_image(image_size=image_size,
                                            batch_size=1),
                           train=USE_DROPOUT)['params']

  params = checkpoint.load_pretrained(
    pretrained_path=f'gs://vit_models/augreg/{filename}.npz',
    init_params=init_params,
    model_config=config)

  return params

def get_vit_checkpoint_mapped(image_size, query):
  params = get_vit_checkpoint(image_size, query)
  params = params_model_to_comps(params)
  return params

def get_reshaped_posembed_component(image_size, query):
  params = get_vit_checkpoint_mapped(image_size, query)['posembed_input']
  return Component(name='posembed_input',
                   params=params,
                   train_locks=[NOT_TRAINABLE])

In [None]:
# Parameter mapping.
TRANSFORMER_KEYS = set()
# Set this to True to unfreeze all the layernorms in the model.
# Can be useful for variants of the residual adapters baseline.
GATHER_LAYER_NORMS = False

def params_model_to_comps(params):
  global TRANSFORMER_KEYS
  TRANSFORMER_KEYS.update(params['Transformer'].keys())
  new_params = {}
  for k in params.keys():
    if k == 'Transformer':
      t_params = params[k]
      for t_k in t_params.keys():
        new_params[t_k] = t_params[t_k]
    else:
      new_params[k] = params[k]
  params = flax.core.freeze(new_params)

  if GATHER_LAYER_NORMS:
    params = params.unfreeze()
    params['encoder_norm']['gathered'] = {}
    for k in params.keys():
      if k.startswith('encoderblock_'):
        params['encoder_norm']['gathered'][k] = {}
        encoderblock_keys = list(params[k].keys())
        for ek in encoderblock_keys:
          if ek.startswith('LayerNorm_'):
            params['encoder_norm']['gathered'][k][ek] = params[k].pop(ek)

  return flax.core.freeze(params)

def params_comps_to_model(params):
  params = params.unfreeze()

  if GATHER_LAYER_NORMS:
    gathered = params['encoder_norm'].pop('gathered')
    for k in gathered:
      assert k.startswith('encoderblock_')
      assert k in params
      for ke in gathered[k].keys():
        assert ke.startswith('LayerNorm_')
        assert ke not in params[k]
        params[k][ke] = gathered[k][ke]

  params['Transformer'] = {}
  keys = list(params.keys())
  assert len(TRANSFORMER_KEYS) != 0
  for k in keys:
    if k in TRANSFORMER_KEYS:
      params['Transformer'][k] = params.pop(k)
  return flax.core.freeze(params)

In [None]:
def get_model_kwargs(hparams, exp_config):
  # Validate adapters params.
  for v in ids_str2ints(hparams['adapter_layers']):
    assert v < hparams['num_layers']
  return {
        'num_classes': int(hparams['num_classes']),
        'num_layers': int(hparams['num_layers']),
        'image_size': int(hparams['ds_image_size']),
        'adapter_layers': str(hparams['adapter_layers']),
        'adapter_dim': int(hparams['adapter_dim']),
        'query': str(exp_config.load_vit_checkpoint_query),
    }

def get_vit_model(num_classes, num_layers, adapter_layers, adapter_dim, query):
  config = get_vit_config(query)
  config['transformer']['num_layers'] = num_layers
  config['transformer']['adapter_layers'] = adapter_layers
  config['transformer']['adapter_dim'] = adapter_dim
  config = FrozenConfigDict(config)
  model = VisionTransformer(**config, num_classes=num_classes)
  return model

def get_vit_model_and_params(
    num_classes, num_layers, image_size, adapter_layers, adapter_dim, query,
    rng_key=0):
  model = get_vit_model(
      num_classes, num_layers, adapter_layers, adapter_dim, query)
  init_params = model.init(
      jax.random.PRNGKey(rng_key),
      get_sample_image(image_size=image_size, batch_size=1),
      train=USE_DROPOUT)['params']
  return model, init_params

def get_vit_model_and_params_mapped(**kwargs):
  model, init_params = get_vit_model_and_params(**kwargs)
  init_params = params_model_to_comps(init_params)
  return model, init_params

In [None]:
def format_params(a, b):
  params = a.copy(b)
  assert len(params) == len(a) + len(b)  # Dicts should not overlap.
  params = params_comps_to_model(params)
  return params

In [None]:
def get_optimizer(
    lr: float,
    lr_schedule: str,
    lr_warmup_ratio: float,
    momentum: float,
    nesterov: bool,
    num_train_batches_between_validations: int,
    num_validations_per_path_training: int,
    ):
  if lr_schedule == 'constant':
    # Divide by 2 so that average lr is the same as other types.
    learning_rate = 0.5 * lr
  elif lr_schedule == 'cosine':
    train_steps = int(num_train_batches_between_validations
                      * num_validations_per_path_training)
    learning_rate = optax.warmup_cosine_decay_schedule(
        init_value=lr/100.0,
        peak_value=lr,
        warmup_steps=int(lr_warmup_ratio * train_steps),
        decay_steps=train_steps)
  elif lr_schedule == 'restarts':
    train_steps = num_train_batches_between_validations
    repeats = num_validations_per_path_training
    kwargs = dict(
        init_value=lr/100.0,
        peak_value=lr,
        warmup_steps=int(lr_warmup_ratio * train_steps),
        decay_steps=train_steps,
    )
    kwargs = [kwargs] * repeats
    learning_rate = optax.sgdr_schedule(kwargs)
  else:
    assert False, f'Invalid lr schedule: {lr_schedule}'

  return optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.sgd(
          learning_rate=learning_rate,
          momentum=momentum,
          nesterov=nesterov,
          accumulator_dtype=jnp.bfloat16))

In [None]:
class Task():
  def __init__(self, name, exp_config):
    self.exp_config = exp_config
    if name.startswith(NOT_TRAINABLE):
      self.name = name
      self.private = False
      return
    self.config = exp_config.task_configs[name]
    self.name = name
    self.private = self.config.private
    self.num_classes = self.get_builder('train').info.features['label'].num_classes
    num_train_examples = self.get_builder('train').info.splits[self.config.splits['train']].num_examples
    self.train_batch_size = exp_config.batch_size
    self.num_train_batches_between_validations = min(
        math.ceil(num_train_examples / self.train_batch_size),
        exp_config.num_train_batches_between_validations_max)
    self.cache_train = num_train_examples < min(100_000, (
        exp_config.num_validations_per_path_training
        * self.num_train_batches_between_validations
        * self.train_batch_size))

    num_validation_examples_tot = self.get_builder('validation').info.splits[self.config.splits['validation']].num_examples
    num_validation_examples_max = exp_config.batch_size * exp_config.num_validation_batches_max
    if num_validation_examples_max <= num_validation_examples_tot:
      self.num_validation_batches = exp_config.num_validation_batches_max
      self.validation_batch_size = exp_config.batch_size
    else:
      # Adjust batch_size and num_batches to cover the smaller validation sets.
      self.num_validation_batches = math.ceil(
          num_validation_examples_tot / exp_config.batch_size)
      self.validation_batch_size = math.floor(
          num_validation_examples_tot / self.num_validation_batches)
      assert num_validation_examples_tot >= (self.num_validation_batches*self.validation_batch_size)
    self.num_validation_examples = self.num_validation_batches * self.validation_batch_size

    print(f'Task: {self.name}')
    print(f'  Train batches between validations: {self.num_train_batches_between_validations}')
    print(f'  Validation batches: {self.num_validation_batches}')
    print(f'  Validation batch size: {self.validation_batch_size}')
    print(f'  Dataset {{\n{self.config.dataset}}}')
    print(f'  Splits {{\n{self.config.splits}}}')


  def get_builder(self, mode):
    if type(self.config.dataset) == str:
      return TfdsBuildersCache.get(self.config.dataset)
    return TfdsBuildersCache.get(self.config.dataset[mode])

  def __str__(self):
    return f'Task_{self.name}'
  def is_trainable(self):
    return not self.name.startswith(NOT_TRAINABLE)
  def is_private(self):
    return self.private

  def get_ds(self, mode, hparams):
    builder = self.get_builder(mode)
    builder.download_and_prepare()
    data = builder.as_dataset(
        split=self.config.splits[mode],
        shuffle_files=mode=='train')

    def _pp(data):
      im = data['image']
      im = tf.cast(im, tf.float32)
      # Must have 3 channels.
      if im.shape[-1] == 1:
        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)
      assert im.shape[-1] == 3
      # Values in range [-1 , 1]
      im = im / 127.5 - 1

      if mode == 'train':
        if hparams['ds_area_range_min'] < 1.0:
          channels = im.shape[-1]
          begin, size, _ = tf.image.sample_distorted_bounding_box(
              tf.shape(im),
              tf.zeros([0, 0, 4], tf.float32),
              aspect_ratio_range=[hparams['ds_aspect_ratio_range_min'],
                                  1.0/hparams['ds_aspect_ratio_range_min']],
              area_range=[hparams['ds_area_range_min'], 1.0],
              # Overlap with bounding box, the bounding box should anyway
              # default defaults to whole image in this case.
              min_object_covered=0,
              use_image_if_no_bounding_boxes=True)
          im = tf.slice(im, begin, size)
          # Restore the depth-dimension lost by the above operation.
          im.set_shape([None, None, channels])
        if hparams['ds_flip_left_right']:
          if tf.random.uniform(shape=[]) > 0.5:
            im = tf.image.flip_left_right(im)
        if hparams['ds_brightness_delta'] > 0.0:
          im = tf.image.random_brightness(
              im, max_delta=hparams['ds_brightness_delta'])
        if hparams['ds_contrast_delta'] > 0.0:
          im = tf.image.random_contrast(
              im, lower=1 - hparams['ds_contrast_delta'],
              upper=1 + hparams['ds_contrast_delta'])
        if hparams['ds_saturation_delta'] > 0.0:
          im = tf.image.random_saturation(
              im, lower=1 - hparams['ds_saturation_delta'],
              upper=1 + hparams['ds_saturation_delta'])
        if hparams['ds_hue_delta'] > 0.0:
          im = tf.image.random_hue(im, max_delta=hparams['ds_hue_delta'])

      im = tf.image.resize(im, [hparams['ds_image_size'],
                                hparams['ds_image_size']])
      im = tf.clip_by_value(im, -1, 1)

      return {'image': im, 'label': data['label']}

    if mode == 'validation':
      data = data.take(self.num_validation_examples)
    if mode == 'validation' or (mode == 'train' and self.cache_train):
      data = data.cache()
    if mode != 'test':
      data = data.repeat()
    data = data.map(_pp, tf.data.AUTOTUNE)
    if mode == 'train':
      batch_size = self.train_batch_size
    else:
      batch_size = self.validation_batch_size
    data = data.batch(batch_size)
    if mode == 'train':
      data = data.shuffle(10)
    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))

def get_task_factory_fn(exp_config):
  def get_task(task_name):
    return Task(name=task_name, exp_config=exp_config)
  return get_task

NOT_TRAINABLE = 'NOT_TRAINABLE'
not_trainable = Task(NOT_TRAINABLE, None)

In [None]:
def get_num_params(params):
  return sum(jax.tree.flatten(
      jax.tree.map(lambda p: np.prod(p.shape), params)
      )[0])

In [None]:
def params2comps(params, train_locks , name=None):
  """Convert frozend dict of params to a list of components."""
  components = []
  for k in params:
    if name is None or name == k:
      c = Component(name=k, params=params[k], train_locks=train_locks)
      components.append(c)
  return components

def params2comp_names(params):
  return list(params.keys())

In [None]:
def fingerprint_params(params):
  return np.sum(np.array(jax.tree.leaves(jax.tree.map(jnp.sum, params))))

class Component():
  counter = 0
  def reset_globals():
    Component.counter = 0
  def __init__(self, name:str, params, train_locks:set):
    self.name = name
    self.params = jax.device_get(params)
    self.num_params = None
    self.train_locks = set(train_locks)
    self.id = Component.counter
    Component.counter += 1

  def __str__(self):
    rtn = f'Component: {self.id}\n  Name: {self.name}'
    rtn += f'\n  Train locks: {self.train_locks}'
    rtn += f'\n  Fingerprint: {self.fingerprint()}'
    rtn += f'\n  Num params: {self.num_params}'
    return rtn

  def get_num_params(self):
    if self.num_params is None:
      self.num_params = get_num_params(self.params)
    return self.num_params

  def fingerprint(self):
    return fingerprint_params(self.params)

  def is_trainable(self):
    return len(self.train_locks) == 0

  def clone(self):
    return Component(name=self.name,
                     params=copy.deepcopy(jax.device_get(self.params)),
                     train_locks=set())

In [None]:
class ObjectCache():
  def __init__(self, factory_fn):
    self.factory_fn = factory_fn
    self.cache = {}
  def __call__(self, *args, **kwargs):
    assert not args
    key = json.dumps(kwargs, sort_keys=True)
    if key not in self.cache:
      self.cache[key] = self.factory_fn(**kwargs)
      # print(f"Added to cache: {self.factory_fn.__name__}({key})  [cache size {len(self.cache)}]")
    return self.cache[key]

In [None]:
def incremental_mutation(value, values_list:list):
  assert value in values_list, f'{value} not in {values_list}'
  idx = values_list.index(value)
  idx += 1 if np.random.uniform() < 0.5 else -1
  idx = max(0, min(len(values_list)-1, idx))
  return values_list[idx]

def random_mutation(values_list:list):
  return np.random.choice(values_list)

In [None]:
class Path():

  def reset_globals(exp_config):
    Path.exp_config = exp_config
    Path.counter = 0
    Path.paths = []
    Path.scorer = None  # To be set to scorer of choice during init of exp.
    # Cache output of functions calls with same args.
    Path.tasks = ObjectCache(get_task_factory_fn(exp_config))
    Path.posembed_components = ObjectCache(get_reshaped_posembed_component)
    Path.optimizers = ObjectCache(get_optimizer)
    Path.models = ObjectCache(get_vit_model)

  def __init__(self, hparams, components, parent, task:Task):
    self.components = components
    self.id = Path.counter
    Path.counter += 1
    self.task = task
    self.parent = parent
    self.hparams = hparams
    self.metrics = {
        'offsprings': 0,
        'reloads': 0,
        'generation': 0 if parent is None else parent.metrics['generation']+1,
        'private': task.is_private(),
    }
    self.model = Path.models(
        num_classes=int(hparams['num_classes']),
        num_layers=int(hparams['num_layers']),
        adapter_layers=str(hparams['adapter_layers']),
        adapter_dim=int(hparams['adapter_dim']),
        query=str(self.exp_config.load_vit_checkpoint_query))
    Path.paths.append(self)

  def __str__(self):
    rtn = f"Path: {self.id}"
    rtn += f"\n  Components: {[c.id for c in self.components]}"
    if self.parent:
      rtn += f"\n  Parent: {self.parent.id}"
    rtn += f"\n  Task: {self.task.name}"
    rtn += f"\n  Total Parameters: {get_num_params(self.get_all_params())}"
    rtn += f"\n  Accounted params: {self.accounted_num_params()}"
    for k,v in self.hparams.items():
      rtn += f"\n    {k}: {v}"
    for k,v in self.metrics.items():
      rtn += f"\n    {k}: {v}"
    rtn += f"\n  Score: {self.score()}"
    return rtn

  def is_trainable(self):
    return self.task.is_trainable()

  def is_private(self):
    return self.task.is_private()

  def score(self):
    return Path.scorer.score(self)

  def get_all_params(self):
    params = {}
    for c in self.components:
      params[c.name] = c.params
    return flax.core.freeze(params)

  def get_trainable_params(self):
    params = {}
    for c in self.components:
      if c.is_trainable():
        params[c.name] = c.params
    return flax.core.freeze(params)

  def get_fixed_params(self):
    params = {}
    for c in self.components:
      if not c.is_trainable():
        params[c.name] = c.params
    return flax.core.freeze(params)

  def update_trainable(self, trained_params):
    trainable_count = 0
    for c in self.components:
      if c.is_trainable():
        trainable_count += 1
        assert c.name in trained_params.keys()
        c.params = trained_params[c.name]
    assert len(trained_params.keys()) == trainable_count, (
        f'{len(trained_params.keys())} {trainable_count}')

  def accounted_num_params(self):
    rtn = 0
    for c in self.components:
      tl = copy.copy(c.train_locks)
      assert type(tl) is set
      tl.add(self.task.name)
      if NOT_TRAINABLE in tl:
        tl.remove(NOT_TRAINABLE)
      if len(tl) == 0:
        return np.nan
      rtn += c.get_num_params() / len(tl)
    return rtn

  def clone(
      self,
      task:Task,
      ds_hparams,
      policy,
      mutate:bool):
    exp_config = Path.exp_config
    assert exp_config == task.exp_config
    comps = []
    new_hparams = copy.deepcopy(self.hparams)
    new_hparams['num_classes'] = task.num_classes
    # Overwrite dataset hparams with those sampled for the generation batch.
    new_hparams.update(ds_hparams)

    def get_component_ref(c, clone):
      if c.is_trainable() or clone:
        # Clone trainable component.
        return c.clone()
      # Refer to frozen component.
      return c

    if mutate:
      for k in exp_config.models_mutation_ranges:
        if (policy.do_mutate() and
            (k in ['num_layers', 'adapter_dim']
             or k.startswith(OPTIMIZER_HPARAMS_KEYS_PRERFIX))):
          new_hparams[k] = incremental_mutation(
              new_hparams[k],
              exp_config.models_mutation_ranges[k])
      new_hparams['adapter_layers'] = mutate_adapters(
          exp_config.mutate_adapters,
          new_hparams['adapter_layers'],
          new_hparams['num_layers'],
          policy)

    _, init_params = get_vit_model_and_params_mapped(
        **get_model_kwargs(new_hparams, exp_config),
        # Use Path.counter so it is deterministic if we rerun same experiment.
        rng_key=Path.counter)
    new_comp_names = params2comp_names(init_params)
    for new_comp_name in new_comp_names:
      comp = None
      # Attept to reuse matching component from closer ancestor.
      ancestor = self
      while ancestor is not None:
        comps_lookup = {c.name:c for c in ancestor.components}
        if new_comp_name in comps_lookup:
          # Head must be trainable if no acestor is of same task will fall back
          # to random init of correct shape.
          if new_comp_name == 'head' and not comps_lookup[new_comp_name].is_trainable():
            assert task.name != ancestor.task.name, f'{task.name} != {ancestor.task.name}'
            ancestor = ancestor.parent
            continue

          # Check shapes match otherwise skip.
          if jax.tree.map(jnp.shape, init_params[new_comp_name]) != jax.tree.map(jnp.shape, comps_lookup[new_comp_name].params):
            if new_comp_name == 'posembed_input':
              # Change of image size changed shape of position embeddings,
              # this can happen if ds_image_size is tuned,
              # continue searching through ancestors for matching size.
              assert 'ds_image_size' in exp_config.models_mutation_ranges
              assert new_hparams['ds_image_size'] != ancestor.hparams['ds_image_size']
              ancestor = ancestor.parent
              continue
            if new_comp_name.startswith('residual_adapter_'):
              # Change of adapter inner dimension changed shape of dense layers,
              # this can happen if adapter_dim is tuned,
              # continue searching through ancestors for matching size.
              assert 'adapter_dim' in exp_config.models_mutation_ranges
              assert new_hparams['adapter_dim'] != ancestor.hparams['adapter_dim']
              ancestor = ancestor.parent
              continue

            print(f'WARNING: Shapes do not match for component: {new_comp_name}  {ancestor.task.name}->{task.name}')
            print(jax.tree.map(jnp.shape, init_params[new_comp_name]))
            print(jax.tree.map(jnp.shape, comps_lookup[new_comp_name].params))
            assert False  # Should not happen in current configuration.

          comp = get_component_ref(comps_lookup[new_comp_name],
                                   clone=mutate and policy.do_mutate(new_comp_name))
          break
        ancestor = ancestor.parent

      # Get reshaped posembed_input.
      if comp is None and new_comp_name == 'posembed_input':
        pe_comp = Path.posembed_components(
            image_size=new_hparams['ds_image_size'],
            query=exp_config.load_vit_checkpoint_query)
        comp = get_component_ref(pe_comp, clone=mutate and policy.do_mutate(new_comp_name))

      # Otherwise create one from random init params.
      if comp is None:
        if VERBOSE:
          print('Init:', new_comp_name)
        # Possible in current configuration.
        assert (new_comp_name == 'head'
                or new_comp_name.startswith('residual_adapter_'))
        comp = params2comps(init_params, train_locks=[], name=new_comp_name)[0]
      assert comp is not None
      comps.append(comp)

    rtn = Path(new_hparams, comps, parent=self, task=task)
    if task == self.task:
      self.metrics['offsprings'] = self.metrics.get('offsprings', 0) + 1
    return rtn

  def get_optimizer(self):
    return Path.optimizers(
        lr=float(self.hparams['opt_lr']),
        lr_schedule=str(self.hparams['opt_lr_schedule']),
        lr_warmup_ratio=float(self.hparams['opt_lr_warmup_ratio']),
        momentum=float(self.hparams['opt_momentum']),
        nesterov=bool(self.hparams['opt_nesterov']),
        num_train_batches_between_validations=int(
            self.task.num_train_batches_between_validations),
        num_validations_per_path_training=int(
            self.task.exp_config.num_validations_per_path_training),
    )

In [None]:
def mutate_adapters(mutate, adapter_layers_ids, num_layers, policy, allow_removal=False):
  a_ids = set(ids_str2ints(adapter_layers_ids))
  if mutate:
    for a_id in range(num_layers):
      if policy.do_mutate():
        if a_id in a_ids:
          if allow_removal:
            a_ids.remove(a_id)
        else:
          a_ids.add(a_id)
  # Drop adapters of layers dropped by a possible mutation in num_layers.
  a_ids = [a_id for a_id in a_ids if a_id < num_layers]
  return ids_ints2str(a_ids)

In [None]:
class Scorer():
  def score(self, path):
    assert False, 'Not implemented'

class ScorerQuality(Scorer):
  def score(self, path):
    if ('quality' not in path.metrics
        or math.isnan(path.metrics['quality'])):
      return None
    assert path.metrics['quality'] >= 0, \
        f'{path.task.name} {path.metrics["quality"]}'
    score = path.metrics['quality']
    assert score >= 0
    return score

class ScorerDecay(Scorer):
  def __init__(self, base, num_params):
    self.base = base
    assert self.base > 0.0
    assert self.base <= 1.0
    self.num_params = num_params
    assert self.num_params > 0
  def score(self, path):
    if ('quality' not in path.metrics
        or math.isnan(path.metrics['quality'])):
      return None
    assert path.metrics['quality'] >= 0, \
        f'{path.task.name} {path.metrics["quality"]}'
    score = path.metrics['quality'] * (self.base ** (path.accounted_num_params() / self.num_params))
    assert score >= 0
    return score

In [None]:
class PopulationPolicy():
  def sample_parent(self, paths):
    assert False, 'Not implemented'

# Random parent sampling policy.
# WARNING: Not used recently, may need updates.
class PPRand(PopulationPolicy):
  def sample_parent(self, paths):
    sampled = paths[np.random.randint(0, len(paths))]
    return sampled

# Tournament policy similar to https://arxiv.org/abs/1802.01548
# WARNING: Not used recently, may need updates.
class PPTournament(PopulationPolicy):
  def __init__(self, subset_size, max_size, exp_config):
    self.subset_size = subset_size
    self.max_size = max_size
    self.exp_config = exp_config

  def reset(self):
    self.seed_paths_id = 0

  def prune(self, paths):
    while len(paths) > self.max_size:
      # subset = np.random.choice(paths, self.subset_size, replace=True).tolist()
      minp = min(paths, key=lambda x: x.score())
      paths.remove(minp)
      print(f'REMOVED: {minp.id} {minp.metrics["quality"]:.2f}')
      assert minp not in paths
    return paths

  def do_mutate(self, comp_name=None):
    if comp_name:
      if comp_name in self.exp_config.force_finetune_components:
        return True
    return self.exp_config.mutation_prob>np.random.uniform()

  def allow_mutations(self, pop):
    return not self.seed_paths_id < len(pop.seed_paths)

  def sample_parent(self, paths):
    subset = np.random.choice(paths, self.subset_size, replace=True).tolist()
    sampled = max(subset, key=lambda x: x.score())
    return sampled

  def sample_path(self, pop, task:Task, ds_hparams):
    # Prune population to max_size if necessary.
    pop.paths[task] = self.prune(pop.paths[task])
    parent = None
    mutate = self.allow_mutations(pop)
    if self.seed_paths_id < len(pop.seed_paths):
      assert mutate == False
      parent = pop.seed_paths[self.seed_paths_id]
      if VERBOSE:
        print('Seed path', parent.id, parent.task.name)
      self.seed_paths_id += 1
    else:
      assert mutate == True
    if parent is None and len(pop.paths[task]) <= 1:
      # This case is needed to fill the first batch.
      parent = random.choice(pop.seed_paths)
      if VERBOSE:
        print('Rand seed', parent.id, parent.task.name)
    if parent is None and len(pop.paths[task]) < self.max_size:
      parent = random.choice(pop.paths[task])
      if VERBOSE:
        print('Rand parent', parent.id, parent.task.name)
    if parent is None:
      parent = self.sample_parent(pop.paths[task])

    child = parent.clone(task, ds_hparams, self, mutate)

    # Store record of mutations.
    mutations = {}
    for k in child.hparams:
      if parent.hparams.get(k) != child.hparams[k]:
        mutations[k] = (parent.hparams.get(k), child.hparams[k])
    child.metrics['mutations'] = json.dumps(mutations)
    if mutations:
      print(child.id, child.metrics['mutations'])
    return child

  def sample_ds_hparams(self, pop, task:Task):
    mutate = self.allow_mutations(pop)
    assert pop.exp_config is self.exp_config
    ds_hparams = {}
    for key in self.exp_config.models_default_hparams:
      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        ds_hparams[key] = self.exp_config.models_default_hparams[key]
    best_path = pop.get_best_path(task)
    if best_path:
      ds_hparams.update(
          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})
    if mutate:
      for k in ds_hparams:
        if (k in self.exp_config.models_mutation_ranges
            and pop.policy.do_mutate()):
          ds_hparams[k] = incremental_mutation(
              ds_hparams[k],
              self.exp_config.models_mutation_ranges[k])
    return ds_hparams

In [None]:
# muNet decay policy.
class PPDecay(PopulationPolicy):
  def __init__(self, exp_config):
    self.exp_config = exp_config

  def reset(self):
    self.seed_paths_id = 0

  def do_mutate(self, comp_name=None):
    if comp_name:
      if comp_name in exp_config.force_finetune_components:
        return True
    return self.exp_config.mutation_prob>np.random.uniform()

  def allow_mutations(self, pop):
    return not self.seed_paths_id < len(pop.seed_paths)

  def sample_parent(self, paths):
    sorted_paths = sorted(paths, key=lambda p: p.score(), reverse=True)
    sampled = None
    for path in sorted_paths:
      offsprings = path.metrics['offsprings']
      assert not math.isnan(offsprings)
      print('>>> considering', path.id, offsprings)
      if np.random.uniform() < 0.5 ** offsprings:
        print(f'selected', path.id)
        sampled = path
        break
    return sampled

  def sample_path(self, pop, task:Task, ds_hparams):
    parent = None
    mutate = self.allow_mutations(pop)
    if self.seed_paths_id < len(pop.seed_paths):
      assert mutate == False
      parent = pop.seed_paths[self.seed_paths_id]
      print('Seed path', parent.id, parent.task.name)
      self.seed_paths_id += 1
    else:
      assert mutate == True

    if not parent:
      parent = self.sample_parent(pop.paths[task])

    if not parent:
      parent = np.random.choice(pop.seed_paths + pop.paths[task])
      print('>>> seed', parent.id)
    child = parent.clone(task, ds_hparams, self, mutate=mutate)

    # Store record of mutations.
    mutations = {}
    for k in child.hparams:
      if parent.hparams.get(k) != child.hparams[k]:
        mutations[k] = (parent.hparams.get(k), child.hparams[k])
    child.metrics['mutations'] = json.dumps(mutations)
    if mutations:
      print(child.id, child.metrics['mutations'])
    return child

  def sample_ds_hparams(self, pop, task:Task):
    mutate = self.allow_mutations(pop)
    assert pop.exp_config is self.exp_config
    ds_hparams = {}
    for key in self.exp_config.models_default_hparams:
      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        ds_hparams[key] = self.exp_config.models_default_hparams[key]
    best_path = pop.get_best_path(task)
    if best_path:
      ds_hparams.update(
          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})
    if mutate:
      for k in ds_hparams:
        if (k in self.exp_config.models_mutation_ranges
            and pop.policy.do_mutate()):
          ds_hparams[k] = incremental_mutation(
              ds_hparams[k],
              self.exp_config.models_mutation_ranges[k])
    return ds_hparams

In [None]:
# Baselines policy.
class PPBaseline(PopulationPolicy):
  def __init__(self, exp_config):
    self.exp_config = exp_config
  def reset(self):
    return None
  def sample_parent(self, paths):
    assert False, 'Baselines should not reach evolutionary codepath.'
  def do_mutate(self, comp_name=None):
    if comp_name:
      if comp_name in exp_config.force_finetune_components:
        return True
    if self.exp_config.mutation_prob == 0.0:
      return False
    elif self.exp_config.mutation_prob == 1.0:
      return True
    else:
      assert False, self.exp_config.mutation_prob

  def sample_path(self, pop, task:Task, ds_hparams):
    assert len(pop.paths[not_trainable]) == 1
    parent = pop.paths[not_trainable][0]
    mutate = True
    child = parent.clone(task, ds_hparams, self, mutate)
    return child

  def sample_ds_hparams(self, pop, task:Task):
    ds_hparams = {}
    for key in self.exp_config.models_default_hparams:
      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        ds_hparams[key] = self.exp_config.models_default_hparams[key]
    return ds_hparams

In [None]:
class Population():
  def __init__(self, exp_config):
    self.paths = defaultdict(list)
    self.exp_config = exp_config
    self.paths_df = pd.DataFrame()
    self.comps_df = pd.DataFrame()
    self.policy = globals()[exp_config.policy_class](
        **exp_config.policy_kwargs,
        exp_config=exp_config)

  def get_best_path(self, task:Task):
    if len(self.paths[task]) == 0:
      return None
    return max(self.paths[task], key=lambda p: p.score())

  def sample_path(self, task:Task, ds_hparams):
    return self.policy.sample_path(pop=self, task=task, ds_hparams=ds_hparams)

  def sample_ds_hparams(self, task:Task):
    return self.policy.sample_ds_hparams(pop=self, task=task)

  def add_train_locks(self, task:Task):
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert task.name not in c.train_locks
    # Add locks.
    paths = self.paths[task]
    for p in paths:
      for c in p.components:
        c.train_locks.add(task.name)
  def rm_train_locks(self, task:Task):
    # Remove locks.
    paths = self.paths[task]
    for p in paths:
      for c in p.components:
        if task.name in c.train_locks:
          c.train_locks.remove(task.name)
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert task.name not in c.train_locks

  def set_seed_paths(self, task:Task):
    self.seed_paths = []
    for paths in self.paths.values():
      for path in paths:
        if path.task is task:
          continue
        if path.task.is_private():
          continue
        self.seed_paths.append(path)
    # random.shuffle(self.seed_paths)
    # Deterministic ordering.
    self.seed_paths = sorted(self.seed_paths, key=lambda p: p.id, reverse=True)

  def start_task(self, task:Task):
    self.set_seed_paths(task)
    self.policy.reset()
    self.rm_train_locks(task)

  def end_task(self, task:Task):
    # Keep only best one.
    best_path = self.get_best_path(task)
    assert best_path is not None
    self.paths[task] = [best_path]

    # Add train locks.
    self.add_train_locks(task)

    # Store stats before dropping references to trigger garbage collection
    # of unused paths, components and parameters.
    self.paths_df = self.paths_df.append(paths_to_df(Path.paths),
                                         ignore_index=True)
    self.comps_df = self.comps_df.append(components_to_df(Path.paths),
                                         ignore_index=True)

    # Drop unused paths generated in this task iteration for garbage collection.
    Path.paths = []
    # Simplify ancestor tree to contain only live paths.
    live_paths_ids = [p.id for paths in self.paths.values() for p in paths]
    # Notice that the simplification is done also for paths of other tasks,
    # since they may be pointing to a path of this task that was just pruned.
    for path in [path for paths in self.paths.values() for path in paths]:
      ancestor = path.parent
      if ancestor is None:
        continue
      while True:
        if ancestor.id in live_paths_ids:
          path.parent = ancestor
          break
        ancestor = ancestor.parent

In [None]:
pd.set_option('display.expand_frame_repr', False)
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

def pop_to_df(pop):
  return paths_to_df([p for paths in pop.paths.values() for p in paths])

def paths_to_df(paths):
  # Collect all metrics names.
  metrics_keys = set()
  hparams_keys = set()
  for path in paths:
    metrics_keys.update(path.metrics)
    hparams_keys.update(path.hparams)

  data = defaultdict(list)
  for path in paths:
    data['task_name'].append(path.task.name)
    data['id'].append(path.id)
    data['parent_id'].append(path.parent.id if path.parent else -1)
    data['parent_task_name'].append(path.parent.task.name if path.parent else None)
    data['final_accounted_params'].append(path.accounted_num_params())
    data['components'].append('_'.join([str(c.id) for c in path.components]))
    for k in hparams_keys:
      data[f'hparams.{k}'].append(path.hparams[k] if k in path.hparams else None)
    for k in metrics_keys:
      data[f'metrics.{k}'].append(path.metrics[k] if k in path.metrics else None)
    data['score'].append(path.score())
  return pd.DataFrame(data)

def components_to_df(paths):
  # Collect all components.
  comps = set()
  for p in paths:
    comps.update(p.components)

  data = defaultdict(list)
  for c in comps:
    data['id'].append(c.id)
    data['name'].append(c.name)
    data['num_params'].append(c.get_num_params())
    data['train_locks'].append(','.join(c.train_locks))
  return pd.DataFrame(data)

def df_leaderboard(df):
  df = df.loc[df['task_name'] != NOT_TRAINABLE]
  # Place columns on the left for readability.
  cols = df.columns.tolist()
  for k in ['metrics.test_quality', 'metrics.quality', 'score']:
    if k in cols:
      cols.remove(k)
      cols.insert(1, k)
  df = df[cols]
  print(df)
  print(f'Avg score:        {df["score"].mean():.6f}')
  print(f'Avg quality:      {df["metrics.quality"].mean():.6f}')
  if 'metrics.test_quality' in df:
    print(f'Avg test quality: {df["metrics.test_quality"].mean():.6f}')

In [None]:
def prp(path):
  rtn = []
  if VERBOSE:
    rtn.append(str(path))
    for c in path.components:
      rtn.append(str(c))
  else:
    rtn.append(str(path.id))
  return '\n'.join(rtn)

In [None]:
def df_write_to_file(df, dir_path, df_name):
  filename_df = os.path.join(dir_path, f'{df_name}.csv')
  with tf.io.gfile.GFile(filename_df, 'w') as outfile:
    df.to_csv(outfile, index=False)

def df_read_from_file(dir_path, df_name,):
  filename_df = os.path.join(dir_path, f'{df_name}.csv')
  with tf.io.gfile.GFile(filename_df, 'r') as infile:
    df = pd.read_csv(infile)
  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in
  # columns with type strings/object.
  for c in df.columns:
    if df[c].dtype == np.object_:
        df[c].fillna('', inplace=True)
  return df

def checkpoint_save(experiment_dir:str, pop:Population, step=None):
  comps_params = {}
  for c in set([c for paths in pop.paths.values() for p in paths for c in p.components]):
    comps_params[f'{c.name}:{c.id}'] = c.params
  flax_checkpoints.save_checkpoint(
      ckpt_dir=experiment_dir,
      target=comps_params,
      step=step)

In [None]:
def load_population_from_checkpoint(
    pop:Population,
    ckpt_dir:str,
    population_df,
    step=None):
  loaded_params = flax.core.freeze(
      flax_checkpoints.restore_checkpoint(
          ckpt_dir=ckpt_dir, target=None, step=step))
  id_2_comp = {}
  for k in loaded_params.keys():
    name,id = k.split(':')
    c = Component(name=name, params=loaded_params[k], train_locks=[])
    c.id = int(id)
    assert c.id not in id_2_comp
    id_2_comp[c.id] = c
  # For parent assignemt.
  id_2_path = {}
  path_2_parent_id = {}
  for index, row in population_df.iterrows():
    comps_ids = row['components'].split('_')
    comps = []
    for id in comps_ids:
      comps.append(id_2_comp[int(id)])
    task_name = row['task_name']
    if task_name == NOT_TRAINABLE:
      task = not_trainable
    else:
      task = Path.tasks(task_name=task_name)
    # Retrieve hparams and metrics.
    hparams = {}
    metrics = {}
    for k in row.keys():
      if k.startswith('hparams.'):
        hparams[k[len('hparams.'):]] = row[k]
      if k.startswith('metrics.'):
        metrics[k[len('metrics.'):]] = row[k]      
    if type(hparams['adapter_layers']) is float:
      if math.isnan(hparams['adapter_layers']):
        hparams['adapter_layers'] = ''
      else:
        hparams['adapter_layers'] = str(int(hparams['adapter_layers']))
    metrics['reloads'] = metrics['reloads'] + 1
    # Create path.
    path = Path(
        hparams=hparams,
        components=comps,
        parent=None,
        task=task,
        )
    path.metrics = metrics
    path.id = int(row['id'])
    # Add train locks.
    for c in path.components:
      c.train_locks.add(task_name)
    pop.paths[task].append(path)
    assert path.id not in id_2_path
    id_2_path[path.id] = path
    if task_name != NOT_TRAINABLE:
      path_2_parent_id[path] = int(row['parent_id'])

  # Set parents.
  for path, parent_id in path_2_parent_id.items():
    path.parent = id_2_path[parent_id]
  Path.counter = 1 + max([id for id in id_2_path])
  Component.counter = 1 + max([id for id in id_2_comp])
  Path.paths = []

In [None]:
@partial(jax.jit, static_argnames='model')
def eval_step(params, images, labels, model):
  logits = model.apply({'params': params}, images, train=USE_DROPOUT)
  # Avg accuracy on the batch.
  return (logits.argmax(axis=-1) == labels).mean()

In [None]:
@partial(jax.jit, static_argnames=['model', 'optimizer'], donate_argnums=[0, 2])
def train_step(params, fixed_params, opt_state, images, labels, model, optimizer):
  def loss_fn(params, fixed_params, images, labels):
    logits = model.apply({'params': format_params(params, fixed_params)},
                         images, train=USE_DROPOUT)
    labels = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))
  grads = jax.grad(loss_fn)(params, fixed_params, images, labels)
  updates, opt_state = optimizer.update(grads, opt_state, params=params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [None]:
LOOP_START = time.time()

def train_loop(paths, ds_train, ds_validation, devices, exp_config):
  global LOOP_START
  timing = {'start_time': time.time(),
            'start_time_loop': LOOP_START}
  task = paths[0].task
  # The following values should be shared by all paths in this generation batch.
  for path in paths:
    assert task == path.task
    assert paths[0].hparams['ds_image_size'] == path.hparams['ds_image_size']

  for p_id, path in enumerate(paths):
    if VERBOSE:
      print('Parent')
      print(prp(path.parent))
      print(prp(path))
    path.device = devices[p_id % len(devices)]
    path.optimizer = path.get_optimizer()
    path.optimizer_init_fn = jax.jit(
        path.optimizer.init,
        device=path.device)
    path.best_params_local = None
    path.best_quality = None
    path.best_score = path.parent.score() if path.task is path.parent.task else -np.inf
    path.evals = []

    # Launch parallel compilation of eval and train step functions.
    params_local = path.get_trainable_params()
    path.compile_params_device = jax.device_put(params_local, path.device)
    path.compile_fixed_params_device = jax.device_put(
        path.get_fixed_params(),
        path.device)
    path.compile_train = Thread(
        target=train_step,
        args=(path.compile_params_device,
              path.compile_fixed_params_device,
              path.optimizer_init_fn(params_local),
              get_sample_image(
                  image_size=path.hparams['ds_image_size'],
                  batch_size=task.train_batch_size),
              get_sample_label(
                  batch_size=task.train_batch_size),
              path.model,
              path.optimizer))
    path.compile_eval = Thread(
        target=eval_step,
        args=(
            format_params(
                path.compile_params_device,
                path.compile_fixed_params_device),
            get_sample_image(
                image_size=path.hparams['ds_image_size'],
                batch_size=task.validation_batch_size),
            get_sample_label(
                batch_size=task.validation_batch_size),
            path.model))
    path.compile_eval.start()

  for path in paths:
    path.compile_eval.join()
    del path.compile_eval
    timing['end_compile_eval'] = time.time()
    path.compile_train.start()

  iter_ds_validation = iter(ds_validation)
  # TRAIN
  for t_step, batch in zip(
      range(exp_config.num_validations_per_path_training
            * task.num_train_batches_between_validations),
      ds_train,
  ):
    for p_id, path in enumerate(paths):
      if t_step == 0:
        path.compile_train.join()
        del path.compile_train
        del path.compile_params_device
        del path.compile_fixed_params_device
        timing['end_compile'] = time.time()
        path.params_device = jax.device_put(
            path.get_trainable_params(),
            path.device)
        path.fixed_params_device = jax.device_put(
            path.get_fixed_params(),
            path.device)
        path.opt_state_device = path.optimizer_init_fn(path.params_device)
        t_step_0_time = time.time()

      path.params_device, path.opt_state_device = train_step(
          path.params_device,
          path.fixed_params_device,
          path.opt_state_device,
          batch['image'],
          batch['label'],
          path.model,
          path.optimizer)
      if t_step == 0 and time.time() - t_step_0_time > 3 and p_id > 3:
        # Notice first step or first paths may overlap with compilation joined
        # in the first step of later paths, so this may fire at times.
        print(f'WARNING: First train step took: {time.time()-t_step_0_time:.2f} s')

    # EVAL
    if (t_step+1) % task.num_train_batches_between_validations == 0:
      first_eval = ((t_step+1) == task.num_train_batches_between_validations)
      if first_eval:
        timing['start_eval'] = time.time()
      for path in paths:
        path.accs = []
      for e_step, batch in zip(
          range(task.num_validation_batches),
          iter_ds_validation,
          ):
        for p_id, path in enumerate(paths):
          if first_eval and e_step == 0:
            e_step_0_time = time.time()
          path.accs.append(
              eval_step(
                  format_params(path.params_device, path.fixed_params_device),
                  batch['image'],
                  batch['label'],
                  path.model))
          if first_eval and e_step == 0 and time.time() - e_step_0_time > 1:
            print(f'WARNING: First eval step took: {time.time()-e_step_0_time:.2f} s')

      qs = []
      eval_idx = (t_step+1) // task.num_train_batches_between_validations
      for path in paths:
        quality = np.mean(path.accs)
        del path.accs
        qs.append(f'{quality:.4f}')
        path.evals.append(quality)
        # Set quality in metrics for current score computation.
        path.metrics['quality'] = quality
        path_score = path.score()
        if path_score > path.best_score:
          path.best_params_local = jax.device_get(path.params_device)
          path.best_score = path_score
          path.best_quality = quality
          qs[-1] += '*'
      train_time = time.time() - timing['end_compile']
      avg_path_time = (train_time / eval_idx) / len(paths)
      print(('\t'.join(qs) + f'\t< Eval {eval_idx}').expandtabs(8),
            f'tot:{train_time:.1f}s', f'avg/path:{avg_path_time:.1f}s')

      if first_eval:
        timing['end_eval'] = time.time()

  for path in paths:
    del path.params_device
    del path.fixed_params_device
    del path.opt_state_device
    del path.optimizer
    del path.optimizer_init_fn

  timing['end_train'] = time.time()

  loop_time = timing['start_time'] - LOOP_START
  compile_time = timing['end_compile'] - timing['start_time']
  compile_eval_time = timing['end_compile_eval'] - timing['start_time']
  compile_train_time = timing['end_compile'] - timing['end_compile_eval']
  train_time = timing['end_train'] - timing['end_compile']
  eval_time = timing['end_eval'] - timing['start_eval']
  LOOP_START = time.time()

  for path in paths:
    path.metrics['loop_time'] = loop_time
    path.metrics['compile_time'] = compile_time
    path.metrics['train_time'] = train_time
    path.metrics['eval_time'] = eval_time
    path.metrics['start_time'] = timing['start_time']
    path.metrics['start_time_loop'] = timing['start_time_loop']
    path.metrics['end_time'] = time.time()
    num_all_params = get_num_params(path.get_all_params())
    num_trainable_params = get_num_params(path.get_trainable_params())
    path.metrics['trainable_params_ratio'] = num_trainable_params/num_all_params
    path.metrics['num_trainable_params'] = num_trainable_params
    path.metrics['quality'] = max(path.evals)
    path.metrics['evals'] = json.dumps([float(v) for v in path.evals])
    path.metrics['training_accounted_params'] = path.accounted_num_params()
    path.metrics['training_score'] = path.score()

    if path.best_params_local:
      path.metrics['improved'] = True
      path.update_trainable(path.best_params_local)
      assert path.best_quality == path.metrics['quality']
      assert path.best_score == path.metrics['training_score']
    else:
      path.metrics['improved'] = False
      # Path will be early pruned if not an improvement, so skip parameters update.
      assert path.best_params_local == None
      assert path.best_quality == None

    del path.best_params_local
    del path.best_score
    del path.best_quality
    del path.evals

    if VERBOSE:
      print('UPDATED:')
      print(prp(path))

  pqs = []
  qs = []
  psc = []
  sc = []
  for path in paths:
    if path.task is path.parent.task:
      pqs.append(f'{path.parent.metrics["quality"]:.4f}')
      psc.append(f'{path.parent.score():.4f}')
    else:
      pqs.append('NEW')
      psc.append('NEW')
    qs.append(f'{path.metrics["quality"]:.4f}')
    sc.append(f'{path.score():.4f}')
    if path.metrics['improved']:
      sc[-1] += '+'

  print(('\t'.join([f'{path.parent.id}' for path in paths]) +
        '\t< Parent id').expandtabs(8))
  print(('\t'.join([f'{path.id}' for path in paths]) +
        '\t< Path id').expandtabs(8))
  print(('\t'.join(pqs) + '\t< Parent best quality').expandtabs(8))
  print(('\t'.join(qs) + '\t< Path best quality').expandtabs(8))
  print(('\t'.join(psc) + '\t< Parent score').expandtabs(8))
  print(('\t'.join(sc) + '\t< Path score').expandtabs(8))

  print('time\tINIT\tCOMPevl\tCOMPtrn\tTRN+EVL\t1stEVAL'.expandtabs(8))
  print(f'(s)\t{loop_time:.1f}\t{compile_eval_time:.1f}\t{compile_train_time:.1f}\t{train_time:.1f}\t{eval_time:.1f}'.expandtabs(8))

In [None]:
# Run a full paths sampling iteration for a task.
def task_iter(task, devices, pop:Population, exp_config:FrozenConfigDict):
  num_devices = len(devices)
  # Track best path.
  best_path = pop.get_best_path(task)
  num_gen_batches = math.ceil(exp_config.num_samples_per_task/num_devices)
  for generation_batch_id in range(num_gen_batches):
    print('----')
    print(f'GENERATION: [{generation_batch_id+1}/{num_gen_batches}]')
    ds_hparams = pop.sample_ds_hparams(task)
    ds_train = task.get_ds('train', ds_hparams)
    ds_validation = task.get_ds('validation', ds_hparams)
    paths = [pop.sample_path(task, ds_hparams) for _ in range(num_devices)]
    train_loop(paths, ds_train, ds_validation, devices, exp_config)
    for path in paths:
      if path.metrics['improved']:
        assert path not in pop.paths
        pop.paths[task].append(path)
    # Track best path.
    curr_best_path = pop.get_best_path(task)
    if curr_best_path != best_path:
      if best_path:
        assert curr_best_path.score() >= best_path.score()
      best_path = curr_best_path
      best_path.metrics['new_best'] = True
      print(f'Best id:{best_path.id}',
            f'score:{best_path.score():.4f}',
            f'quality:{best_path.metrics["quality"]:.4f}',
            f'gen:{generation_batch_id}',
            f'\n{best_path.hparams}')
  assert best_path in pop.paths[task]

In [None]:
TEST_MODELS_IMMUTABILITY = False

# Run final eval on test set.
def run_test_eval(path, ds_test):
  # Running on same device should allow to reuse the fn compiled for validation
  # if batch size matches.
  params = path.get_all_params()
  params_device = jax.device_put(params_comps_to_model(params), path.device)
  acc_sum = []
  tot_num_samples = 0
  # Warning: if repeat() is called on this dataset, then this loop never ends.
  for batch in ds_test:
    acc_avg = jax.device_get(
        eval_step(
            params_device,
            batch['image'],
            batch['label'],
            path.model))
    batch_size = batch['image'].shape[0]
    # Need to recompute sum because last batch can have different size to allow
    # for exact eval on the test set.
    acc_sum.append(acc_avg * batch_size)
    tot_num_samples += batch_size
  del params_device
  acc_avg = np.sum(acc_sum) / tot_num_samples
  if 'test_quality' in path.metrics:
    assert np.isclose(path.metrics['test_quality'], acc_avg), \
        f'{path.task.name} {path.metrics["test_quality"]} {acc_avg}'
  path.metrics['test_quality'] = acc_avg

def run_all_test_evals(pop):
  threads = []
  for path in [path for paths in pop.paths.values() for path in paths if path.is_trainable()]:
    if 'test_quality' in path.metrics and not TEST_MODELS_IMMUTABILITY:
      continue
    ds_test = path.task.get_ds('test', path.hparams)
    thread = Thread(target=run_test_eval, args=(path, ds_test))
    thread.start()
    threads.append(thread)
  for thread in threads:
    thread.join()

In [None]:
def reset_globals(exp_config):
  Path.reset_globals(exp_config)
  Component.reset_globals()

In [None]:
def init_population(exp_config:FrozenConfigDict, continue_exp:bool):
  reset_globals(exp_config)

  Path.scorer = globals()[exp_config.scorer_class](**exp_config.scorer_kwargs)
  pop = Population(exp_config=exp_config)

  def reload_state(load_exp_dir):
    pop.paths_df = df_read_from_file(
        load_exp_dir,
        df_name='paths')
    pop.comps_df = df_read_from_file(
        load_exp_dir,
        df_name='components')
    df_reloaded_population = df_read_from_file(
        load_exp_dir,
        df_name='population')
    load_population_from_checkpoint(
        pop,
        load_exp_dir,
        df_reloaded_population)
    print('Loaded models from', load_exp_dir, ':')
    df_leaderboard(pop_to_df(pop))
    Path.counter = 1 + pop.paths_df['id'].max()
    Component.counter = 1 + pop.comps_df['id'].max()

  # Load population from previous experiment.
  if continue_exp:
    load_exp_dir = exp_config.experiment_dir
    reload_state(load_exp_dir)
    return pop
  elif exp_config.load_experiment:
    load_exp_dir = exp_config.load_experiment_dir
    reload_state(load_exp_dir)

  # Add new seed models.
  if not continue_exp and (
      exp_config.load_rand_init or exp_config.load_vit_checkpoint):
    hparams = exp_config.models_default_hparams.as_configdict()
    # Add a randomly initialized model.
    if exp_config.load_rand_init:
      _, path0_params = get_vit_model_and_params_mapped(
          **get_model_kwargs(hparams, exp_config))
      path = Path(
          hparams,
          params2comps(path0_params, train_locks=[NOT_TRAINABLE]),
          parent=None,
          task=not_trainable)
      pop.paths[not_trainable].append(path)
    # Add model loaded from checkpoint.
    if exp_config.load_vit_checkpoint:
      path_params = get_vit_checkpoint_mapped(
          hparams['ds_image_size'],
          exp_config.load_vit_checkpoint_query)
      path = Path(hparams, params2comps(
          path_params,
          train_locks=[NOT_TRAINABLE]),
          parent=None,
          task=not_trainable)
      pop.paths[not_trainable].append(path)

  return pop

In [None]:
# Experiment setup.
def continue_exp(exp_dir):
  # Load configs.
  print('CONTINUING EXISTING EXPERIMENT:', exp_dir)
  load_config_dict_file = os.path.join(exp_dir, 'config.json')
  exp_config = FrozenConfigDict(json.load(
      tf.io.gfile.GFile(load_config_dict_file, 'r')))
  pop = init_population(exp_config, continue_exp=True)
  # Get loop_id from checkpoint file name.
  checkpoint_path = flax_checkpoints.latest_checkpoint(exp_dir)
  matched = re.findall(r'checkpoint_([0-9]+)$', checkpoint_path)
  assert len(matched)==1
  loop_id = int(matched[0])
  print('FROM CHECKPOINT:', loop_id)
  assert exp_config.experiment_dir == exp_dir
  return pop, exp_config, loop_id

def setup_new_experiment(exp_config):
  # Finalize and save config.
  exp_config.experiment_id = exp_config.experiment_name \
      + datetime.datetime.strftime(
          datetime.datetime.now(), ':%Y-%m-%d-%H-%M-%S')
  exp_config.experiment_dir = os.path.join(exp_config.experiments_root_dir,
                                           exp_config.experiment_id)
  exp_config = FrozenConfigDict(exp_config)
  pop = init_population(exp_config, continue_exp=False)
  print('NEW EXPERIMENT:', exp_config.experiment_dir)
  return pop, exp_config, 0

In [None]:
def setup_exp():
  if BENCHMARK == 'ViT tiny 3 layers / characters benchmark':
    exp_config = get_exp_config_ti3_chars()
    exp_config.experiment_name += ':t3-chars'
  elif BENCHMARK == 'ViT base / decathlon benchmark':
    exp_config = get_exp_config_base_deca()
    exp_config.experiment_name += ':b-deca'
  elif BENCHMARK == 'ViT large / ViT benchmark':
    exp_config = get_exp_config_large()
    exp_config.experiment_name += ':l-vit'
  else:
    assert False, BENCHMARK

  if AUTO_TUNE:
    assert CONFIGURATION == 'muNet' or CONFIGURATION.startswith('Size scale:')
    exp_config.experiment_name += ':autotune'
    exp_config = exp_config_add_auto_tune(exp_config)

  if CONFIGURATION == 'Finetune all':
    exp_config = exp_config_set_baseline_finetune_all(exp_config)
    exp_config.experiment_name += ':finetune'
  elif CONFIGURATION.startswith('Freeze bottom layers'):
    num_layers = int(CONFIGURATION.split(':')[1])
    exp_config = exp_config_set_baseline_freeze_bottom_layers(
        exp_config, num_layers)
    exp_config.experiment_name += f':freeze{num_layers}'
  elif CONFIGURATION.startswith('Adapters:'):
    adapter_dim = int(CONFIGURATION.split(':')[1])
    exp_config = exp_config_set_baseline_adapters(exp_config, adapter_dim)
    exp_config.experiment_name += f':adapters{adapter_dim}'
  elif CONFIGURATION.startswith('Size scale:'):
    base_percent = int(CONFIGURATION.split(':')[1])
    exp_config = exp_config_set_size_scale(exp_config, base_percent)
    exp_config.experiment_name += f':size{base_percent}'
  elif CONFIGURATION == 'muNet':
    exp_config.experiment_name += f':munet'
  else:
    assert False, CONFIGURATION

  if AUTO_CONTINUE:
    exp_dir_prefix = os.path.join(exp_config.experiments_root_dir,
                                  exp_config.experiment_name)
    matching_dirs = tf.io.gfile.GFile(exp_dir_prefix + '*')
    assert len(matching_dirs) < 2, \
        f'Multiple dirs matched for auto restart {matching_dirs}'
    if len(matching_dirs) == 1:
      print('AUTO CONTINE')
      return continue_exp(matching_dirs[0])

  return setup_new_experiment(exp_config)

In [None]:
# Main loop over tasks.
pop, exp_config, loop_id = setup_exp()

devices = jax.local_devices()
print('DEVICE COUNT:', len(devices))
num_tasks = len(exp_config.task_names)
num_loops = exp_config.num_task_iters * num_tasks
for _ in range(num_loops):
  if loop_id >= num_loops:
    break
  t_i = loop_id // num_tasks
  task_idx = loop_id % num_tasks
  task_name = exp_config.task_names[task_idx]
  print('\n\n====')
  print(f'LOOP: [{loop_id+1}/{exp_config.num_task_iters * num_tasks}]')
  print(f'TASK: {task_name}')
  task = Path.tasks(task_name=task_name)
  pop.start_task(task)
  task_iter(task, devices, pop, exp_config)
  pop.end_task(task)
  loop_id += 1

  end_loop_st = time.time()
  # Run test evals.
  run_all_test_evals(pop)
  pop_df = pop_to_df(pop)
  # Save data needed to resume exp.
  start_write = time.time()
  print('WRITING CHECKPOINT:', loop_id)
  if loop_id == 1:
    tf.io.gfile.makedirs(exp_config.experiment_dir)
    json.dump(exp_config.as_configdict().to_dict(),
              tf.io.gfile.GFile(os.path.join(exp_config.experiment_dir,
                                             'config.json'),
                                'wb'), indent=2)
  checkpoint_save(exp_config.experiment_dir, pop, step=loop_id)
  df_write_to_file(pop_df, exp_config.experiment_dir, 'population')
  df_write_to_file(pop.paths_df, exp_config.experiment_dir, 'paths')
  df_write_to_file(pop.comps_df, exp_config.experiment_dir, 'components')
  print(f'TEST EVAL TIME: {start_write - end_loop_st:.2f} s')
  print(f'WRITE TIME: {time.time() - start_write:.2f} s')
  # Display stats.
  df_leaderboard(pop_df)
  avg_time_per_sample = (
      pop.paths_df['metrics.end_time'].mean() \
          - pop.paths_df['metrics.start_time_loop'].mean()
      ) / len(devices)
  print(f'Avg time per path: {avg_time_per_sample:.2f} s')