# Fiddle `select()` and Tag APIs

*Please see https://github.com/google/fiddle/blob/main/docs/colabs.md for more colabs.*

The `select()` and `Tag`/`TaggedValue` APIs for Fiddle let users concisely
change many values in a larger configuration structure.

The `select()` API makes it easy to set parameters across all occurrences of
specific functions or classes within a config. For example:

```python
# Set all Dropout classes to have rate 0.1.
select(root_cfg, nn.Dropout).set(rate=0.1)
```

Values can also be tagged with one or more tags, making it easy to set values
that are shared in many places all at once. For example:

```python
# Set all tagged dtypes, which may be on different functions/classes.
select(root_cfg, tag=ActivationDType).set(value=jnp.bfloat16)
```

Both of these APIs facilitate configuration to be factorized into setups
involving a declaration of the base model (say, `base_model.py`) and several
experiment override files (say, `my_experiment_1.py`, `my_experiment_2.py`),
where the latter sets values tagged by TaggedValues, or modifies specific
functions or classes using `select`.

In [None]:
!pip install fiddle


import fiddle as fdl
from fiddle import graphviz
from fiddle import printing
from fiddle import tagging
from fiddle import selectors
import fiddle.extensions.jax

fiddle.extensions.jax.enable()  # Nicer printout for JAX types; non-essential.

## Running example

Let's first consider a simple structure of Flax modules. These ones add a range
(e.g. an array `[0, 1, 2, 3]`) to their input.

In [None]:
from typing import Any

from flax import linen as nn
from jax import numpy as jnp


class AddRange(nn.Module):
  start: int
  stop: int
  dtype: Any

  def __call__(self, x):
    return x + jnp.arange(self.start, self.stop, dtype=self.dtype)


class AddTwoRanges(nn.Module):
  add_range_1: AddRange
  add_range_2: AddRange

  def __call__(self, x):
    return self.add_range_2(self.add_range_1(x))


cfg = fdl.Config(AddTwoRanges)
cfg.add_range_1 = fdl.Config(AddRange, 0, 4, jnp.float32)
cfg.add_range_2 = fdl.Config(AddRange, 0, 4, jnp.float32)
graphviz.render(cfg)

This model can be run as follows (see
[this colab](https://colab.sandbox.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb)
for an introduction to Flax APIs),

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2]))

## `select()` API

To enable easier wide-spread modification of configurations, we add a simple
tool to select nodes across the configuration DAG, and then set new values on
them.

The main call into this method is `select()`. It currently takes a root config,
and a function or class to select. It returns a `Selection` object,

In [None]:
selectors.select(cfg, AddRange)

This `Selection` object supports iteration over selected nodes,

In [None]:
list(selectors.select(cfg, AddRange))

Let's say we wanted an integer version of our model. Because the `arange` calls
have a dtype (hyper)parameter, just sending integer inputs doesn't work (JAX
auto-casts the integer side of the addition to a `float`),

In [None]:
# Notice that the output dtype is float32.
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

We can use the `select()` API to set both `AddRange` dtypes to `int32`,

In [None]:
selectors.select(cfg, AddRange).set(dtype=jnp.int32)
graphviz.render(cfg)

and this will correctly have an integer output,

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

`set()` also allows setting multiple values,

In [None]:
selectors.select(cfg, AddRange).set(start=1, stop=10)
graphviz.render(cfg)

An API to get all values for a particular field is also provided; this may be
useful for unit testing or debugging,

In [None]:
list(selectors.select(cfg, AddRange).get('dtype'))

### Advanced use notes

The `Selection` object does not maintain references to the actual nodes it
selects, so if the configuration is modified in the meantime, the selection will
pick up any added or deleted nodes. Think of it as declarative semantics. To
demonstrate,

In [None]:
cfg = fdl.Config(AddTwoRanges)
selection = selectors.select(cfg, AddRange)
print("Current selection:", list(selection))
cfg.add_range_1 = fdl.Config(AddRange, 0, 4, jnp.float32)
print("After adding a node:", list(selection))

### Suggested usage patterns

Finally, for larger configuration modifications, users might find the coding
pattern of binding `select` to a root config using `functools.partial` useful,
since it allows modifying multiple nodes quickly.

In [None]:
import functools


class DropoutResidualBlock(nn.Module):
  """Module that runs dropout after applying a body computation."""

  dropout: nn.Module
  body: nn.Module

  def __call__(self, x):
    residual = x
    x = self.dropout(self.body(x))
    return residual + x


# Base experiment definition, typically defined in some kind of `base_model.py`.
cfg = fdl.Config(DropoutResidualBlock)
cfg.body = fdl.Config(
    AddTwoRanges,
    fdl.Config(AddRange, 0, 4, jnp.float32),
    fdl.Config(AddRange, 0, 4, jnp.float32),
)
cfg.dropout = fdl.Config(nn.Dropout, deterministic=False)

# Experimental modifications, typically in some kind of `my_experiment.py`.
select = functools.partial(selectors.select, cfg)
select(AddRange).set(start=5, stop=9)
select(nn.Dropout).set(rate=0.2)
graphviz.render(cfg)

The model can be run, showing some outputs are zeroed based on RNG key,

In [None]:
import jax.random

model = fdl.build(cfg)
inputs = jnp.array([0, 0, 0, 0], dtype=jnp.float32)
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(0)}))
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(1)}))
print(model.apply({}, inputs, rngs={"dropout": jax.random.PRNGKey(2)}))

## TaggedValue API

For simple cases where we only need to set some specific attributes of a single
function/class, `select()` with a function/class will work great. For cases
where the attribute we want to modify affects multiple functions/classes, e.g.
`dtype`, then this could become cumbersome, because `select()` operates by
class, and some classes could name their `dtype` parameter differently.

Therefore, Fiddle introduces the concept of `TaggedValue`: values that are
tagged with one or more tags, and can be set all at once. Let's make our example
a little more complicated, adding a constant and a range,

In [None]:
from typing import List


class AddConstant(nn.Module):
  value: Any
  dtype: Any

  def __call__(self, x):
    return x + jnp.array(self.value, dtype=self.dtype)


class Sequential(nn.Module):
  submodules: List[nn.Module]

  def __call__(self, x):
    for module in self.submodules:
      x = module(x)
    return x


cfg = fdl.Config(
    Sequential,
    submodules=[
        fdl.Config(AddRange, 0, 4, jnp.float32),
        fdl.Config(AddConstant, 1, jnp.float32),
    ])
graphviz.render(cfg)

Just to demonstrate the output of this model:

In [None]:
model = fdl.build(cfg)
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

### Tagging values in a configuration

Now, let's tag the dtypes with a tag,

In [None]:
class ActivationDType(fdl.Tag):
  "The requested data-type for module outputs."


def base_config() -> fdl.Config[Sequential]:
  add_range = fdl.Config(AddRange, 0, 4,
                         ActivationDType.new(default=jnp.float32))
  add_const = fdl.Config(AddConstant, 1,
                         ActivationDType.new(default=jnp.float32))
  return fdl.Config(Sequential, submodules=[add_range, add_const])


graphviz.render(base_config())

We can now express an override configuration, which changes both dtypes to
`int32`,

In [None]:
def experiment_config() -> fdl.Config[Sequential]:
  cfg = base_config()
  selectors.select(cfg, tag=ActivationDType).replace(value=jnp.int32)
  return cfg


graphviz.render(experiment_config())

and this model has `int32` output, as desired,

In [None]:
model: Sequential = fdl.build(experiment_config())
model.apply({}, jnp.array([1, 2, 1, 2], dtype=jnp.int32))

### Basic API reference

In more detail, the `TagSubclass.new()` syntax creates a `TaggedValue`, with an
optional default value.

In [None]:
ActivationDType.new(default=jnp.float32)

This is equivalent to explicitly constructing a `TaggedValue`,

In [None]:
fdl.TaggedValue(tags={ActivationDType}, default=jnp.float32)

we'll see later how to use `TaggedValue`s that have a set of tags.

If you build a configuration with `TaggedValue`s that do not have a default, you
will get an error,

In [None]:
try:
  fdl.build(ActivationDType.new())
except Exception as e:
  name = e.__class__.__name__
  %html <span style="color:red">{name}: {e}</span>
else:
  raise AssertionError("Expected an exception to be thrown!")

### Listing all tags in a configuration

When configurations get very large, it can be very useful to list all available
tags. Tags sometimes serve as a high-level "API" to large configurations.

In [None]:
for tag in tagging.list_tags(experiment_config()):
  print(tag.name, "-", tag.description)

### Tag subclassing and sets of tags

Since tags are types, we have a natural way of specifying a hierarchy of tags,
through subclassing. We could use these to make finer-grained tags, separating
the activation dtypes of intermediate layers and final layers.

In [None]:
class IntermediateLayerActivationDtype(ActivationDType):
  """DType for intermediate layer neural network computations."""


class FinalLayerActivationDtype(ActivationDType):
  """DType for final layer neural network computations."""


def fine_tag_types_config() -> fdl.Config[Sequential]:
  add_range = fdl.Config(
      AddRange, 0, 4, IntermediateLayerActivationDtype.new(default=jnp.float32))
  add_const = fdl.Config(AddConstant, 1,
                         FinalLayerActivationDtype.new(default=jnp.float32))
  return fdl.Config(Sequential, submodules=[add_range, add_const])


graphviz.render(fine_tag_types_config())

We can now just set the intermediate layers to have a lower precision,

In [None]:
cfg = fine_tag_types_config()
selectors.select(
    cfg, tag=IntermediateLayerActivationDtype).replace(value=jnp.bfloat16)
graphviz.render(cfg)

or we can achieve the same result by setting all dtypes to bfloat16, and then
setting the final ones to float32,

In [None]:
cfg = fine_tag_types_config()
selectors.select(cfg, tag=ActivationDType).replace(value=jnp.bfloat16)
selectors.select(cfg, tag=FinalLayerActivationDtype).replace(value=jnp.float32)
graphviz.render(cfg)

Finally, you can create tagged values that manually specify a set of tags. The
`select()` API will select any `TaggedValues` that contain the tag specified.

In [None]:
class MyTagA(fdl.Tag):
  """An example tag."""


class MyTagB(fdl.Tag):
  """Another example tag."""


def foo(a, b, ab):
  return {"a": a, "b": b, "ab": ab}


cfg = fdl.Config(
    foo,
    a=MyTagA.new(
        default=1),  # Equivalent to fdl.TaggedValue({MyTagA}, default=1).
    b=MyTagB.new(default=2),
    ab=fdl.TaggedValue({MyTagA, MyTagB}, default=3))
selectors.select(cfg, tag=MyTagB).replace(value=4)
print(fdl.build(cfg))
selectors.select(cfg, tag=MyTagA).replace(value=7)
print(fdl.build(cfg))

### Advanced use notes

There are a few things to note about the `TaggedValue` API.

First is that `select(cfg, tag=<tag>)` checks tags by their class hierarchy; if
you redefine a Tag class (e.g. by re-running the cell in colab), and have
differing Tag classes in `cfg` and `<tag>` arguments, then you will not set the
value of any tags.

Secondly, by convention, the `TaggedValue` objects are not shared, only `Tag`s.
This means that if you want to set the `dtype` on a sub-network of a model, e.g.
just the encoder of an encoder-decoder model, you can do so. Here we demonstrate
setting it on just `add_range_2`,

In [None]:
cfg_one_set = base_config()
selectors.select(
    cfg_one_set.submodules[1], tag=ActivationDType).replace(value=jnp.bfloat16)
for submodule_cfg in cfg_one_set.submodules:
  print(submodule_cfg.dtype)

Finally, we've added a little logic to the Graphviz rendering and printing, so
you can see tags in your configuration,

In [None]:
print(printing.as_str_flattened(base_config()))

### Suggested usage patterns

Projects should generally have a file like `fiddle_tags.py`, which declares tags
for relevant values. Think of them as declaring collections of attributes with a
similar meaning. This enables their reuse throughout the project (and in
dependent projects too). Documentation on the tags is *required*; please help
others (including future-you!) by writing a good doc-string.

```py
class ActivationDType(fdl.Tag):
  """Outputs of a module/layer should have this dtype."""

class EmbeddingDimension(fdl.Tag):
  """The size of the embedding dimension."""
```