diff --git a/CHANGELOG.md b/CHANGELOG.md index b58c2be91..6d6ed4b80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ vNext - Added Optax update guide and deprecated `flax.optim`. - Added `sep` argument to `flax.traverse_util.flatten_dict()`. - +- Remove float32 dtype assumption throughout linen (except for reccurent + modules). - - Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`. - @@ -48,7 +50,7 @@ Breaking changes: - You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not - ambigious when the kernel rank is known. + ambigious when the kernel rank is known. - `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call. - NamedTuples are no longer converted to tuples on assignment to a `linen.Module`. @@ -64,7 +66,7 @@ Bugfixes: - Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)). - Mixed precision training with float16 now works correctly with the attention layers. - auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes. - + 0.3.4 @@ -72,7 +74,7 @@ Bugfixes: Possibly breaking changes: - When calling `init` the 'intermediates' collection is no longer mutable. - Therefore, intermediates will no longer be returned from initialization by default. + Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization. - When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`. @@ -105,9 +107,9 @@ Possible breaking changes: latest checkpoint already saved. - MultiOptimizer now rejects the case where multiple sub optimizers update the same parameter. - + Other changes: - - Added custom error classes to many Linen errors. See: + - Added custom error classes to many Linen errors. See: https://flax.readthedocs.io/en/latest/flax.errors.html - Adds `Module.bind` for binding variables and RNGs to an interactive Module. - Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument. @@ -127,7 +129,7 @@ NOTE: You must now explicitly import `flax.nn` if you want to use the old 0.3.1 ------ -Many improvements to Linen, and the old `flax.nn` is officially reprecated! +Many improvements to Linen, and the old `flax.nn` is officially reprecated! Notably, there's a clean API for extracting intermediates from modules defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules @@ -141,7 +143,7 @@ Possible breaking changes: is enforced by raising a TypeError in `__setattr__` after `setup`. - Pytrees of dicts and lists are transformed into FrozenDict and tuples during attribute assignment. - This avoids undetected submodules and inner state. + This avoids undetected submodules and inner state. - Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list. diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 2d57f9621..c95d7c43b 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -17,22 +17,25 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions -from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, - log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh, PReLU) +from .activation import (PReLU, celu, elu, gelu, glu, leaky_relu, log_sigmoid, + log_softmax, relu, sigmoid, silu, soft_sign, softmax, + softplus, swish, tanh) from .attention import (MultiHeadDotProductAttention, SelfAttention, - dot_product_attention, dot_product_attention_weights, - make_attention_mask, make_causal_mask, combine_masks) -from ..core import broadcast, DenyList, FrozenDict + combine_masks, dot_product_attention, + dot_product_attention_weights, make_attention_mask, + make_causal_mask) +from ..core import DenyList, FrozenDict, broadcast +from .dtypes import canonicalize_inexact_dtypes, canonicalize_numeric_dtypes +from .initializers import ones, zeros from .linear import Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed -from .module import (Module, compact, nowrap, enable_named_call, - disable_named_call, override_named_call, Variable, init, - init_with_output, apply, merge_param) +from .module import (Module, Variable, apply, compact, disable_named_call, + enable_named_call, init, init_with_output, merge_param, + nowrap, override_named_call) from .normalization import BatchNorm, GroupNorm, LayerNorm from .pooling import avg_pool, max_pool, pool -from .recurrent import GRUCell, LSTMCell, ConvLSTM, OptimizedLSTMCell +from .recurrent import ConvLSTM, GRUCell, LSTMCell, OptimizedLSTMCell from .stochastic import Dropout -from .transforms import jit, named_call, checkpoint, remat, remat_scan, scan, vmap, map_variables, vjp, jvp, custom_vjp -from .initializers import zeros, ones +from .transforms import (checkpoint, custom_vjp, jit, jvp, map_variables, + named_call, remat, remat_scan, scan, vjp, vmap) # pylint: enable=g-multiple-import diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 268123ecc..f640b08c8 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -15,48 +15,33 @@ """Activation functions. """ +import jax.numpy as jnp # pylint: disable=unused-import -# re-export activation functions from jax.nn -from jax.nn import celu -from jax.nn import elu -from jax.nn import gelu -from jax.nn import glu -from jax.nn import leaky_relu -from jax.nn import log_sigmoid -from jax.nn import log_softmax -from jax.nn import normalize -from jax.nn import relu -from jax.nn import sigmoid -from jax.nn import soft_sign -from jax.nn import softmax -from jax.nn import softplus -from jax.nn import swish -from jax.nn import silu -from jax.nn import selu -from jax.nn import hard_tanh -from jax.nn import relu6 -from jax.nn import hard_sigmoid -from jax.nn import hard_swish - +# re-export activation functions from jax.nn and jax.numpy +from jax.nn import (celu, elu, gelu, glu, hard_sigmoid, hard_swish, hard_tanh, + leaky_relu, log_sigmoid, log_softmax, normalize, relu, + relu6, selu, sigmoid, silu, soft_sign, softmax, softplus, + swish) from jax.numpy import tanh # pylint: enable=unused-import -from typing import Any - -from flax.linen.module import Module, compact -import jax.numpy as jnp - - -Array = Any +from .dtypes import Array, FloatingDType, canonicalize_inexact_dtypes +from .module import Module, compact class PReLU(Module): """Parametric Rectified Linear Unit (PReLU) activation function. Attributes: - negative_slope_init: the value to initialize the negative slope. + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). + negative_slope_init: the value to initialize the negative slope + (default 0.01). """ + dtype: Optional[FloatingDType] = jnp.float32 + param_dtype: Optional[FloatingDType] = jnp.float32 negative_slope_init: float = 0.01 + @compact def __call__(self, inputs: Array) -> Array: """Applies an activation to the inputs. @@ -67,8 +52,13 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + assert jnp.issubdtype(inputs.dtype, jnp.floating) + inputs = jnp.asarray(inputs, dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, param_dtype, + self.dtype) negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, jnp.float32) + lambda k: jnp.asarray(self.negative_slope_init, param_dtype) ) - return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs) + negative_slope = jnp.asarray(negative_slope, dtype) + return jnp.where(inputs >= 0, inputs, negative_slope * inputs) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 749ded6f3..5b90ad1e3 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,23 +15,34 @@ """Attention core modules for Flax.""" from functools import partial -from typing import (Any, Callable, Tuple, Optional) +from typing import Callable, Optional import jax -from jax import lax -from jax import random import jax.numpy as jnp -import numpy as np - -from flax.linen.linear import default_kernel_init -from flax.linen.linear import DenseGeneral -from flax.linen.module import Module, compact, merge_param -from flax.linen.initializers import zeros - -PRNGKey = Any -Shape = Tuple[int] -Dtype = Any -Array = Any +from jax import lax, random +from typing_extensions import Protocol + +from .dtypes import (Array, InexactDType, Initializer, PRNGKey, + canonicalize_inexact_dtypes) +from .initializers import zeros +from .linear import DenseGeneral, default_kernel_init +from .module import Module, compact, merge_param + + +class AttentionFunction(Protocol): + @staticmethod + def __call__(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: + ... def dot_product_attention_weights(query: Array, @@ -42,7 +53,7 @@ def dot_product_attention_weights(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, + dtype: InexactDType = jnp.float32, precision: Optional[lax.Precision] = None): """Computes dot-product attention weights given query and key. @@ -109,7 +120,7 @@ def dot_product_attention_weights(query: Array, keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(attn_weights.dtype) / + multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier @@ -125,8 +136,8 @@ def dot_product_attention(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, - precision: Optional[lax.Precision] = None): + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -205,18 +216,18 @@ class MultiHeadDotProductAttention(Module): decode: whether to prepare and use an autoregressive cache. """ num_heads: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True dropout_rate: float = 0. deterministic: Optional[bool] = None precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros use_bias: bool = True - attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention + attention_fn: AttentionFunction = dot_product_attention decode: bool = False @compact @@ -246,6 +257,10 @@ def __call__(self, Returns: output of shape `[batch_sizes..., length, features]`. """ + param_dtype, dtype = canonicalize_inexact_dtypes(jnp.result_type(inputs_q, + inputs_kv), + self.param_dtype, + self.dtype) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -254,8 +269,8 @@ def __call__(self, dense = partial(DenseGeneral, axis=-1, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, @@ -323,7 +338,7 @@ def __call__(self, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=m_deterministic, - dtype=self.dtype, + dtype=dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, @@ -331,8 +346,8 @@ def __call__(self, kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, precision=self.precision, name='out')(x) return out @@ -350,11 +365,12 @@ def __call__(self, inputs_q: Array, mask: Optional[Array] = None, # mask-making utility functions -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable[..., Any] = jnp.multiply, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32): +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: InexactDType = jnp.float32): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the @@ -381,7 +397,7 @@ def make_attention_mask(query_input: Array, def make_causal_mask(x: Array, extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32) -> Array: + dtype: InexactDType = jnp.float32) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights @@ -403,7 +419,7 @@ def make_causal_mask(x: Array, def combine_masks(*masks: Optional[Array], - dtype: Dtype = jnp.float32) -> Array: + dtype: InexactDType = jnp.float32) -> Array: """Combine attention masks. Args: diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py new file mode 100644 index 000000000..4e4e41ca2 --- /dev/null +++ b/flax/linen/dtypes.py @@ -0,0 +1,57 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with dtypes.""" + +from typing import Any, Callable, Optional, Tuple, Type + +import jax.numpy as jnp +import numpy as np + + +Array = Any # pylint: disable=invalid-name +PRNGKey = Any # pylint: disable=invalid-name +Shape = Tuple[int, ...] +FloatingDType = Type[jnp.floating] +GenericDType = Type[np.generic] +InexactDType = Type[jnp.inexact] +NumericDType = Type[jnp.number] +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] + + +def canonicalize_inexact_dtypes( + input_dtype: InexactDType, + param_dtype: Optional[InexactDType], + computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, + InexactDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.inexact) + return returned_param_dtype, dtype + + +def canonicalize_numeric_dtypes( + input_dtype: NumericDType, + param_dtype: Optional[NumericDType], + computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, + NumericDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.number) + return returned_param_dtype, dtype + diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 74a590ac2..ded3b5196 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -16,25 +16,15 @@ import abc from dataclasses import field +from typing import Iterable, List, Optional, Sequence, Tuple, Union -from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, - Union) - -from flax.linen.module import Module, compact -from flax.linen.initializers import lecun_normal, variance_scaling, zeros - -from jax import lax -from jax import eval_shape -from jax import ShapedArray import jax.numpy as jnp import numpy as np +from jax import ShapedArray, eval_shape, lax - -PRNGKey = Any -Shape = Tuple[int, ...] -Dtype = Any # this could be a real type? -Array = Any - +from .dtypes import Array, GenericDType, InexactDType, Initializer, NumericDType +from .initializers import lecun_normal, variance_scaling, zeros +from .module import Module, compact default_kernel_init = lecun_normal() @@ -71,10 +61,10 @@ class DenseGeneral(Module): axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros precision: Optional[lax.Precision] = None @compact @@ -87,6 +77,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) @@ -96,15 +90,13 @@ def __call__(self, inputs: Array) -> Array: raise ValueError('batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims)) - inputs = jnp.asarray(inputs, self.dtype) - ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) - def kernel_init_wrap(rng, shape, dtype=jnp.float32): + def kernel_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), np.prod(shape[-n_features:]),) @@ -119,8 +111,8 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): for ax in range(inputs.ndim) if ax not in axis) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape, - self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + param_dtype) + kernel = jnp.asarray(kernel, dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) @@ -130,7 +122,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): precision=self.precision) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: - def bias_init_wrap(rng, shape, dtype=jnp.float32): + def bias_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype) @@ -138,10 +130,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return jnp.reshape(bias, shape) bias = self.param('bias', bias_init_wrap, batch_shape + features, - self.param_dtype) + param_dtype) + bias = jnp.asarray(bias, dtype) # expand bias shape to broadcast bias over batch dims. bias = jnp.reshape(bias, expanded_batch_shape + features) - bias = jnp.asarray(bias, self.dtype) out = out + bias return out @@ -161,11 +153,11 @@ class Dense(Module): """ features: int use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -177,19 +169,22 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features), - self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + param_dtype) + kernel = jnp.asarray(kernel, dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + param_dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -218,7 +213,8 @@ class _Conv(Module): high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` (default: 1). + dilation factor to apply in each spatial dimension of `inputs` (default: + 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the @@ -243,11 +239,11 @@ class _Conv(Module): kernel_dilation: Union[None, int, Sequence[int]] = 1 feature_group_count: int = 1 use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = jnp.float32 + param_dtype: Optional[NumericDType] = jnp.float32 precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @property @abc.abstractmethod @@ -276,8 +272,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): raise TypeError('The kernel size must be specified as a' @@ -350,8 +348,8 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( kernel_shape = conv_output_shape[1:-1] + ( np.prod(kernel_size) * in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) if self.shared_weights: y = lax.conv_general_dilated( @@ -386,8 +384,8 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( # One bias weight per output entry, unshared betwen pixels. bias_shape = y.shape[1:] - bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, bias_shape, param_dtype) + bias = jnp.asarray(bias, dtype) bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) y += bias @@ -443,11 +441,11 @@ class ConvTranspose(Module): padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME' kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = jnp.float32 + param_dtype: Optional[NumericDType] = jnp.float32 precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -464,7 +462,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) kernel_size: Tuple[int, ...] if isinstance(self.kernel_size, int): @@ -482,8 +483,8 @@ def __call__(self, inputs: Array) -> Array: in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) padding_lax: Union[str, Sequence[Tuple[int, int]]] if self.padding == 'CIRCULAR': @@ -532,8 +533,8 @@ def __call__(self, inputs: Array) -> Array: if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, (self.features,), param_dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -555,9 +556,9 @@ class Embed(Module): """ num_embeddings: int features: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - embedding_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_embed_init + dtype: Optional[GenericDType] = jnp.float32 + param_dtype: GenericDType = jnp.float32 + embedding_init: Initializer = default_embed_init embedding: Array = field(init=False) @@ -577,10 +578,11 @@ def __call__(self, inputs: Array) -> Array: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ + dtype = self.param_dtype if self.dtype is None else self.dtype if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not work correctly. - embedding = jnp.asarray(self.embedding, self.dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.take(embedding, inputs, axis=0) def attend(self, query: Array) -> Array: @@ -595,6 +597,7 @@ def attend(self, query: Array) -> Array: Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ - query = jnp.asarray(query, self.dtype) - embedding = jnp.asarray(self.embedding, self.dtype) + dtype = self.param_dtype if self.dtype is None else self.dtype + query = jnp.asarray(query, dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.dot(query, embedding.T) diff --git a/flax/linen/module.py b/flax/linen/module.py index 1aae88c5a..86e250ef9 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -13,7 +13,6 @@ # limitations under the License. """Flax Modules.""" -from contextlib import contextmanager import dataclasses import enum import functools @@ -23,33 +22,30 @@ import types import typing import weakref - -from typing import (Any, Callable, Sequence, Iterable, List, Optional, Tuple, - Set, Type, Union, TypeVar, Generic, Dict, overload) +from contextlib import contextmanager +from typing import (Any, Callable, Dict, Generic, Iterable, List, Optional, + Sequence, Set, Tuple, Type, TypeVar, Union, overload) import jax +import numpy as np from jax import tree_util from jax._src.numpy.lax_numpy import isin -import numpy as np import flax -from flax import config -from flax import errors -from flax import traceback_util -from flax import traverse_util -from flax import serialization -from flax import core +from flax import (config, core, errors, serialization, traceback_util, + traverse_util) from flax.core import Scope -from flax.core.scope import CollectionFilter, DenyList, Variable, VariableDict, FrozenVariableDict, union_filters from flax.core.frozen_dict import FrozenDict, freeze +from flax.core.scope import (CollectionFilter, DenyList, FrozenVariableDict, + Variable, VariableDict, union_filters) from flax.struct import __dataclass_transform__ +from .dtypes import Array, PRNGKey + # from .dotgetter import DotGetter traceback_util.register_exclusion(__file__) -PRNGKey = Any # pylint: disable=invalid-name RNGSequences = Dict[str, PRNGKey] -Array = Any # pylint: disable=invalid-name T = TypeVar('T') @@ -603,7 +599,8 @@ def _wrap_module_methods(cls): wrapped_method = wrap_method_once(method) if key != 'setup': # We import named_call at runtime to avoid a circular import issue. - from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top + from flax.linen.transforms import \ + named_call # pylint: disable=g-import-not-at-top wrapped_method = named_call(wrapped_method, force=False) setattr(cls, key, wrapped_method) return cls diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index fe4d2e097..f22ef5010 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,21 +14,17 @@ """Normalization modules for Flax.""" -from typing import (Any, Callable, Optional, Tuple, Iterable, Union) +from typing import Any, Iterable, Optional, Tuple, Union +import jax.numpy as jnp from jax import lax from jax.nn import initializers -import jax.numpy as jnp - -from flax.linen.module import Module, compact, merge_param +from .dtypes import (Array, InexactDType, Initializer, + canonicalize_inexact_dtypes) +from .module import Module, compact, merge_param -PRNGKey = Any -Array = Any -Shape = Tuple[int, ...] -Dtype = Any # this could be a real type? - -Axes = Union[int, Iterable[int]] +Axes = Union[int, Tuple[int, ...]] def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: @@ -47,7 +43,7 @@ def _abs_sq(x): def _compute_stats(x: Array, axes: Axes, axis_name: Optional[str] = None, - axis_index_groups: Any = None): + axis_index_groups: Any = None) -> Tuple[Array, Array]: """Computes mean and variance statistics. This implementation takes care of a few important details: @@ -79,15 +75,19 @@ def _compute_stats(x: Array, axes: Axes, def _normalize(mdl: Module, x: Array, mean: Array, var: Array, reduction_axes: Axes, feature_axes: Axes, - dtype: Dtype, param_dtype: Dtype, + dtype: InexactDType, param_dtype: InexactDType, epsilon: float, use_bias: bool, use_scale: bool, - bias_init: Callable[[PRNGKey, Shape, Dtype], Array], - scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): - """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + bias_init: Initializer, + scale_init: Initializer): + """"Normalizes the input of a normalization layer and optionally applies a + learned scale and bias. A seperate bias and scale is learned for each feature as specified by feature_axes. """ + input_dtype = jnp.result_type(x, mean, var) + param_dtype, dtype = canonicalize_inexact_dtypes(input_dtype, param_dtype, + dtype) reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) @@ -100,16 +100,21 @@ def _normalize(mdl: Module, x: Array, mean: Array, var: Array, for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) + x = jnp.asarray(x, dtype) + mean = jnp.asarray(mean, dtype) + var = jnp.asarray(var, dtype) y = x - mean mul = lax.rsqrt(var + epsilon) if use_scale: scale = mdl.param('scale', scale_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + scale = jnp.asarray(scale, dtype) mul *= scale y *= mul if use_bias: bias = mdl.param('bias', bias_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + bias = jnp.asarray(bias, dtype) y += bias return jnp.asarray(y, dtype) @@ -171,17 +176,19 @@ class BatchNorm(Module): axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones axis_name: Optional[str] = None axis_index_groups: Any = None @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, + x: Array, + use_running_average: Optional[bool] = None) -> Array: """Normalizes the input using batch statistics. NOTE: @@ -199,6 +206,9 @@ def __call__(self, x, use_running_average: Optional[bool] = None): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) use_running_average = merge_param( 'use_running_average', self.use_running_average, use_running_average) @@ -210,10 +220,10 @@ def __call__(self, x, use_running_average: Optional[bool] = None): initializing = self.is_mutable_collection('params') ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, jnp.float32), + lambda s: jnp.zeros(s, dtype), feature_shape) ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, jnp.float32), + lambda s: jnp.ones(s, dtype), feature_shape) if use_running_average: @@ -225,14 +235,14 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_index_groups=self.axis_index_groups) if not initializing: - ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean + ra_mean.value = (self.momentum * ra_mean.value + (1 - self.momentum) * + mean) ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class LayerNorm(Module): @@ -256,12 +266,12 @@ class LayerNorm(Module): scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -273,6 +283,9 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) reduction_axes = (-1,) feature_axes = (-1,) @@ -280,10 +293,9 @@ def __call__(self, x): mean, var = _compute_stats(x, reduction_axes, None, None) return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class GroupNorm(Module): @@ -302,23 +314,24 @@ class GroupNorm(Module): group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + param_dtype: the dtype passed to parameter initializers (default: + float32). use_bias: If True, bias (beta) is added. - use_scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done - by the next layer. + use_scale: If True, multiply by scale (gamma). When the next layer is + linear (also e.g. nn.relu), this can be disabled since the scaling will + be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. """ num_groups: Optional[int] = 32 group_size: Optional[int] = None epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -332,7 +345,10 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - reduction_axes = list(range(1, x.ndim - 1)) + [-1] + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) + reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) feature_axes = (-1,) if ((self.num_groups is None and self.group_size is None) or @@ -367,7 +383,6 @@ def broadcast_stat(stat): var = broadcast_stat(var) return _normalize( - self, x, mean, var, reduction_axes[:-1], feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, + self, x, mean, var, reduction_axes[:-1], feature_axes, dtype, + param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index ade9f7006..eb91313c3 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -25,24 +25,20 @@ import abc from functools import partial -from typing import (Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, - Type, Union) +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union -from flax.linen.module import Module, compact -from flax.linen.activation import sigmoid, tanh -from flax.linen.initializers import orthogonal, zeros -from flax.linen.linear import Conv, Dense, default_kernel_init - -from jax import numpy as jnp +import numpy as np from jax import lax +from jax import numpy as jnp from jax import random -import numpy as np +from .activation import sigmoid, tanh +from .dtypes import Array, PRNGKey, Shape +from .initializers import orthogonal, zeros +from .linear import Conv, Dense, default_kernel_init +from .module import Module, compact -PRNGKey = Any -Shape = Tuple[int] -Dtype = Any # this could be a real type? -Array = Any +Dtype = Any class RNNCellBase(Module): diff --git a/setup.py b/setup.py index dffe46fdb..a1903f8a3 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "matplotlib", # only needed for tensorboard export "msgpack", "optax", + "typing_extensions>=3.10", ] tests_require = [ diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 890a53a49..a74114a44 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -657,8 +657,8 @@ def __call__(self, x): # attributes features = 3 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -667,8 +667,8 @@ def __call__(self, x): # attributes features = 2 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros