Skip to content

Commit

Permalink
Update Gemax to add an option to use the QTensor-saving Freezer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619056575
  • Loading branch information
cdh4696 authored and Copybara-Service committed Mar 27, 2024
1 parent 42082a2 commit 7790c56
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 13 deletions.
56 changes: 55 additions & 1 deletion aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import copy
import functools
from typing import Iterable
from typing import Optional, Union
from typing import Optional, Union, Callable, Any
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2 import config
Expand Down Expand Up @@ -130,14 +130,18 @@ class AqtDotGeneral(nn.Module):
# apply_quant_mode determines if using Freezer in cfg.get/set_tensor
lhs_apply_quant_mode: bool = True
lhs_init: nn.initializers.Initializer = jnp.zeros
lhs_wrapper: Optional[Callable[..., Any]] = None
lhs_scale_init: nn.initializers.Initializer = jnp.zeros
lhs_scale_wrapper: Optional[Callable[..., Any]] = None
lhs_var_name: str = 'qlhs'
lhs_qtensor: Optional[aqt_tensor.QTensor] = None

rhs_quant_mode: QuantMode = QuantMode.TRAIN
rhs_apply_quant_mode: bool = True
rhs_init: nn.initializers.Initializer = jnp.zeros
rhs_wrapper: Optional[Callable[..., Any]] = None
rhs_scale_init: nn.initializers.Initializer = jnp.zeros
rhs_scale_wrapper: Optional[Callable[..., Any]] = None
rhs_var_name: str = 'qrhs'
rhs_qtensor: Optional[aqt_tensor.QTensor] = None

Expand Down Expand Up @@ -213,16 +217,49 @@ def make_aqt_dg(
QuantMode.CONVERT: general_freezer.FreezerMode.WRITE,
QuantMode.SERVE: general_freezer.FreezerMode.READ,
}

def init_wrapper(
qt: aqt_tensor.QTensor,
wrapper: Optional[Callable[..., Any]],
scale_wrapper: Optional[Callable[..., Any]],
):
if wrapper is not None:
if qt.qvalue is not None:
qt.qvalue = wrapper(lambda: qt.qvalue)()
if scale_wrapper is not None:
if qt.scale is not None:
qt.scale = jax.tree_map(
lambda x: scale_wrapper(lambda: x)(), qt.scale
)
if qt.scale_t is not None:
qt.scale_t = jax.tree_map(
lambda x: scale_wrapper(lambda: x)(), qt.scale_t
)
return qt

lhs_init_wrapper = functools.partial(
init_wrapper,
wrapper=self.lhs_wrapper,
scale_wrapper=self.lhs_scale_wrapper,
)
rhs_init_wrapper = functools.partial(
init_wrapper,
wrapper=self.rhs_wrapper,
scale_wrapper=self.rhs_scale_wrapper,
)

lhs_freezer = general_freezer.Freezer(
name=self.lhs_var_name,
mode=quant_to_freezer_mode[lhs_qm],
collection=self.quant_collection,
init_wrapper=lhs_init_wrapper,
)

rhs_freezer = general_freezer.Freezer(
name=self.rhs_var_name,
mode=quant_to_freezer_mode[rhs_qm],
collection=self.quant_collection,
init_wrapper=rhs_init_wrapper,
)

prng_name = self.prng_name
Expand Down Expand Up @@ -326,12 +363,16 @@ class AqtEinsum(nn.Module):
# TODO(lew): split out separate class for each side.
lhs_quant_mode: QuantMode = QuantMode.TRAIN
lhs_init: nn.initializers.Initializer = jnp.zeros
lhs_wrapper: Optional[Callable[..., Any]] = None
lhs_scale_init: nn.initializers.Initializer = jnp.zeros
lhs_scale_wrapper: Optional[Callable[..., Any]] = None
lhs_var_name: str = 'qlhs'

rhs_quant_mode: QuantMode = QuantMode.TRAIN
rhs_init: nn.initializers.Initializer = jnp.zeros
rhs_wrapper: Optional[Callable[..., Any]] = None
rhs_scale_init: nn.initializers.Initializer = jnp.zeros
rhs_scale_wrapper: Optional[Callable[..., Any]] = None
rhs_var_name: str = 'qrhs'

# If you want use 'params' make sure that there is another mechanism to hide
Expand Down Expand Up @@ -393,13 +434,17 @@ def __call__(

lhs_quant_mode = self.lhs_quant_mode
lhs_init = self.lhs_init
lhs_wrapper = self.lhs_wrapper
lhs_scale_init = self.lhs_scale_init
lhs_scale_wrapper = self.lhs_scale_wrapper
lhs_var_name = self.lhs_var_name
lhs_qtensor = lhs_g if lhs_is_qt else None

rhs_quant_mode = self.rhs_quant_mode
rhs_init = self.rhs_init
rhs_wrapper = self.rhs_wrapper
rhs_scale_init = self.rhs_scale_init
rhs_scale_wrapper = self.rhs_scale_wrapper
rhs_var_name = self.rhs_var_name
rhs_qtensor = rhs_g if rhs_is_qt else None

Expand All @@ -420,6 +465,11 @@ def __call__(
lhs_var_name, rhs_var_name = rhs_var_name, lhs_var_name
lhs_is_qt, rhs_is_qt = rhs_is_qt, lhs_is_qt
lhs_qtensor, rhs_qtensor = rhs_qtensor, lhs_qtensor
lhs_wrapper, rhs_wrapper = rhs_wrapper, lhs_wrapper
lhs_scale_wrapper, rhs_scale_wrapper = (
rhs_scale_wrapper,
lhs_scale_wrapper,
)
if tiling_config is not None:
tiling_config = tiled_dot_general.Cfg(
lhs=tiling_config.rhs, rhs=tiling_config.lhs
Expand All @@ -434,13 +484,17 @@ def __call__(
# the qtensor passed to dg.
lhs_apply_quant_mode=not lhs_is_qt, # Freezer not used if lhs is qt
lhs_init=lhs_init,
lhs_wrapper=lhs_wrapper,
lhs_scale_init=lhs_scale_init,
lhs_scale_wrapper=lhs_scale_wrapper,
lhs_var_name=lhs_var_name,
lhs_qtensor=lhs_qtensor,
rhs_quant_mode=rhs_quant_mode,
rhs_apply_quant_mode=not rhs_is_qt, # Freezer not used if rhs is qt
rhs_init=rhs_init,
rhs_wrapper=rhs_wrapper,
rhs_scale_init=rhs_scale_init,
rhs_scale_wrapper=rhs_scale_wrapper,
rhs_var_name=rhs_var_name,
rhs_qtensor=rhs_qtensor,
quant_collection=quant_collection,
Expand Down
31 changes: 21 additions & 10 deletions aqt/jax/v2/flax/freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Freezer for writing & storing general Flax structure."""

import copy
import enum
from typing import Any
from typing import Any, Callable

import flax.linen as nn

Expand Down Expand Up @@ -44,25 +45,35 @@ class Freezer(nn.Module):

collection: str
mode: FreezerMode
init_wrapper: Callable[..., Any] | None = None

@nn.compact
def _get_or_set(self, inputs: Any, is_set: bool) -> Any | None:
def initializer():
if self.init_wrapper is not None:
# If we don't copy inputs here, the init_wrapper may change the internal
# structure of inputs, and it will be reflected when you apply
# s.value = input. That could result in an odd behavior.
return self.init_wrapper(copy.deepcopy(inputs))
return inputs

if is_set:
match self.mode:
case FreezerMode.NONE:
pass
case FreezerMode.WRITE:
s = self.variable(
self.collection, _FREEZE_VAR_NAME, lambda: inputs
)
s.value = inputs
is_init = not self.has_variable(self.collection, _FREEZE_VAR_NAME)
s = self.variable(self.collection, _FREEZE_VAR_NAME, initializer)
if not is_init:
# In case we are using the initialization wrapper, we should NOT
# call this line, since this line could overwrite the initialized
# value (which is wrapped using self.init_wrapper).
s.value = inputs
return None
case FreezerMode.READ:
# Set in READ mode works as an initializer for checkpoint reading.
# we don't want to change the variable during the serving
_ = self.variable(
self.collection, _FREEZE_VAR_NAME, lambda: inputs
)
_ = self.variable(self.collection, _FREEZE_VAR_NAME, initializer)
return None
case _:
# Nothing matched.
Expand All @@ -80,11 +91,11 @@ def _get_or_set(self, inputs: Any, is_set: bool) -> Any | None:

msg = 'Initialization should not happen in Get mode, but in Set mode.'

def initializer():
def initializer_bugish():
assert False, msg

return self.variable(
self.collection, _FREEZE_VAR_NAME, initializer
self.collection, _FREEZE_VAR_NAME, initializer_bugish
).value
case _:
# Nothing matched.
Expand Down
32 changes: 30 additions & 2 deletions aqt/jax/v2/flax/freezer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Tests for freezer."""

from typing import Mapping, Sequence
from typing import Any, Mapping, Sequence, Callable

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -33,9 +33,12 @@ class _CustomStructure:

class TestModel(nn.Module):
freezer_mode: freezer.FreezerMode
init_wrapper: Callable[..., Any] | None = None

def setup(self):
self.f = freezer.Freezer('freezer', mode=self.freezer_mode)
self.f = freezer.Freezer(
'freezer', mode=self.freezer_mode, init_wrapper=self.init_wrapper
)

def __call__(self, x):
"""Emulates basic routine on how to use the freezer."""
Expand Down Expand Up @@ -142,6 +145,31 @@ def test_freezer_get_set(self):

self.assertEqual(cs, get_after_set_read)

def test_custom_init_wrapper(self):
class CustomWrapper(flax.struct.PyTreeNode):
value: Any
metadata: str

def init_wrapper(x: _CustomStructure):
ret = x.replace(
member=CustomWrapper(x.member, 'member metadata'),
member_list=CustomWrapper(x.member_list, 'member list metadata'),
)
return ret

subkeys = jax.random.split(jax.random.PRNGKey(0), 2)
cs = self._create_custom_structure(subkeys[0])

tm_write = TestModel(
freezer_mode=freezer.FreezerMode.WRITE, init_wrapper=init_wrapper
)
param_init_write = tm_write.init(subkeys[1], cs)
cs_frozen = param_init_write['freezer']['f']['frozen']
self.assertIsInstance(cs_frozen.member, CustomWrapper)
self.assertIsInstance(cs_frozen.member_list, CustomWrapper)
self.assertEqual(cs_frozen.member.metadata, 'member metadata')
self.assertEqual(cs_frozen.member_list.metadata, 'member list metadata')


if __name__ == '__main__':
absltest.main()

0 comments on commit 7790c56

Please sign in to comment.