In [None]:
#@title LICENSE
# Licensed under the Apache License, Version 2.0

# JaxPruner Deep Dive
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/deep_dive.ipynb)

In this interactive colab we make a deep dive on the internals of the `jaxpruner` library and demonstrate how to implement a new algorithm. We will 
start w/ going over our base class `BaseUpdater` and sub-class it to implement our new algorithm.

In [None]:

import dataclasses
import chex
import jax
import jax.numpy as jnp
import optax
import pprint

In [None]:
!pip3 install git+https://github.com/google-research/jaxpruner
import jaxpruner

## The BaseUpdater
We define most of the common functionalities used in pruning/sparse-training algorithms under the BaseUpdater class. All algorithms are expected to inherit this class and implement at least the `calculate_scores` function. BaseUpdater provide 2 main entry-points for the user:

- `instant_sparsify`: Used for instant (one-shot) pruning.
- `wrap_optax`: Used to wrap an existing optimizer with pruning related operations. 

Using `instant_sparsify` is straight-forward and demonstrated in our [Quick-Start colab](TODO). In this colab we will focus on training/optimization based algorithms which use the `wrap_optax` functionality.

Optax optimizers are [gradient transformations](https://optax.readthedocs.io/en/latest/api.html#optax.GradientTransformation) that update provided state variables to control how gradients are used. Resulting gradients are often added to the paramaters in a separate call outside the optimizer. This means every pruning algorithm is able to access gradients directly and update the gradients together with its own state (masks, counter, etc.).

Below we share a simplified version of `wrap_optax` function. `self.init_state` is called during the initialization and similarly `self.update_state` is used if `self.scheduler` says it is an update iteration. The state of the original optax transformation is stored under `.inner_state`.

In [None]:
def wrap_optax(self, inner: optax.GradientTransformation
               ) -> optax.GradientTransformation:
  """Wraps an existing transformation and adds sparsity related updates."""
  def init_fn(params):
    sparse_state = self.init_state(params)
    sparse_state = sparse_state._replace(inner_state=inner.init(params))
    return sparse_state

  def update_fn(updates, state, params):
    # Simplified
    if self.scheduler.is_mask_update_iter(state.count):
      new_state = self.update_state(state, params, updates)
    else:
      new_state = state
    new_updates, updated_inner_state = inner.update(updates, new_state.inner_state, params)
    new_state = new_state._replace(inner_state=updated_inner_state,
                                   count=new_state.count + 1)
    return new_updates, new_state

  return optax.GradientTransformation(init_fn, update_fn)

### `init_state` and `update_state` 
There are 2 main functions called during the optimizer initialization and usage. These are used to implement pruning and sparse_training related operations. Let's start with `init_state`. `jaxpruner` implements sparsity through binary masks. We create these masks at initialization and store it using the `SparseState`. The smallest data format in jax is `int8` (boolean variables use 8 bits per element). We use `jnp.packbits` to compress the masks further. In addition to `masks` we also store a count variable and target_sparsities under the `SparseState`.

In [None]:
def init_state(self, params: chex.ArrayTree):
  """Creates the sparse state."""
  if self.sparsity_distribution_fn is None:
    target_sparsities = None
  else:
    target_sparsities = self.sparsity_distribution_fn(params)
  logging.info('target_sparsities: %s', target_sparsities)
  masks = self.get_initial_masks(params, target_sparsities)
  if self.use_packed_masks:
    masks = jax.tree.map(jnp.packbits, masks)
  return SparseState(
      masks=masks,
      target_sparsities=target_sparsities,
      count=jnp.zeros([], jnp.int32),
  )



Different algorithms have different routines for updating masks. In `BaseUpdater` we implement a routine best suited for gradual pruning algorithms. This routine calculates sparsities using the current step count and 
creates masks using the scores. 

In [None]:
def update_state(self, sparse_state: jaxpruner.SparseState, params: chex.ArrayTree,
                 grads: chex.ArrayTree) -> jaxpruner.SparseState:
  """Updates the sparse state."""
  sparsities = self.scheduler.get_sparsity_at_step(
      sparse_state.target_sparsities, sparse_state.count
  )
  scores = self.calculate_scores(
      params, sparse_state=sparse_state, grads=grads
  )
  new_masks = self.create_masks(scores, sparsities)
  if self.use_packed_masks:
    new_masks = jax.tree.map(jnp.packbits, new_masks)
  return sparse_state._replace(masks=new_masks)

### Other helper functions

`self.update_state` function above has access to parameters, but can't update them directly. However, most pruning and sparse training algorithms require some updates like applying masks to parameters. To address this limitation in a unified manner, we provide two additional functions. These functions are expected to be added to the training loop by the user. According to the algorithm chosen; they do necessary updated on the parameters. These functions are:
- `post_gradient_update`: Intended to be called after applying gradients to the parameters. Since most algorithms keep parameters *sparse*, the default implementation under `BaseUpdater` applies masks to the parameters provided.
```python
def post_gradient_update(
      self, params: chex.ArrayTree, sparse_state: SparseState
  ) -> chex.ArrayTree:
    return self.apply_masks(params, sparse_state.masks)
```
- `pre_forward_update`: Intended to be called before the forward call (i.e. `flax_model.apply(data, params)`) to modify the parameters temporarily. This is useful when implementing algorithms like [STE](https://arxiv.org/abs/2102.04010) which calls top-k operation on parameters before every forward call, however keeps the original parameters as it is. Most algorithms doesn't need this call to work and therefore the default implementation is an *identity* function.


### Configuration
`BaseUpdater` has following attributes:

- `scheduler`:  Implements when masks are updated and how much sparsity should be applied at a given step. Default is `NoUpdateSchedule`.
- `skip_gradients`:  Returns zero gradients during mask update iterations, practically skipping the gradient update (default=False).
- `is_sparse_gradients`: Whether masks are applied to the gradients before passing to the optimizer wrapped (default=False).
- `sparsity_type`: One of `sparsity_types.{Unstructured/NbyM/Block}`. Determines the topk_function used by algorithms (default=Unstructured).
- `sparsity_distribution_fn(params, ...)`: Function to set target sparsity for each parameter. Default value is `uniform` distribution with `None` sparsity.
- `rng_seed`: random seed to overwrite the default (default=8).
- `use_packed_masks` If true, packs int8 masks into bits (default=False). 



## Implementing new algorithms
Since `BaseUpdater` implements most of the routines needed by the gradual pruning, therefore implementing gradual magnitude pruning algorithm requires only an implementation of the `calculate_scores` function.

In [None]:
@dataclasses.dataclass
class MagnitudePruning(jaxpruner.BaseUpdater):
  """Implements magnitude based pruning."""

  def calculate_scores(self, params, sparse_state=None, grads=None):
    del sparse_state, grads
    param_magnitudes = jax.tree.map(jnp.abs, params)
    return param_magnitudes

`calculate_scores` is used by `update_state` and `instant_sparsify` function. An alternative way to sub-class `BaseUpdater` is through updating these higher level functions. Below we implement a static sparse training method that sparsifies the masks randomly at
initialization and keeps them same. 

We can also over-write fields defined in BaseUpdater dataclass. When doing that, **we need to make sure new or overwritten 
variables are defined using `dataclasses.field` or with type annotation**. If not,
these variables are treated as class variables and overwritten at initialization.

In [None]:
@dataclasses.dataclass
class StaticRandomSparse(jaxpruner.BaseUpdater):
  """Initializes sparsity randomly and optimizes using that sparsity."""

  is_sparse_gradients: bool = True

  def update_state(self, sparse_state, params, grads):
    """Returns sparse_state unmodified."""
    del params, grads
    return sparse_state

  def get_initial_masks(
      self, params: chex.ArrayTree, target_sparsities: chex.ArrayTree
  ) -> chex.ArrayTree:
    """Generate initial mask. This is only used when .wrap_optax() is called."""
    scores = pruners.generate_random_scores(params, self.rng_seed)
    init_masks = self.create_masks(scores, target_sparsities)
    return init_masks

  def instant_sparsify(self, params, grads=None):
    raise RuntimeError(
        'instant_sparsify function is not supported in sparse training methods.'
    )

There are often multiple paths to implementing same algorithm. Feel free to check other algorithms to get inspired further.