In [None]:
# @title LICENSE
# Licensed under the Apache License, Version 2.0

## JaxPruner Quick Start
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/quick_start.ipynb)

This interactive colab provides a short overview of some of the key features of the `jaxpruner` library:

- One-shot Pruning
- Pruning during Optimization (Integration w/ optax)
- ConfigDict Integration
- Compatibility with JAX parallelization via `pmap` and `pjit`

In [None]:
import functools
import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
import jax.experimental.pjit
import numpy as np
import optax
import pprint

In [None]:
!pip3 install git+https://github.com/google-research/jaxpruner
import jaxpruner
import ml_collections


# One-shot Pruning
Pruning a given matrix to a desired level of sparsity is the building block of any pruning algorithm. Therefore jaxpruner provides a common API for one-shot
pruning. This is achieved by calling the `instant_sparsify` method.

In [None]:
matrix_size = 5
learning_rate = 0.01
matrix = jax.random.uniform(
    jax.random.PRNGKey(8), shape=(matrix_size, matrix_size)
)
print(matrix)

In [None]:
sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8
)
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution
)
pruned_matrix, mask = pruner.instant_sparsify(matrix)

print(pruned_matrix)
print(mask.dtype)
print(mask)

We can quickly change the sparsity structure using `sparsity_type` flag.

In [None]:
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    sparsity_type=jaxpruner.sparsity_types.NByM(1, 5),
)
pruned_matrix, mask = pruner.instant_sparsify(matrix)

print(pruned_matrix)
print(mask.dtype)
print(mask)

`instant sparsify` also supports parameter collections, which are commonly used in deep learning.

In [None]:
# params = [matrix, 1 - matrix]
params = {'pos': matrix, 'inv': 1 - matrix}
pruned_params, masks = pruner.instant_sparsify(params)
pprint.pprint(pruned_params)

It is common to choose different sparsities for different layers or keep them dense entirely. We provide some basic functions to distribute sparsity across different layers such as `uniform` (default) and `erk` under `jaxpruner.sparsity_distributions`. Users can also define their own distributions easily. Here we define a custom distribution function to set different sparsities for each variable.

In [None]:
def custom_distribution(params, sparsity=0.8):
  return {key: 0.4 if 'pos' in key else sparsity for key in params}


pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=custom_distribution
)
pruned_params, masks = pruner.instant_sparsify(params)
pprint.pprint(jaxpruner.summarize_sparsity(pruned_params))

Masks used for enforcing sparsity use the same tree structure as the parameters pruned. We use `None` values to indicate dense parameters. We don't create masks for dense variables.

In [None]:
def custom_distribution2(params, sparsity=0.8):
  return {key: None if 'pos' in key else sparsity for key in params}


pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=custom_distribution2
)
_, masks = pruner.instant_sparsify(params)
pprint.pprint(masks)

Changing the pruning algorithm is easy as they all inherit from the same `BaseUpdater`. We have the following baseline pruning and sparse training algorithms included in our library.

In [None]:
for k in jaxpruner.ALGORITHM_REGISTRY:
  print(k, jaxpruner.ALGORITHM_REGISTRY[k])

Next we use gradient based saliency score for pruning. `SaliencyPruning` requires gradients to be passed to `pruner.instant_sparsify`. Gradients are multipled with parameter values to obtain a first order Taylor approximation of the change in loss.

In [None]:
# Gradient based pruning
pruner = jaxpruner.SaliencyPruning(
    sparsity_distribution_fn=sparsity_distribution
)
print(pruner.instant_sparsify(matrix, grads=(1 - matrix))[0])

# Pruning as optimization (jaxpruner + optax)

Often state-of-the-art pruning algorithms require iterative adjustments to the sparsity masks used. Such iterative approaches are stateful, i.e. they require some additional variables like masks, counters and initial values. This is similar to common optimization algorithms such as Adam and SGD+Momentum which require moving averages.

The observation that *most iterative pruning and sparse training algoritms can be implemented as an optimizer*, played a key role when designing `jaxpruner` and led us to integrate `jaxpruner` with the `optax` optimization library.

Here is an example training loop where we find an orthogonal matrix using gradient descent:

In [None]:
matrix_size = 5


def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size)) ** 2)
  return loss


grad_fn = jax.value_and_grad(loss_fn)


@functools.partial(jax.jit, static_argnames='optimizer')
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


def run_experiment(init_matrix):
  optimizer = optax.sgd(0.05)
  params = {'w': init_matrix}
  opt_state = optimizer.init(params)

  for i in range(20):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    if i % 4 == 0:
      print(f'Step: {i}, loss: {loss}')
  return params['w']

First run the baseline training with a dense matrix.

In [None]:
params = jax.random.uniform(
    jax.random.PRNGKey(8), shape=(matrix_size, matrix_size)
)
run_experiment(params)

Adding a pruner to an existing training loop requires just 2 lines. First we wrap an existing optimizer using the `pruner.wrap_optax` method. This wrapped optimizer ensures the masks are updated during the training. Second, we add a `pruner.post_gradient_update` call after our gradient step. This function defines algorithm specific parameter updates (like applying a mask to parameters) and provides flexibility when implementing various algorithms.

In [None]:
def run_pruning_experiment(init_matrix, pruner):
  optimizer = optax.sgd(0.05)
  # Modification #1
  optimizer = pruner.wrap_optax(optimizer)

  params = {'w': init_matrix}
  opt_state = optimizer.init(params)

  for i in range(20):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    # Modification #2
    params = pruner.post_gradient_update(params, opt_state)

    if i % 4 == 0:
      print(f'Step: {i}, loss: {loss}')
      print(jaxpruner.summarize_sparsity(params, only_total_sparsity=True))
  return params['w']

Now, prune the matrix in one step (step=15).




In [None]:
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=10),
)
params = jax.random.uniform(
    jax.random.PRNGKey(8), shape=(matrix_size, matrix_size)
)
run_pruning_experiment(params, pruner)

Alternatively we can prune it iteratively using the [polynomial schedule](https://arxiv.org/abs/1710.01878).

In [None]:
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.PolynomialSchedule(
        update_freq=4, update_start_step=2, update_end_step=14
    ),
)
params = jax.random.uniform(
    jax.random.PRNGKey(8), shape=(matrix_size, matrix_size)
)
run_pruning_experiment(params, pruner)

# ml_collections.ConfigDict Integration

Many popular jax libraries like [scenic](https://github.com/google-research/scenic) and [big_vision](https://github.com/google-research/big_vision) use `ml_collections.ConfigDict` to configure experiments. `jaxpruner` provides a helper function (`jaxpruner.create_updater_from_config`) to make it easy to use a `ConfigDict` to generate pruner objects.

In [None]:
sparsity_config = ml_collections.ConfigDict()
sparsity_config.algorithm = 'magnitude'
sparsity_config.update_freq = 2
sparsity_config.update_end_step = 15
sparsity_config.update_start_step = 5
sparsity_config.sparsity = 0.6
sparsity_config.dist_type = 'uniform'

In [None]:
# Create a dense layer and sparsify.
pruner = jaxpruner.create_updater_from_config(sparsity_config)
params = jax.random.uniform(
    jax.random.PRNGKey(8), shape=(matrix_size, matrix_size)
)
run_pruning_experiment(params, pruner)

# Parallelization with `pmap` and `pjit`

The `jaxpruner` library is in general compatible with JAX parallelization mechanisms like `pmap` and `pjit`. There are some minor points to watch out for,
which we will now demonstrate using parallelized versions of the previously introduced orthogonal matrix optimization example.

## `pmap`

First, we demonstrate compatibility with `pmap` where a model is replicated to run different shards of a batch on different devices. Note that this example
has no actual model "inputs" apart from the parameter matrix and the replication is thus not directly useful, but the general mechanisms are the same as for real training.

The main point to watch out for is to make sure that the optimizer state is replicated **after** wrapping it with the `jaxpruner`.

In [None]:
matrix_size = 8


def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size)) ** 2)
  return loss


grad_fn = jax.value_and_grad(loss_fn)


@functools.partial(
    jax.pmap,
    out_axes=(0, 0, None),
    axis_name='batch',
    static_broadcasted_argnums=(2,),
)
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  loss = jax.lax.pmean(loss, 'batch')
  grads = jax.lax.pmean(grads, 'batch')
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8
)

pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=0),
)

optimizer = optax.sgd(0.001)
optimizer = pruner.wrap_optax(optimizer)
params = {
    'w': jax.random.normal(jax.random.PRNGKey(0), (matrix_size, matrix_size))
}
opt_state = optimizer.init(params)
# The key step for using pmap with the jaxpruner is to replicate the optimizer
# state **after** wrapping it.
opt_state = flax.jax_utils.replicate(opt_state)
params = flax.jax_utils.replicate(params)

for i in range(100):
  params, opt_state, loss = update_fn(params, opt_state, optimizer)
  params = pruner.post_gradient_update(params, opt_state)
  if i % 5 == 0:
    print(f'Step: {i}, loss: {loss}')
params = flax.jax_utils.unreplicate(params)
print(params['w'])

## `pjit`

Next, we demonstrate tensor sharded training with `pjit`. Here the key is that the partition specifications of the wrapped optimizer state have to incoporate also the `jaxpruner.base_update.SparseState` produced by the pruning wrapper.

In [None]:
matrix_size = 8
if jax.device_count() % 8 == 0:
  MESH_SHAPE = (2, 4)
else:
  MESH_SHAPE = (1, 1)


def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size)) ** 2)
  return loss


grad_fn = jax.value_and_grad(loss_fn)

# Define the partition-specs for pjit; in most libraries for real models this
# is done somewhat automatically, yet this will likely require a small
# adjustment as shown below.

params_partition = {'w': PartitionSpec('X', 'Y')}

# The main step required to run the jaxpruner together with pjit is defining
# a partition-spec for the wrapped `SparseState` as shown below.
opt_partition = jaxpruner.base_updater.SparseState(
    masks=params_partition,
    inner_state=(None, None),  # other optimizers may require sharding
    target_sparsities=None,
    count=None,
)

resources = (params_partition, opt_partition)


@functools.partial(
    jax.experimental.pjit.pjit,
    in_shardings=resources,
    out_shardings=resources + (None,),
    static_argnames='optimizer',
)
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8
)
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=0),
)

optimizer = optax.sgd(0.001)
optimizer = pruner.wrap_optax(optimizer)
params = {
    'w': jax.random.normal(jax.random.PRNGKey(0), (matrix_size, matrix_size))
}
opt_state = optimizer.init(params)

devices = np.asarray(jax.devices()).reshape(MESH_SHAPE)
mesh = jax.sharding.Mesh(devices, ('X', 'Y'))

with mesh:
  for i in range(100):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    params = pruner.post_gradient_update(params, opt_state)
    if i % 5 == 0:
      print(f'Step: {i}, loss: {loss}')
  print(params['w'])
  jax.debug.visualize_array_sharding(params['w'])