Skip to content

Commit

Permalink
Update Gemax to add an option to use the QTensor-saving Freezer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619056575
  • Loading branch information
cdh4696 authored and Copybara-Service committed Apr 3, 2024
1 parent 153a7e2 commit 3771efe
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 49 deletions.
68 changes: 61 additions & 7 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
# pylint: disable=g-importing-member
import copy
import functools
from typing import Iterable
from typing import Optional, Union
from typing import Any, Callable, Iterable, Optional, Sequence, Union
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2 import config
Expand All @@ -27,11 +26,19 @@
from aqt.jax.v2.flax import aqt_flax_dg_core
from aqt.jax.v2.flax import freezer as general_freezer
from aqt.jax.v2.flax.utils import QuantMode
import flax.core.meta as nn_meta
import flax.linen as nn
import jax
import jax.numpy as jnp


_AxisMetadataWrapper = Callable[
[jnp.ndarray, Sequence[utils.AxisIdx]], nn_meta.AxisMetadata
]

_default_wrapper = lambda x, no_sharding_axes: x


class Freezer(nn.Module):
"""Identity function that can freeze its input.
Expand Down Expand Up @@ -129,18 +136,25 @@ class AqtDotGeneral(nn.Module):
lhs_quant_mode: QuantMode = QuantMode.TRAIN
# apply_quant_mode determines if using Freezer in cfg.get/set_tensor
lhs_apply_quant_mode: bool = True
lhs_init: nn.initializers.Initializer = jnp.zeros
lhs_scale_init: nn.initializers.Initializer = jnp.zeros
lhs_var_name: str = 'qlhs'
lhs_qtensor: Optional[aqt_tensor.QTensor] = None

rhs_quant_mode: QuantMode = QuantMode.TRAIN
rhs_apply_quant_mode: bool = True
rhs_init: nn.initializers.Initializer = jnp.zeros
rhs_scale_init: nn.initializers.Initializer = jnp.zeros
rhs_var_name: str = 'qrhs'
rhs_qtensor: Optional[aqt_tensor.QTensor] = None

# Variables only for the legacy Freezer.
lhs_init: nn.initializers.Initializer = jnp.zeros
lhs_scale_init: nn.initializers.Initializer = jnp.zeros

rhs_init: nn.initializers.Initializer = jnp.zeros
rhs_scale_init: nn.initializers.Initializer = jnp.zeros

# Variables only for the new Freezer.
lhs_axis_metadata_wrapper: _AxisMetadataWrapper = _default_wrapper
rhs_axis_metadata_wrapper: _AxisMetadataWrapper = _default_wrapper

# If you want use 'params' make sure that there is another mechanism to hide
# these variables from the optimizer.
quant_collection: str = 'aqt'
Expand Down Expand Up @@ -213,16 +227,46 @@ def make_aqt_dg(
QuantMode.CONVERT: general_freezer.FreezerMode.WRITE,
QuantMode.SERVE: general_freezer.FreezerMode.READ,
}

def init_wrapper(
qt: aqt_tensor.QTensor,
axis_metadata_wrapper: _AxisMetadataWrapper,
):
# We are not doing any sharding for scale and scale_t, for now.
scale_non_shard_axis = range(qt.ndim)

qt = qt.replace(
qvalue=axis_metadata_wrapper(qt.qvalue, []),
scale=jax.tree_map(
lambda x: axis_metadata_wrapper(x, scale_non_shard_axis),
qt.scale,
),
scale_t=jax.tree_map(
lambda x: axis_metadata_wrapper(x, scale_non_shard_axis),
qt.scale_t,
),
)
return qt

lhs_init_wrapper = functools.partial(
init_wrapper, axis_metadata_wrapper=self.lhs_axis_metadata_wrapper
)
rhs_init_wrapper = functools.partial(
init_wrapper, axis_metadata_wrapper=self.rhs_axis_metadata_wrapper
)

lhs_freezer = general_freezer.Freezer(
name=self.lhs_var_name,
mode=quant_to_freezer_mode[lhs_qm],
collection=self.quant_collection,
axis_metadata_wrapper=lhs_init_wrapper,
)

rhs_freezer = general_freezer.Freezer(
name=self.rhs_var_name,
mode=quant_to_freezer_mode[rhs_qm],
collection=self.quant_collection,
axis_metadata_wrapper=rhs_init_wrapper,
)

prng_name = self.prng_name
Expand Down Expand Up @@ -326,11 +370,13 @@ class AqtEinsum(nn.Module):
# TODO(lew): split out separate class for each side.
lhs_quant_mode: QuantMode = QuantMode.TRAIN
lhs_init: nn.initializers.Initializer = jnp.zeros
lhs_axis_metadata_wrapper: Optional[Callable[..., Any]] = None
lhs_scale_init: nn.initializers.Initializer = jnp.zeros
lhs_var_name: str = 'qlhs'

rhs_quant_mode: QuantMode = QuantMode.TRAIN
rhs_init: nn.initializers.Initializer = jnp.zeros
rhs_axis_metadata_wrapper: Optional[Callable[..., Any]] = None
rhs_scale_init: nn.initializers.Initializer = jnp.zeros
rhs_var_name: str = 'qrhs'

Expand All @@ -341,7 +387,7 @@ class AqtEinsum(nn.Module):
assert_eqn: Optional[str] = None
assert_lhs_shape: Optional[utils.ShapeTemplate] = None
assert_rhs_shape: Optional[utils.ShapeTemplate] = None
tile_sizes: Optional[tiled_dot_general.EinsumTileSizes] = None
tile_sizes: Optional[utils.EinsumTileSizes] = None

# If set to True, use the current Freezer. Otherwise, use the new
# QTensorFreezer.
Expand Down Expand Up @@ -393,12 +439,14 @@ def __call__(

lhs_quant_mode = self.lhs_quant_mode
lhs_init = self.lhs_init
lhs_axis_metadata_wrapper = self.lhs_axis_metadata_wrapper
lhs_scale_init = self.lhs_scale_init
lhs_var_name = self.lhs_var_name
lhs_qtensor = lhs_g if lhs_is_qt else None

rhs_quant_mode = self.rhs_quant_mode
rhs_init = self.rhs_init
rhs_axis_metadata_wrapper = self.rhs_axis_metadata_wrapper
rhs_scale_init = self.rhs_scale_init
rhs_var_name = self.rhs_var_name
rhs_qtensor = rhs_g if rhs_is_qt else None
Expand All @@ -420,6 +468,10 @@ def __call__(
lhs_var_name, rhs_var_name = rhs_var_name, lhs_var_name
lhs_is_qt, rhs_is_qt = rhs_is_qt, lhs_is_qt
lhs_qtensor, rhs_qtensor = rhs_qtensor, lhs_qtensor
lhs_axis_metadata_wrapper, rhs_axis_metadata_wrapper = (
rhs_axis_metadata_wrapper,
lhs_axis_metadata_wrapper,
)
if tiling_config is not None:
tiling_config = tiled_dot_general.Cfg(
lhs=tiling_config.rhs, rhs=tiling_config.lhs
Expand All @@ -434,12 +486,14 @@ def __call__(
# the qtensor passed to dg.
lhs_apply_quant_mode=not lhs_is_qt, # Freezer not used if lhs is qt
lhs_init=lhs_init,
lhs_axis_metadata_wrapper=lhs_axis_metadata_wrapper,
lhs_scale_init=lhs_scale_init,
lhs_var_name=lhs_var_name,
lhs_qtensor=lhs_qtensor,
rhs_quant_mode=rhs_quant_mode,
rhs_apply_quant_mode=not rhs_is_qt, # Freezer not used if rhs is qt
rhs_init=rhs_init,
rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper,
rhs_scale_init=rhs_scale_init,
rhs_var_name=rhs_var_name,
rhs_qtensor=rhs_qtensor,
Expand Down
18 changes: 7 additions & 11 deletions aqt/jax/v2/flax/freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import enum
from typing import Any, Callable

from flax.core import meta as nn_meta
import flax.linen as nn


Expand Down Expand Up @@ -45,30 +46,25 @@ class Freezer(nn.Module):

collection: str
mode: FreezerMode
init_wrapper: Callable[..., Any] | None = None
axis_metadata_wrapper: Callable[..., nn_meta.AxisMetadata] | None = None

@nn.compact
def _get_or_set(self, inputs: Any, is_set: bool) -> Any | None:
def initializer():
if self.init_wrapper is not None:
# If we don't copy inputs here, the init_wrapper may change the internal
# structure of inputs, and it will be reflected when you apply
if self.axis_metadata_wrapper is not None:
# If we don't copy inputs here, the axis_metadata_wrapper may change the
# internal structure of inputs, and it will be reflected when you apply
# s.value = input. That could result in an odd behavior.
return self.init_wrapper(copy.deepcopy(inputs))
return self.axis_metadata_wrapper(copy.deepcopy(inputs))
return inputs

if is_set:
match self.mode:
case FreezerMode.NONE:
pass
case FreezerMode.WRITE:
is_init = not self.has_variable(self.collection, _FREEZE_VAR_NAME)
s = self.variable(self.collection, _FREEZE_VAR_NAME, initializer)
if not is_init:
# In case we are using the initialization wrapper, we should NOT
# call this line, since this line could overwrite the initialized
# value (which is wrapped using self.init_wrapper).
s.value = inputs
s.value = inputs
return None
case FreezerMode.READ:
# Set in READ mode works as an initializer for checkpoint reading.
Expand Down
45 changes: 33 additions & 12 deletions aqt/jax/v2/flax/freezer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aqt.jax.v2.flax import freezer
import flax
from flax import linen as nn
from flax.core import meta as nn_meta
import jax
from jax import numpy as jnp

Expand All @@ -33,11 +34,13 @@ class _CustomStructure:

class TestModel(nn.Module):
freezer_mode: freezer.FreezerMode
init_wrapper: Callable[..., Any] | None = None
axis_metadata_wrapper: Callable[..., nn_meta.AxisMetadata] | None = None

def setup(self):
self.f = freezer.Freezer(
'freezer', mode=self.freezer_mode, init_wrapper=self.init_wrapper
'freezer',
mode=self.freezer_mode,
axis_metadata_wrapper=self.axis_metadata_wrapper,
)

def __call__(self, x):
Expand Down Expand Up @@ -78,14 +81,28 @@ def _assert_same_tree_shape_dtype(self, tree1, tree2):
self.assertEqual(leaf1.dtype, leaf2.dtype)

def test_freezer_get_set(self):
class CustomWrapper(flax.struct.PyTreeNode):
class CustomWrapper(flax.struct.PyTreeNode, nn_meta.AxisMetadata):
value: Any
metadata: str
metadata: str = flax.struct.field(default=None, pytree_node=False)

def init_wrapper(x: _CustomStructure):
def unbox(self):
return self.value

def replace_boxed(self, val):
return self.replace(value=val)

def add_axis(self, index, params):
return self

def remove_axis(self, index, params):
return self

def axis_metadata_wrapper(x: _CustomStructure):
ret = x.replace(
member=CustomWrapper(x.member, 'member metadata'),
member_list=CustomWrapper(x.member_list, 'member list metadata'),
member_list=[
CustomWrapper(v, 'member list metadata') for v in x.member_list
],
)
return ret

Expand Down Expand Up @@ -117,16 +134,18 @@ def unbox(x):

# 2. WRITE mode test.
tm_write = TestModel(
freezer_mode=freezer.FreezerMode.WRITE, init_wrapper=init_wrapper
freezer_mode=freezer.FreezerMode.WRITE,
axis_metadata_wrapper=axis_metadata_wrapper,
)
param_init_write = tm_write.init(subkeys[4], cs_for_init)

# Check if the init parameters are properly wrapped.
cs_frozen = param_init_write['freezer']['f']['frozen']
self.assertIsInstance(cs_frozen.member, CustomWrapper)
self.assertIsInstance(cs_frozen.member_list, CustomWrapper)
self.assertEqual(cs_frozen.member.metadata, 'member metadata')
self.assertEqual(cs_frozen.member_list.metadata, 'member list metadata')
for value in cs_frozen.member_list:
self.assertIsInstance(value, CustomWrapper)
self.assertEqual(value.metadata, 'member list metadata')

# Unbox the initialization parameters.
# The metadata in the box is used to shard the variables here.
Expand Down Expand Up @@ -156,16 +175,18 @@ def unbox(x):

# 3. READ mode test.
tm_read = TestModel(
freezer_mode=freezer.FreezerMode.READ, init_wrapper=init_wrapper
freezer_mode=freezer.FreezerMode.READ,
axis_metadata_wrapper=axis_metadata_wrapper,
)
param_init_read = tm_read.init(subkeys[5], cs_for_init)

# Check if the init parameters are properly wrapped.
cs_frozen = param_init_read['freezer']['f']['frozen']
self.assertIsInstance(cs_frozen.member, CustomWrapper)
self.assertIsInstance(cs_frozen.member_list, CustomWrapper)
self.assertEqual(cs_frozen.member.metadata, 'member metadata')
self.assertEqual(cs_frozen.member_list.metadata, 'member list metadata')
for value in cs_frozen.member_list:
self.assertIsInstance(value, CustomWrapper)
self.assertEqual(value.metadata, 'member list metadata')

# Unbox the initialization parameters.
param_init_read = jax.tree.map(
Expand Down
Loading

0 comments on commit 3771efe

Please sign in to comment.