# Fiddle `select()` and Placeholder APIs

_Please see https://github.com/google/fiddle/blob/main/docs/colabs.md for more colabs._

The `select()` and `Placeholder` APIs for Fiddle let users concisely change many values in a larger configuration structure.

The `select()` API makes it easy to set parameters across all occurrences of specific functions or classes within a config. For example:

```python
# Set all Dropout classes to have rate 0.1.
select(root_cfg, nn.Dropout).set(rate=0.1)
```

Placeholders are values that are tagged with a key, making it easy to set values that are shared in many places all at once. For example:

```python
# Set all tagged dtypes, which may be on different functions/classes.
fdl.set_placeholder(root_cfg, dtype_key, jnp.bfloat16)
```

Both of these APIs facilitate configuration to be factorized into setups involving a declaration of the base model (say, `base_model.py`) and several experiment override files (say, `my_experiment_1.py`, `my_experiment_2.py`), where the latter sets values tagged by placeholders, or modifies specific functions or classes using `select`.

In [None]:
!pip install fiddle-config


import fiddle as fdl
from fiddle import graphviz
from fiddle import printing
from fiddle.experimental import selectors
import fiddle.extensions.jax

fiddle.extensions.jax.enable()  # Nicer printout for JAX types; non-essential.

## Running example

Let's first consider a simple structure of Flax modules. These ones add a range
(e.g. an array `[0, 1, 2, 3]`) to their input.

In [None]:
from typing import Any

from flax import linen as nn
from jax import numpy as jnp


class AddRange(nn.Module):
  start: int
  stop: int
  dtype: Any

  def __call__(self, x):
    return x + jnp.arange(self.start, self.stop, dtype=self.dtype)


class AddTwoRanges(nn.Module):
  add_range_1: AddRange
  add_range_2: AddRange

  def __call__(self, x):
    return self.add_range_2(self.add_range_1(x))


cfg = fdl.Config(AddTwoRanges)
cfg.add_range_1 = fdl.Config(AddRange, 0, 4, jnp.float32)
cfg.add_range_2 = fdl.Config(AddRange, 0, 4, jnp.float32)
graphviz.render(cfg)

This model can be run as follows (see [this colab](https://colab.sandbox.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb) for an introduction to Flax APIs),

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2]))

## `select()` API

To enable easier wide-spread modification of configurations, we add a simple tool to select nodes across the configuration DAG, and then set new values on them.

The main call into this method is `select()`. It currently takes a root config,
and a function or class to select. It returns a `Selection` object,

In [None]:
selectors.select(cfg, AddRange)

This `Selection` object supports iteration over selected nodes,

In [None]:
list(selectors.select(cfg, AddRange))

Let's say we wanted an integer version of our model. Because the `arange` calls have a dtype (hyper)parameter, just sending integer inputs doesn't work (JAX auto-casts the integer side of the addition to a `float`),

In [None]:
# Notice that the output dtype is float32.
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

We can use the `select()` API to set both `AddRange` dtypes to `int32`,

In [None]:
selectors.select(cfg, AddRange).dtype = jnp.int32
graphviz.render(cfg)

and this will correctly have an integer output,

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

`select()` also has a shorthand for setting multiple values,

In [None]:
selectors.select(cfg, AddRange).set(start=1, stop=10)
graphviz.render(cfg)

An API to get all values for a particular field is also provided; this may be
useful for unit testing or debugging,

In [None]:
list(selectors.select(cfg, AddRange).get('dtype'))

### Advanced use notes

The `Selection` object does not maintain references to the actual nodes it selects, so if the configuration is modified in the meantime, the selection will
pick up any added or deleted nodes. Think of it as declarative semantics. To demonstrate,

In [None]:
cfg = fdl.Config(AddTwoRanges)
selection = selectors.select(cfg, AddRange)
print("Current selection:", list(selection))
cfg.add_range_1 = fdl.Config(AddRange, 0, 4, jnp.float32)
print("After adding a node:", list(selection))

### Suggested usage patterns

Finally, for larger configuration modifications, users might find the coding pattern of binding `select` to a root config using `functools.partial` useful, since it allows modifying multiple nodes quickly.

In [None]:
import functools


class DropoutResidualBlock(nn.Module):
  """Module that runs dropout after applying a body computation."""

  dropout: nn.Module
  body: nn.Module

  def __call__(self, x):
    residual = x
    x = self.dropout(self.body(x))
    return residual + x


# Base experiment definition, typically defined in some kind of `base_model.py`.
cfg = fdl.Config(DropoutResidualBlock)
cfg.body = fdl.Config(
    AddTwoRanges,
    fdl.Config(AddRange, 0, 4, jnp.float32),
    fdl.Config(AddRange, 0, 4, jnp.float32),
)
cfg.dropout = fdl.Config(nn.Dropout, deterministic=False)

# Experimental modifications, typically in some kind of `my_experiment.py`.
select = functools.partial(selectors.select, cfg)
select(AddRange).set(start=5, stop=9)
select(nn.Dropout).set(rate=0.2)
graphviz.render(cfg)

The model can be run, showing some outputs are zeroed based on RNG key,

In [None]:
import jax.random

model = fdl.build(cfg)
inputs = jnp.array([0, 0, 0, 0], dtype=jnp.float32)
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(0)}))
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(1)}))
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(2)}))

## Placeholder API

For simple cases where we only need to set some specific attributes of a single function/class, `select()` will work great. For cases where the attribute we want to modify affects multiple functions/classes, e.g. `dtype`, then this could become cumbersome, because `select()` operates by class, and some classes could name their `dtype` parameter differently.

Therefore, Fiddle introduces the concept of placeholders: values that are tagged
with a shared key, and can be set all at once. Let's make our example a little
more complicated, adding a constant and a range,

In [None]:
from typing import List


class AddConstant(nn.Module):
  value: Any
  dtype: Any

  def __call__(self, x):
    return x + jnp.array(self.value, dtype=self.dtype)


class Sequential(nn.Module):
  submodules: List[nn.Module]

  def __call__(self, x):
    for module in self.submodules:
      x = module(x)
    return x


cfg = fdl.Config(
    Sequential,
    submodules=[
        fdl.Config(AddRange, 0, 4, jnp.float32),
        fdl.Config(AddConstant, 1, jnp.float32),
    ])
graphviz.render(cfg)

Just to demonstrate the output of this model:

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

Now, let's tag the dtypes with a placeholder key,

In [None]:
dtype_key = fdl.PlaceholderKey(name="activation_dtype")


def base_config() -> fdl.Config[Sequential]:
  add_range = fdl.Config(AddRange, 0, 4,
                         fdl.Placeholder(dtype_key, default=jnp.float32))
  add_const = fdl.Config(AddConstant, 1,
                         fdl.Placeholder(dtype_key, default=jnp.float32))
  return fdl.Config(Sequential, submodules=[add_range, add_const])

graphviz.render(base_config())

We can now express an override configuration, which changes both dtypes to `int32` using the `set_placeholder()` invocation,

In [None]:
def experiment_config() -> fdl.Config[Sequential]:
  cfg = base_config()
  fdl.set_placeholder(cfg, dtype_key, jnp.int32)
  return cfg

graphviz.render(experiment_config())

and this model has `int32` output, as desired,

In [None]:
model: Sequential = fdl.build(experiment_config())
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

### Advanced use notes

There are a few things to note about the `Placeholder` API.

First is that `set_placeholder()` sets placeholders by object identity on keys; the key names are basically debugging information. Here we demonstrate setting `add_range_2`'s dtype placeholder to a different key object, and that it is not updated by the `set_placeholder` call. In other words, only `cfg_two_keys.add_range_1.dtype` will end up getting set,



In [None]:
cfg_two_keys = base_config()
other_dtype_key = fdl.PlaceholderKey(name="activation_dtype")
cfg_two_keys.submodules[1].dtype = fdl.Placeholder(other_dtype_key, jnp.int32)
fdl.set_placeholder(cfg_two_keys, dtype_key, jnp.float64)
for submodule_cfg in cfg_two_keys.submodules:
  print(submodule_cfg.dtype)

Secondly, by convention, the actual `Placeholder` objects are not shared, only `PlaceholderKey`s. This means that if you want to set the `dtype` on a sub-network of a model, e.g. just the encoder of an encoder-decoder model, you can do so. Here we demonstrate setting it on just `add_range_2`,

In [None]:
cfg_one_set = base_config()
fdl.set_placeholder(cfg_one_set.submodules[1], dtype_key, jnp.bfloat16)
for submodule_cfg in cfg_two_keys.submodules:
  print(submodule_cfg.dtype)

Finally, we've added a little logic to the Graphviz rendering and printing, so you can see placeholders in your configuration,

In [None]:
print(printing.as_str_flattened(base_config()))

### Suggested usage patterns

We believe the object-identity feature of `PlaceholderKey`s will encourage good code organization. Projects should generally have a file like `fiddle_placeholders.py`, which declares these keys for relevant shared values. Think of them as declaring collections of attributes with a similar meaning.

```py
activation_dtype = fdl.PlaceholderKey("activation_dtype")
embed_dim = fdl.PlaceholderKey("embed_dim")
```