Skip to content

Commit

Permalink
feat(train): log norm and histograms (#143)
Browse files Browse the repository at this point in the history
* feat(train): log norm and histograms
* feat: update shampoo
  • Loading branch information
borisdayma authored Mar 19, 2022
1 parent 7939874 commit b7b619a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 28 deletions.
13 changes: 9 additions & 4 deletions tools/train/scalable_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,11 @@ def sharded_init_fn(params):
if not _skip_preconditioning(param):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
preconditioners = [jnp.eye(max_size) for s in shapes]
statistics = [
matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
for s in shapes
]
preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
padded_statistics.extend(statistics)
padded_preconditioners.extend(preconditioners)
exponent = (
Expand Down Expand Up @@ -1244,8 +1247,10 @@ def _init(param):
preconditioners = []
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
preconditioners = [jnp.eye(s[0]) for s in shapes]
statistics = [
matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
]
preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]

diagonal_statistics = []
if _graft_type_has_diagonal_statistics():
Expand Down
105 changes: 87 additions & 18 deletions tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""

import functools
from typing import List, Union
from typing import Any, List, Sequence, Union

import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax import lax

Expand All @@ -41,6 +42,7 @@ class SlicedSymmetricMatrix:
def product_with_transpose(
mat1,
mat2,
axes,
precision=lax.Precision.DEFAULT,
):
"""Returns mat1 * mat2^T for two matrices (possibly batched).
Expand All @@ -50,50 +52,85 @@ def product_with_transpose(
Args:
mat1: First matrix.
mat2: Second matrix.
axes: The axes over which to apply the product.
precision: JAX precision to use for the multiplication.
"""
return jnp.einsum("...ij,...kj->...ik", mat1, mat2, precision=precision)
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)


@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
def sliced_transposed_product(
mat,
block_size,
axes=(-1,),
precision=lax.Precision.DEFAULT,
):
"""Returns the blocked slices representing a symmetric matrix mat*mat^T.
"""Returns the blocked slices representing a symmetric contraction.
Specifically, the output is a contraction of the input mat with itself, in the
specified axes.
Args:
mat: The matrix for which we will compute mat*mat^T. It does not need to be
square, and may be batched.
mat: The matrix for which we will compute a contraction with itself.
block_size: The size of row blocks to compute.
axes: Axes to use for the contraction.
precision: The precision to use in each computation.
Raises:
ValueError: Raised when the specified block size does not evenly divide
the number of rows of the input mat.
"""
num_rows = mat.shape[-2]
rank = len(mat.shape)

def _make_axis_positive(ax):
assert -rank <= ax < rank
return ax + rank if ax < 0 else ax

positive_axes = [_make_axis_positive(ax) for ax in axes]
assert len(positive_axes) == len(axes)
remaining_axes = set(range(rank)) - set(positive_axes)
assert len(remaining_axes) == 1
remaining_ax = remaining_axes.pop()

num_rows = mat.shape[remaining_ax]
if num_rows % block_size != 0:
raise ValueError(
"The row dimension must be divisible by block_size. "
f"Instead got row dimension={num_rows} and block_size={block_size}."
)
block_rows = [
product_with_transpose(
mat[Ellipsis, i * block_size : (i + 1) * block_size, :],
mat[Ellipsis, 0 : (i + 1) * block_size, :],
precision,

block_rows = []
for i in range(num_rows // block_size):
start_indices = [0] * rank
start_indices[remaining_ax] = i * block_size

slice_sizes = list(mat.shape)
slice_sizes[remaining_ax] = block_size

slice_sizes_full = list(mat.shape)
slice_sizes_full[remaining_ax] = (i + 1) * block_size

block_rows.append(
product_with_transpose(
lax.dynamic_slice(
mat, start_indices=start_indices, slice_sizes=slice_sizes
),
lax.dynamic_slice(
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
),
axes=(axes, axes),
precision=precision,
)
)
for i in range(num_rows // block_size)
]

return SlicedSymmetricMatrix(block_rows=block_rows)


@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
def sliced_transposed_product_concat(
mat,
block_size,
axes=(-1,),
precision=lax.Precision.DEFAULT,
):
"""Returns the concatenated slices representing mat*mat^T.
Expand All @@ -102,14 +139,15 @@ def sliced_transposed_product_concat(
mat: The matrix for which we will compute mat*mat^T. It does not need to be
square, and may be batched.
block_size: The size of row blocks to compute.
axes: Axes to use for the contraction.
precision: The precision to use in each computation.
Raises:
ValueError: Raised when the specified block size does not evenly divide
the number of rows of the input mat.
"""
sliced_symmetric_matrix = sliced_transposed_product(
mat=mat, block_size=block_size, precision=precision
mat=mat, block_size=block_size, axes=axes, precision=precision
)
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)

Expand Down Expand Up @@ -179,12 +217,13 @@ def materialize_matrix_from_concat(
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))


@functools.partial(jax.jit, static_argnames=("alpha", "beta"))
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
def update_sliced_rows(
symmetric_matrix,
mat,
alpha,
beta,
axes=(-1,),
):
"""Implements the blocked equivalent of SYRK.
Expand All @@ -197,15 +236,45 @@ def update_sliced_rows(
should match that of symmetric_matrix.
alpha: The weight for the update.
beta: The weight for the original symmetric matrix.
axes: Axes to use for the contraction of the update.
Returns:
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
"""
block_size = symmetric_matrix.block_rows[0].shape[-2]
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
return SlicedSymmetricMatrix(
block_rows=[
update * alpha + row * beta
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
]
)


def find_num_blocks(block_rows_concat):
"""Returns the number of (row) blocks representing the concatenated matrix.
For example, an input with dimensions [256, 2560] represents 10 square blocks,
which matches 4 lower-triangular block rows (1+2+3+4). So this function will
return 4.
Use ordinary numpy functions here so that the returned value is static.
Args:
block_rows_concat: The concatenated block array.
Raises:
ValueError: When the dimensions of the matrix do not correspond to a lower
triangular block representation.
"""
# Compute the number of square blocks used to represent the matrix.
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
# Determine the number of block rows by inverting y = x*(x+1)/2.
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
if num_blocks * (num_blocks + 1) / 2 != total_blocks:
raise ValueError(
"Could not determine an appropriate number of blocks for "
"the concatenated matrix."
)
else:
return num_blocks
69 changes: 63 additions & 6 deletions tools/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import transformers
import wandb
from datasets import Dataset
from flax.core.frozen_dict import FrozenDict, freeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import onehot
Expand Down Expand Up @@ -405,6 +405,12 @@ class TrainingArguments:
default=False,
metadata={"help": "Log model to wandb at `save_steps` frequency."},
)
log_histograms: bool = field(
default=False,
metadata={
"help": "Log parameters and gradients histograms. Slows down training."
},
)

seed_model: int = field(
default=42,
Expand Down Expand Up @@ -514,10 +520,22 @@ def update_state_metrics(self, state):

def log(self, metrics, prefix=None):
if jax.process_index() == 0:
log_metrics = {
f"{prefix}/{k}" if prefix is not None else k: v
for k, v in metrics.items()
}
log_metrics = {}
for k, v in metrics.items():
if prefix is not None:
k = f"{prefix}/{k}"
if "_norm" in k:
log_metrics[f"{k}/"] = unfreeze(v)
elif "_hist" in k:
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
v = jax.tree_map(
lambda x: wandb.Histogram(np_histogram=x),
v,
is_leaf=lambda x: isinstance(x, tuple),
)
log_metrics[f"{k}/"] = v
else:
log_metrics[k] = v
wandb.log({**log_metrics, **self.state_dict})


Expand Down Expand Up @@ -1024,20 +1042,59 @@ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
)

# update state
grads = with_sharding_constraint(grads, param_spec)

# update state
state = state.apply_gradients(
grads=grads,
dropout_rng=dropout_rng,
train_time=state.train_time + delta_time,
train_samples=state.train_samples + batch_size_per_step,
)

# get norm and histogram of grads and params
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)

def maybe_fn(fn, val, zeros):
"""Call fn only if it is a logging step"""
return jax.lax.cond(
state.step % training_args.logging_steps == 0,
fn,
lambda _: zeros,
val,
)

def norm(val):
return jax.tree_map(lambda x: jnp.linalg.norm(x), val)

gradients_norm = maybe_fn(norm, grads, zeros_norm)
params_norm = maybe_fn(norm, state.params, zeros_norm)

metrics = {
"loss": loss,
"learning_rate": learning_rate_fn(state.step),
"gradients_norm": gradients_norm,
"params_norm": params_norm,
}

if training_args.log_histograms:
zeros_hist = jax.tree_map(
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
)

def histogram(val):
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)

gradients_hist = maybe_fn(histogram, grads, zeros_hist)
params_hist = maybe_fn(histogram, state.params, zeros_hist)

metrics.update(
{
"params_hist": params_hist,
"gradients_hist": gradients_hist,
}
)

return state, metrics

# Define eval fn
Expand Down

0 comments on commit b7b619a

Please sign in to comment.