# Pax Workshop
## Pax Layer Basics

Goal: This lab describes the basics for authoring a new Pax layer.

*   Pax is built on top of Flax nn.Module
*   Familiarity with the basics of Flax will help users understand Pax layer
API. See https://flax.readthedocs.io/en/latest/overview.html for Flax basics
*   Pax has some of its roots in Lingvo. We try to highlight API differences with Lingvo to help Lingvo users to familiarize with the new API

In [None]:
from typing import Optional

import jax
import jax.numpy as jnp
import numpy as np
from praxis import base_layer
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from pprint import pprint

# Introduce some common alias.
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
template_field = base_layer.template_field
instance_field = base_layer.instance_field
instantiate = base_layer.instantiate

PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM

## Layer definition

A 'Pax layer' represents an arbitrary function, possibly with trainable parameters.

Layers are the essential building blocks of models. They inherit from the Flax nn.Module. A layer can contain other layers as children.

A *Pax layer* always inherits from `base_layer.BaseLayer` (which internally inherits from Flax nn.Module). All non-trivial Pax layers have one or more
*fields* and a `__call__` method. Additionally, a `setup` method can be defined to initialize variables and create child layers from templates. A quick preview of what a layer looks like is:
```
class Linear(base_layer.BaseLayer):
  # Hyperparameters:
  input_dims: int = 0
  ...

  def setup(self):
    self.create_variable('w', WeightHParams(shape=[self.input_dims, ...

  def __call__(self, inputs):
    return jnp.matmul(inputs, self.theta.w)
```

Pax layer fields can be divided into three groups:

* *Hyperparameters*
* *Child layers*
* *Layer templates*

### PAX Layer Fields: hyperparameters

The ***hyperparameters*** for a layer are declared using dataclass field syntax:

```
  <name>: <type> = <default_value>
```

All hyperparameters are currently required to have default values.  Hyperparameter values are frozen when a layer is instantiated.

### PAX Layer Fields: child layers

***Child layers*** can be declared using the following syntax: 

```
<name>: <type> = instance_field(<factory>)
````

Where `<factory>` is typically a Layer class name (e.g., `Bias`), but can also be a factory function that returns a Layer.

Alternatively, child layers can be constructed from layer templates using `self.create_child`, as described below.  Child layers that are constructed from templates should *not* be declared using dataclass field syntax.

### PAX Layer Fields: layer templates

***Layer templates*** are declared using the following syntax:

```
<name>: fdl.Config[<type>] = template_field(<factory>)
```

Where `<factory>` is typically a Layer class name (e.g., `Bias`), but can also be a factory function that returns a Layer.  Layer templates are used to create child layers, by calling `self.create_child` from the `setup` method.

### setup

`setup` declares the layer variables/weights and its sublayers using

- `self.create_variable`
- `self.create_child`

### setup: creating variables

For example, the following declares a trainable weight `w`. Note that Pax requires users to statically annotate the shape and dtype of the weight, in addition to the usual initializer method.
```
self.create_variable(
    'w',
    WeightHParams(
        shape=[self.input_dims, self.output_dims],
        init=self.params_init,
        dtype=self.dtype))
```
Trainable weights can be accessed as `self.theta.w`.

The following declares a non-trainable weight `moving_mean`. `REQUIRES_MEAN_SYNC` tells the training loop to sync the mean of this variable after train step, which you can ignore for now.
```
mva = WeightHParams(
    shape=[self.output_dims],
    init=WeightInit.Constant(0.0),
    dtype=self.dtype,
    collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC])
self.create_variable(
  'moving_mean',
  mva,
  trainable=False)
```

Non-trainable weights can be accessed via `self.get_var('moving_mean')`.

### setup: creating child layers from templates

`create_child(<name>, <layer_tpl>)` creates a child layer named `self.<name>` based on the layer template `<layer_tpl>`.  Here's an example of how to create a sublayer `self.linear` from a layer template `linear_tpl`:

```
def setup(self):
  linear_tpl = self.linear_tpl.clone()
  linear_tpl.set(
      input_dims=self.input_dims,
      output_dims=self.output_dims)
  self.create_child('linear', linear_tpl)
```

### \_\_call\_\_

`__call__` defines the forward-pass computation:
- `self.theta.w` refers to trainable weight `w`
- `self.get_var('moving_mean')` refers to non-trainable weight `moving_mean`
- Trainable weights are immutable in `__call__` while non-trainable weights can be updated with `self.update_var('moving_mean', new_moving_mean)`.
- Sublayer `__call__` can be expressed as
  `projected_inputs = self.linear(inputs)`

### randomness

Often, users want to access randomness in `__call__`. All `jax.random.*` functions take a `jax.random.PRNGKey(some_int)` as an argument.

**Important:** The same key always maps to the same random bits. (On the first order, users can think of JAX random function as a deterministic hash from a key to some random bits.) Because of this, users must always use a new random key per invocation.

Typically, to avoid key re-use, JAX users are taught to split the key with `key, subkey=jax.random.split(key)` and use `subkey`, but for convenience, Pax BaseLayer handles the key splitting internally and provides `self.next_prng_key()` that gives users a new key on every invocation.

See https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html?highlight=random# for more on JAX randomness.

An example of `self.next_prng_key` in action:

```

def __call__(self, inputs):
  prng_key = self.next_prng_key()
  random_tensor = jax.random.uniform(
      prng_key, inputs.shape, dtype=inputs.dtype)
  inputs_with_noise = inputs + random_tensor
  ...
```

### summaries

Users may want to report summaries to be shown in TensorBoard. Inside layer `__call__`, users can do so with `self.add_summary`. Example:

```
self.add_summary(
  'inputs_mean', jnp.mean(inputs))
```

### Note to Lingvo users

For Lingvo users:

  - No more explicit `theta` argument to `fprop`
  - Same APIs for creating variables and sublayers:
    - `self.create_variable`
    - `self.create_child`
  - Slightly different ways to add summaries and aux_loss
    - `self.add_summary`
    - `self.add_aux_loss`

Let's put all those in action in a few layer definitions. The layer `__call__` logic is for illustration purposes.

In [None]:
class Linear(base_layer.BaseLayer):
  """A simple linear layer.

  Attributes:
    input_dims: Depth of the input.
    output_dims: Depth of the output.
  """
  input_dims: int = 0
  output_dims: int = 0

  def setup(self):
    # create_variable creates trainable variable, similar to Flax' self.param.
    self.create_variable(
        'w',
        WeightHParams(
            shape=[self.input_dims, self.output_dims],
            init=self.params_init,
            dtype=self.dtype))

  def __call__(self, inputs):
    # Use self.theta.
    return jnp.matmul(inputs, self.theta.w)


class Bias(base_layer.BaseLayer):
  """A simple bias layer.

  Attributes:
    dims: Depth of the input.
  """
  dims: int = 0

  def setup(self):
    self.create_variable(
        'b',
        WeightHParams(
            shape=[self.dims],
            init=WeightInit.Constant(0.0),
            dtype=self.dtype))

  def __call__(self, inputs):
    return inputs + self.theta.b


class FeedForward(base_layer.BaseLayer):
  """A basic feed-forward layer.

  Attributes:
    input_dims: Depth of the input.
    output_dims: Depth of the output.
    has_bias: Adds bias weights or not.
    linear_tpl: Linear layer params
    activation_tpl: Activation function to use.
  """
  input_dims: int = 0
  output_dims: int = 0
  has_bias: bool = True
  linear_tpl: LayerTpl = template_field(Linear)
  activation_tpl: pax_fiddle.Config[activations.BaseActivation] = template_field(
      activations.ReLU)

  def setup(self):
    linear_tpl = self.linear_tpl.clone()
    linear_tpl.input_dims=self.input_dims
    linear_tpl.output_dims=self.output_dims
    # Provide type hint.
    self.linear: Linear
    self.create_child('linear', linear_tpl)

    if self.has_bias:
      bias_layer_tpl = pax_fiddle.Config(Bias, dims=self.output_dims)
      # Provide type hint.
      self.bias: Bias
      self.create_child('bias', bias_layer_tpl)

    # Provide type hints
    self.activation: activations.BaseActivation
    self.create_child('activation', self.activation_tpl)

    # To demonstrate how to add a non-trainable var e.g. batch stats.
    # Set trainable=False
    mva = WeightHParams(
        shape=[self.output_dims],
        init=WeightInit.Constant(0.0),
        dtype=self.dtype,
        collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC])
    self.create_variable(
      'moving_mean',
      mva,
      trainable=False)

  # For illustration purposes, we demo how to add summary, get randomness,
  # add aux loss and update non-trainable vars. You wouldn't normally do this
  # for a FeedForward layer.
  def __call__(self, inputs):
    # Add a summary to the `summaries` variable collection.
    self.add_summary(
      'inputs_mean', jnp.mean(inputs))

    # Demonstrate how to get randomness with self.next_prng_key.
    prng_key = self.next_prng_key()
    random_tensor = jax.random.uniform(
        prng_key, inputs.shape, dtype=inputs.dtype)
    inputs_with_noise = inputs + random_tensor

    # No longer need to provide self.theta.linear.
    projected_inputs = self.linear(inputs_with_noise)
    if self.hparams.has_bias:
      projected_inputs = self.bias(projected_inputs)
    output = self.activation(projected_inputs)

    # Add a dummy aux_loss.
    self.add_aux_loss('dummy_aux_loss', jnp.mean(output))

    # Read and update non-trainable var.
    old_v = self.get_var('moving_mean')
    new_v = old_v + 1.0
    self.update_var('moving_mean', new_v)

    return output


## Instantiate a layer and introspect how variable collections are tracked

It is useful to use Colab to interactively play around with the example layers to get a mental model of Pax and its relation to Flax nn.Module. The following sections should be read after getting some familarity with Flax nn.Module APIs to appreciate the similarity and differences with PAX. See https://flax.readthedocs.io/en/latest/index.html.

- a `fdl.Config[Layer]` object `p` can be instantiated via `instantiate(p)` to get an object of Layer itself.
- Similar to Flax nn.Module APIs, Pax layers can be initialized and run with `layer.init` and `layer.__call__`

In [None]:
# Create a new layer using Fiddle.
ffn_cfg = pax_fiddle.Config(
    FeedForward,
    name='ffn', input_dims=1, output_dims=2)
ffn: FeedForward = instantiate(ffn_cfg)

### initial_vars

`ffn.init` takes a RNG key for variable initialization. The returned `initial_vars` is a nested dict of different variable collections that are initialized:

- 'params' collection includes the trainable variables
- 'non_trainable' collections includes non-trainable ones

In [None]:
npy_inputs = np.random.normal(1.0, 0.5, [2, 2, ffn_cfg.input_dims]).astype(np.float32)
inputs = jnp.asarray(npy_inputs)

prng_key = jax.random.PRNGKey(seed=123)
prng_key, init_key, random_key, noise_key = jax.random.split(prng_key, 4)
# Similar to Flax, use layer.init to initialize the layer and return initial_vars.
# Note Pax doesn't support shape inference.
initial_vars = ffn.init({PARAMS: init_key, RANDOM: random_key}, inputs)
pprint(initial_vars)

### layer.apply

`layer.apply` returns the outputs without mutating any variable collections by default.

- Provide `rngs={RANDOM: noise_key}` to pass in RNG key stream named RANDOM

In [None]:
outputs = ffn.apply(initial_vars, inputs, method=ffn.__call__, rngs={RANDOM: noise_key})
pprint(outputs)

### layer.apply(..., mutable=True)
`layer.apply(..., mutable=True)` returns the outputs with updated variable collections.
- Set `mutable=True` to see `updated_vars`
- Note that `summaries`, `aux_loss` are passed around just like `params` and `non_trainable` as separate Flax variable collections

In [None]:
# Settig mutable=True will return both the outputs and the updated variable collections.
outputs, updated_vars = ffn.apply(initial_vars, inputs, method=ffn.__call__, rngs={RANDOM: noise_key}, mutable=True)
pprint(outputs)