From 89a7b4f045d9002ae70554a2224df38359151388 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Mon, 24 Jan 2022 17:32:34 -0500 Subject: [PATCH 1/8] Remove float32 dtype assumption * Infer dtypes from inputs where possible. * LSTM dtype assumption persists; this is repaired in a separate pull request. --- CHANGELOG.md | 16 ++-- flax/linen/activation.py | 7 +- flax/linen/attention.py | 152 +++++++++++++++++------------ flax/linen/linear.py | 158 ++++++++++++++++++++----------- flax/linen/normalization.py | 105 ++++++++++++-------- flax/linen/recurrent.py | 4 +- setup.py | 1 + tests/linen/linen_module_test.py | 8 +- 8 files changed, 278 insertions(+), 173 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b58c2be91..6d6ed4b80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ vNext - Added Optax update guide and deprecated `flax.optim`. - Added `sep` argument to `flax.traverse_util.flatten_dict()`. - +- Remove float32 dtype assumption throughout linen (except for reccurent + modules). - - Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`. - @@ -48,7 +50,7 @@ Breaking changes: - You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not - ambigious when the kernel rank is known. + ambigious when the kernel rank is known. - `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call. - NamedTuples are no longer converted to tuples on assignment to a `linen.Module`. @@ -64,7 +66,7 @@ Bugfixes: - Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)). - Mixed precision training with float16 now works correctly with the attention layers. - auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes. - + 0.3.4 @@ -72,7 +74,7 @@ Bugfixes: Possibly breaking changes: - When calling `init` the 'intermediates' collection is no longer mutable. - Therefore, intermediates will no longer be returned from initialization by default. + Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization. - When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`. @@ -105,9 +107,9 @@ Possible breaking changes: latest checkpoint already saved. - MultiOptimizer now rejects the case where multiple sub optimizers update the same parameter. - + Other changes: - - Added custom error classes to many Linen errors. See: + - Added custom error classes to many Linen errors. See: https://flax.readthedocs.io/en/latest/flax.errors.html - Adds `Module.bind` for binding variables and RNGs to an interactive Module. - Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument. @@ -127,7 +129,7 @@ NOTE: You must now explicitly import `flax.nn` if you want to use the old 0.3.1 ------ -Many improvements to Linen, and the old `flax.nn` is officially reprecated! +Many improvements to Linen, and the old `flax.nn` is officially reprecated! Notably, there's a clean API for extracting intermediates from modules defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules @@ -141,7 +143,7 @@ Possible breaking changes: is enforced by raising a TypeError in `__setattr__` after `setup`. - Pytrees of dicts and lists are transformed into FrozenDict and tuples during attribute assignment. - This avoids undetected submodules and inner state. + This avoids undetected submodules and inner state. - Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list. diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 268123ecc..2514f2eda 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -67,8 +67,11 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + dtype = inputs.dtype + assert jnp.issubdtype(dtype, jnp.floating) negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, jnp.float32) + lambda k: jnp.asarray(self.negative_slope_init, dtype) ) - return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs) + assert negative_slope.shape == () + return jnp.where(inputs >= 0, inputs, negative_slope * inputs) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 749ded6f3..fe1819bd6 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,23 +15,42 @@ """Attention core modules for Flax.""" from functools import partial -from typing import (Any, Callable, Tuple, Optional) +from typing import Any, Callable, Tuple, Type, Optional import jax from jax import lax from jax import random import jax.numpy as jnp import numpy as np +from typing_extensions import Protocol -from flax.linen.linear import default_kernel_init +from flax.linen.initializers import zeros from flax.linen.linear import DenseGeneral +from flax.linen.linear import _canonicalize_dtypes +from flax.linen.linear import default_kernel_init from flax.linen.module import Module, compact, merge_param -from flax.linen.initializers import zeros PRNGKey = Any -Shape = Tuple[int] -Dtype = Any +Shape = Tuple[int, ...] +InexactDType = Type[jnp.inexact] Array = Any +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] + + +class AttentionFunction(Protocol): + @staticmethod + def __call__(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: + ... def dot_product_attention_weights(query: Array, @@ -42,7 +61,7 @@ def dot_product_attention_weights(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, + dtype: InexactDType = jnp.float32, precision: Optional[lax.Precision] = None): """Computes dot-product attention weights given query and key. @@ -109,7 +128,7 @@ def dot_product_attention_weights(query: Array, keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(attn_weights.dtype) / + multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier @@ -125,8 +144,8 @@ def dot_product_attention(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, - precision: Optional[lax.Precision] = None): + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -179,52 +198,27 @@ def dot_product_attention(query: Array, precision=precision) -class MultiHeadDotProductAttention(Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation (default: float32) - param_dtype: the dtype passed to parameter initializers (default: float32). - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - """ +class _BaseMultiHeadDotProductAttention(Module): num_heads: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True dropout_rate: float = 0. deterministic: Optional[bool] = None precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros use_bias: bool = True - attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention + attention_fn: AttentionFunction = dot_product_attention decode: bool = False - @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def _apply(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -246,6 +240,10 @@ def __call__(self, Returns: output of shape `[batch_sizes..., length, features]`. """ + param_dtype, dtype = _canonicalize_dtypes(jnp.result_type(inputs_q, + inputs_kv), + self.param_dtype, + self.dtype) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -254,8 +252,8 @@ def __call__(self, dense = partial(DenseGeneral, axis=-1, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, @@ -323,7 +321,7 @@ def __call__(self, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=m_deterministic, - dtype=self.dtype, + dtype=dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, @@ -331,30 +329,64 @@ def __call__(self, kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, precision=self.precision, name='out')(x) return out -class SelfAttention(MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" +class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention): + """Multi-head dot-product attention. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: float32) + param_dtype: the dtype passed to parameter initializers (default: float32). + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts + query, key, value, and returns output of shape + `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + """ + @compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): + return self._apply(inputs_q, inputs_kv, mask, deterministic=deterministic) + + +class SelfAttention(_BaseMultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" @compact def __call__(self, inputs_q: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None): - return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) + return self._apply(inputs_q, inputs_q, mask, deterministic=deterministic) # mask-making utility functions -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable[..., Any] = jnp.multiply, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32): +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: InexactDType = jnp.float32): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the @@ -381,7 +413,7 @@ def make_attention_mask(query_input: Array, def make_causal_mask(x: Array, extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32) -> Array: + dtype: InexactDType = jnp.float32) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights @@ -403,7 +435,7 @@ def make_causal_mask(x: Array, def combine_masks(*masks: Optional[Array], - dtype: Dtype = jnp.float32) -> Array: + dtype: InexactDType = jnp.float32) -> Array: """Combine attention masks. Args: diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 74a590ac2..2bd2027a2 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -18,7 +18,7 @@ from dataclasses import field from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, - Union) + Type, Union) from flax.linen.module import Module, compact from flax.linen.initializers import lecun_normal, variance_scaling, zeros @@ -32,8 +32,11 @@ PRNGKey = Any Shape = Tuple[int, ...] -Dtype = Any # this could be a real type? +InexactDType = Type[jnp.inexact] +NumericDType = Type[jnp.number] +GenericDType = Type[np.generic] Array = Any +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] default_kernel_init = lecun_normal() @@ -51,6 +54,38 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]: return (x,) +def _canonicalize_dtypes( + input_dtype: InexactDType, + param_dtype: Optional[InexactDType], + computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, + InexactDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.inexact) + if jnp.issubdtype(input_dtype, jnp.complexfloating): + assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) + assert jnp.issubdtype(dtype, jnp.complexfloating) + return returned_param_dtype, dtype + + +def _canonicalize_numeric_dtypes( + input_dtype: NumericDType, + param_dtype: Optional[NumericDType], + computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, + NumericDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.number) + if jnp.issubdtype(input_dtype, jnp.complexfloating): + assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) + assert jnp.issubdtype(dtype, jnp.complexfloating) + return returned_param_dtype, dtype + + class DenseGeneral(Module): """A linear transformation with flexible axes. @@ -60,8 +95,8 @@ class DenseGeneral(Module): (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` @@ -71,10 +106,10 @@ class DenseGeneral(Module): axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros precision: Optional[lax.Precision] = None @compact @@ -87,6 +122,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) @@ -96,15 +134,13 @@ def __call__(self, inputs: Array) -> Array: raise ValueError('batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims)) - inputs = jnp.asarray(inputs, self.dtype) - ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) - def kernel_init_wrap(rng, shape, dtype=jnp.float32): + def kernel_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), np.prod(shape[-n_features:]),) @@ -119,8 +155,8 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): for ax in range(inputs.ndim) if ax not in axis) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape, - self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + param_dtype) + kernel = jnp.asarray(kernel, dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) @@ -130,7 +166,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): precision=self.precision) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: - def bias_init_wrap(rng, shape, dtype=jnp.float32): + def bias_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype) @@ -138,10 +174,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return jnp.reshape(bias, shape) bias = self.param('bias', bias_init_wrap, batch_shape + features, - self.param_dtype) + param_dtype) + bias = jnp.asarray(bias, dtype) # expand bias shape to broadcast bias over batch dims. bias = jnp.reshape(bias, expanded_batch_shape + features) - bias = jnp.asarray(bias, self.dtype) out = out + bias return out @@ -152,8 +188,8 @@ class Dense(Module): Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. @@ -161,11 +197,11 @@ class Dense(Module): """ features: int use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -177,19 +213,21 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features), self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.asarray(kernel, dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -218,7 +256,8 @@ class _Conv(Module): high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` (default: 1). + dilation factor to apply in each spatial dimension of `inputs` (default: + 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the @@ -228,8 +267,8 @@ class _Conv(Module): feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -243,11 +282,11 @@ class _Conv(Module): kernel_dilation: Union[None, int, Sequence[int]] = 1 feature_group_count: int = 1 use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @property @abc.abstractmethod @@ -276,8 +315,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): raise TypeError('The kernel size must be specified as a' @@ -350,8 +391,8 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( kernel_shape = conv_output_shape[1:-1] + ( np.prod(kernel_size) * in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) if self.shared_weights: y = lax.conv_general_dilated( @@ -386,8 +427,8 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( # One bias weight per output entry, unshared betwen pixels. bias_shape = y.shape[1:] - bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, bias_shape, param_dtype) + bias = jnp.asarray(bias, dtype) bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) y += bias @@ -430,8 +471,8 @@ class ConvTranspose(Module): kernel. Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -443,11 +484,11 @@ class ConvTranspose(Module): padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME' kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Optional[lax.Precision] = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -464,7 +505,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) kernel_size: Tuple[int, ...] if isinstance(self.kernel_size, int): @@ -482,8 +526,8 @@ def __call__(self, inputs: Array) -> Array: in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) padding_lax: Union[str, Sequence[Tuple[int, int]]] if self.padding == 'CIRCULAR': @@ -532,8 +576,8 @@ def __call__(self, inputs: Array) -> Array: if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, (self.features,), param_dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -549,15 +593,15 @@ class Embed(Module): Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). + dtype: the dtype of the embedding vectors (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. """ num_embeddings: int features: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - embedding_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_embed_init + dtype: Optional[GenericDType] = None + param_dtype: GenericDType = jnp.float32 + embedding_init: Initializer = default_embed_init embedding: Array = field(init=False) @@ -577,10 +621,11 @@ def __call__(self, inputs: Array) -> Array: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ + dtype = self.param_dtype if self.dtype is None else self.dtype if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not work correctly. - embedding = jnp.asarray(self.embedding, self.dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.take(embedding, inputs, axis=0) def attend(self, query: Array) -> Array: @@ -595,6 +640,7 @@ def attend(self, query: Array) -> Array: Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ - query = jnp.asarray(query, self.dtype) - embedding = jnp.asarray(self.embedding, self.dtype) + dtype = self.param_dtype if self.dtype is None else self.dtype + query = jnp.asarray(query, dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.dot(query, embedding.T) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index fe4d2e097..bde9b13e7 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,21 +14,24 @@ """Normalization modules for Flax.""" -from typing import (Any, Callable, Optional, Tuple, Iterable, Union) +from typing import Any, Callable, Optional, Tuple, Type, Iterable, Union from jax import lax from jax.nn import initializers import jax.numpy as jnp +import numpy as np from flax.linen.module import Module, compact, merge_param +from flax.linen.linear import _canonicalize_dtypes PRNGKey = Any Array = Any Shape = Tuple[int, ...] -Dtype = Any # this could be a real type? +InexactDType = Type[jnp.inexact] +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] -Axes = Union[int, Iterable[int]] +Axes = Union[int, Tuple[int, ...]] def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: @@ -47,7 +50,7 @@ def _abs_sq(x): def _compute_stats(x: Array, axes: Axes, axis_name: Optional[str] = None, - axis_index_groups: Any = None): + axis_index_groups: Any = None) -> Tuple[Array, Array]: """Computes mean and variance statistics. This implementation takes care of a few important details: @@ -79,15 +82,18 @@ def _compute_stats(x: Array, axes: Axes, def _normalize(mdl: Module, x: Array, mean: Array, var: Array, reduction_axes: Axes, feature_axes: Axes, - dtype: Dtype, param_dtype: Dtype, + dtype: InexactDType, param_dtype: InexactDType, epsilon: float, use_bias: bool, use_scale: bool, - bias_init: Callable[[PRNGKey, Shape, Dtype], Array], - scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): - """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + bias_init: Initializer, + scale_init: Initializer): + """"Normalizes the input of a normalization layer and optionally applies a + learned scale and bias. A seperate bias and scale is learned for each feature as specified by feature_axes. """ + input_dtype = jnp.result_type(x, mean, var) + param_dtype, dtype = _canonicalize_dtypes(input_dtype, param_dtype, dtype) reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) @@ -100,16 +106,21 @@ def _normalize(mdl: Module, x: Array, mean: Array, var: Array, for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) + x = jnp.asarray(x, dtype) + mean = jnp.asarray(mean, dtype) + var = jnp.asarray(var, dtype) y = x - mean mul = lax.rsqrt(var + epsilon) if use_scale: scale = mdl.param('scale', scale_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + scale = jnp.asarray(scale, dtype) mul *= scale y *= mul if use_bias: bias = mdl.param('bias', bias_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + bias = jnp.asarray(bias, dtype) y += bias return jnp.asarray(y, dtype) @@ -151,8 +162,8 @@ class BatchNorm(Module): momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled @@ -171,17 +182,19 @@ class BatchNorm(Module): axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones axis_name: Optional[str] = None axis_index_groups: Any = None @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, + x: Array, + use_running_average: Optional[bool] = None) -> Array: """Normalizes the input using batch statistics. NOTE: @@ -199,6 +212,9 @@ def __call__(self, x, use_running_average: Optional[bool] = None): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) use_running_average = merge_param( 'use_running_average', self.use_running_average, use_running_average) @@ -210,10 +226,10 @@ def __call__(self, x, use_running_average: Optional[bool] = None): initializing = self.is_mutable_collection('params') ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, jnp.float32), + lambda s: jnp.zeros(s, dtype), feature_shape) ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, jnp.float32), + lambda s: jnp.ones(s, dtype), feature_shape) if use_running_average: @@ -225,14 +241,14 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_index_groups=self.axis_index_groups) if not initializing: - ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean + ra_mean.value = (self.momentum * ra_mean.value + (1 - self.momentum) * + mean) ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class LayerNorm(Module): @@ -246,8 +262,8 @@ class LayerNorm(Module): Attributes: epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -256,12 +272,12 @@ class LayerNorm(Module): scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -273,6 +289,9 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) reduction_axes = (-1,) feature_axes = (-1,) @@ -280,10 +299,9 @@ def __call__(self, x): mean, var = _compute_stats(x, reduction_axes, None, None) return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class GroupNorm(Module): @@ -301,8 +319,9 @@ class GroupNorm(Module): proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: + None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -313,12 +332,12 @@ class GroupNorm(Module): num_groups: Optional[int] = 32 group_size: Optional[int] = None epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -332,7 +351,10 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - reduction_axes = list(range(1, x.ndim - 1)) + [-1] + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) + reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) feature_axes = (-1,) if ((self.num_groups is None and self.group_size is None) or @@ -367,7 +389,6 @@ def broadcast_stat(stat): var = broadcast_stat(var) return _normalize( - self, x, mean, var, reduction_axes[:-1], feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, + self, x, mean, var, reduction_axes[:-1], feature_axes, dtype, + param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index ade9f7006..a33a6c619 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -40,8 +40,8 @@ import numpy as np PRNGKey = Any -Shape = Tuple[int] -Dtype = Any # this could be a real type? +Shape = Tuple[int, ...] +Dtype = Any Array = Any diff --git a/setup.py b/setup.py index dffe46fdb..a1903f8a3 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "matplotlib", # only needed for tensorboard export "msgpack", "optax", + "typing_extensions>=3.10", ] tests_require = [ diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 890a53a49..a74114a44 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -657,8 +657,8 @@ def __call__(self, x): # attributes features = 3 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -667,8 +667,8 @@ def __call__(self, x): # attributes features = 2 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros From f5f625c24354d7ccef14e23b499a9a11f41cc177 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 12:12:53 -0500 Subject: [PATCH 2/8] Revert Liskov error fix --- flax/linen/attention.py | 76 ++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index fe1819bd6..0e91dfbd6 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -198,7 +198,31 @@ def dot_product_attention(query: Array, precision=precision) -class _BaseMultiHeadDotProductAttention(Module): +class MultiHeadDotProductAttention(Module): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: float32) + param_dtype: the dtype passed to parameter initializers (default: float32). + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts + query, key, value, and returns output of shape + `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + """ num_heads: int dtype: Optional[InexactDType] = None param_dtype: Optional[InexactDType] = None @@ -214,11 +238,12 @@ class _BaseMultiHeadDotProductAttention(Module): attention_fn: AttentionFunction = dot_product_attention decode: bool = False - def _apply(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + @compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -336,46 +361,13 @@ def _apply(self, return out -class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation (default: float32) - param_dtype: the dtype passed to parameter initializers (default: float32). - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - """ - @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): - return self._apply(inputs_q, inputs_kv, mask, deterministic=deterministic) - - -class SelfAttention(_BaseMultiHeadDotProductAttention): +class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" + @compact def __call__(self, inputs_q: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None): - return self._apply(inputs_q, inputs_q, mask, deterministic=deterministic) + return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) # mask-making utility functions From dd42ae248b6d37ec57f8faa4c54ee5d63a4af0cd Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 14:16:52 -0500 Subject: [PATCH 3/8] Add param_dtype to PReLU --- flax/linen/activation.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 2514f2eda..65abdc047 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -43,10 +43,12 @@ from typing import Any +from flax.linen.linear import _canonicalize_dtypes from flax.linen.module import Module, compact import jax.numpy as jnp +FloatingDType = Type[jnp.floating] Array = Any @@ -54,9 +56,15 @@ class PReLU(Module): """Parametric Rectified Linear Unit (PReLU) activation function. Attributes: - negative_slope_init: the value to initialize the negative slope. + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). + negative_slope_init: the value to initialize the negative slope + (default 0.01). """ + dtype: Optional[FloatingDType] = jnp.float32 + param_dtype: Optional[FloatingDType] = jnp.float32 negative_slope_init: float = 0.01 + @compact def __call__(self, inputs: Array) -> Array: """Applies an activation to the inputs. @@ -67,11 +75,13 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - dtype = inputs.dtype - assert jnp.issubdtype(dtype, jnp.floating) + assert jnp.issubdtype(inputs.dtype, jnp.floating) + inputs = jnp.asarray(inputs, dtype) + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, param_dtype, + self.dtype) negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, dtype) + lambda k: jnp.asarray(self.negative_slope_init, param_dtype) ) - assert negative_slope.shape == () + negative_slope = jnp.asarray(negative_slope, dtype) return jnp.where(inputs >= 0, inputs, negative_slope * inputs) From 4ec2e81149220e0ab1c291223be57ac1424afdb5 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 14:17:17 -0500 Subject: [PATCH 4/8] Correct typo --- flax/linen/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 2bd2027a2..234b729e4 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -219,14 +219,14 @@ def __call__(self, inputs: Array) -> Array: kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features), - self.param_dtype) + param_dtype) kernel = jnp.asarray(kernel, dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) + param_dtype) bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y From 9d3bd2ab5757b02270f75ff850c961e805642e43 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 14:52:56 -0500 Subject: [PATCH 5/8] Remove assertions --- flax/linen/linear.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 234b729e4..6ee9f5e3e 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -64,9 +64,6 @@ def _canonicalize_dtypes( if computation_dtype is None else computation_dtype) assert jnp.issubdtype(input_dtype, jnp.inexact) - if jnp.issubdtype(input_dtype, jnp.complexfloating): - assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) - assert jnp.issubdtype(dtype, jnp.complexfloating) return returned_param_dtype, dtype @@ -80,9 +77,6 @@ def _canonicalize_numeric_dtypes( if computation_dtype is None else computation_dtype) assert jnp.issubdtype(input_dtype, jnp.number) - if jnp.issubdtype(input_dtype, jnp.complexfloating): - assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) - assert jnp.issubdtype(dtype, jnp.complexfloating) return returned_param_dtype, dtype From b96c2d19cdd139f776c1af8ea78cec3ee7363cc4 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 14:55:44 -0500 Subject: [PATCH 6/8] Correct dtype defaults as per FLIP --- flax/linen/linear.py | 36 ++++++++++++++++++------------------ flax/linen/normalization.py | 30 +++++++++++++++--------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 6ee9f5e3e..2d87a6b61 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -89,8 +89,8 @@ class DenseGeneral(Module): (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` @@ -100,8 +100,8 @@ class DenseGeneral(Module): axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () use_bias: bool = True - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros precision: Optional[lax.Precision] = None @@ -182,8 +182,8 @@ class Dense(Module): Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. @@ -191,8 +191,8 @@ class Dense(Module): """ features: int use_bias: bool = True - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 precision: Optional[lax.Precision] = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros @@ -261,8 +261,8 @@ class _Conv(Module): feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -276,8 +276,8 @@ class _Conv(Module): kernel_dilation: Union[None, int, Sequence[int]] = 1 feature_group_count: int = 1 use_bias: bool = True - dtype: Optional[NumericDType] = None - param_dtype: Optional[NumericDType] = None + dtype: Optional[NumericDType] = jnp.float32 + param_dtype: Optional[NumericDType] = jnp.float32 precision: Optional[lax.Precision] = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros @@ -465,8 +465,8 @@ class ConvTranspose(Module): kernel. Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -478,8 +478,8 @@ class ConvTranspose(Module): padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME' kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True - dtype: Optional[NumericDType] = None - param_dtype: Optional[NumericDType] = None + dtype: Optional[NumericDType] = jnp.float32 + param_dtype: Optional[NumericDType] = jnp.float32 precision: Optional[lax.Precision] = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros @@ -587,13 +587,13 @@ class Embed(Module): Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: None). + dtype: the dtype of the embedding vectors (default: float32). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. """ num_embeddings: int features: int - dtype: Optional[GenericDType] = None + dtype: Optional[GenericDType] = jnp.float32 param_dtype: GenericDType = jnp.float32 embedding_init: Initializer = default_embed_init diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index bde9b13e7..19ba0ebaa 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -162,8 +162,8 @@ class BatchNorm(Module): momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled @@ -182,8 +182,8 @@ class BatchNorm(Module): axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros @@ -262,8 +262,8 @@ class LayerNorm(Module): Attributes: epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: None). - param_dtype: the dtype passed to parameter initializers (default: None). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -272,8 +272,8 @@ class LayerNorm(Module): scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros @@ -319,21 +319,21 @@ class GroupNorm(Module): proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: None). + dtype: the dtype of the computation (default: float32). param_dtype: the dtype passed to parameter initializers (default: - None). + float32). use_bias: If True, bias (beta) is added. - use_scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done - by the next layer. + use_scale: If True, multiply by scale (gamma). When the next layer is + linear (also e.g. nn.relu), this can be disabled since the scaling will + be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. """ num_groups: Optional[int] = 32 group_size: Optional[int] = None epsilon: float = 1e-6 - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[InexactDType] = jnp.float32 + param_dtype: Optional[InexactDType] = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros From 4f2927dfb82b030174acb5458639d24b7caf5791 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 4 Mar 2022 15:04:16 -0500 Subject: [PATCH 7/8] Expose canonicalize_inexact_dtypes, canonicalize_numeric_dtypes --- flax/linen/__init__.py | 3 ++- flax/linen/activation.py | 6 +++--- flax/linen/attention.py | 10 +++++----- flax/linen/linear.py | 18 ++++++++++-------- flax/linen/normalization.py | 17 +++++++++-------- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 2d57f9621..3135d4d8e 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -24,7 +24,8 @@ dot_product_attention, dot_product_attention_weights, make_attention_mask, make_causal_mask, combine_masks) from ..core import broadcast, DenyList, FrozenDict -from .linear import Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed +from .linear import (Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed + canonicalize_inexact_dtypes, canonicalize_numeric_dtypes) from .module import (Module, compact, nowrap, enable_named_call, disable_named_call, override_named_call, Variable, init, init_with_output, apply, merge_param) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 65abdc047..56ebaed07 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -43,7 +43,7 @@ from typing import Any -from flax.linen.linear import _canonicalize_dtypes +from flax.linen.linear import canonicalize_inexact_dtypes from flax.linen.module import Module, compact import jax.numpy as jnp @@ -77,8 +77,8 @@ def __call__(self, inputs: Array) -> Array: """ assert jnp.issubdtype(inputs.dtype, jnp.floating) inputs = jnp.asarray(inputs, dtype) - param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, param_dtype, + self.dtype) negative_slope = self.param( 'negative_slope', lambda k: jnp.asarray(self.negative_slope_init, param_dtype) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 0e91dfbd6..2e244c487 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -26,7 +26,7 @@ from flax.linen.initializers import zeros from flax.linen.linear import DenseGeneral -from flax.linen.linear import _canonicalize_dtypes +from flax.linen.linear import canonicalize_inexact_dtypes from flax.linen.linear import default_kernel_init from flax.linen.module import Module, compact, merge_param @@ -265,10 +265,10 @@ def __call__(self, Returns: output of shape `[batch_sizes..., length, features]`. """ - param_dtype, dtype = _canonicalize_dtypes(jnp.result_type(inputs_q, - inputs_kv), - self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(jnp.result_type(inputs_q, + inputs_kv), + self.param_dtype, + self.dtype) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 2d87a6b61..4f1ed41f9 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -54,7 +54,7 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]: return (x,) -def _canonicalize_dtypes( +def canonicalize_inexact_dtypes( input_dtype: InexactDType, param_dtype: Optional[InexactDType], computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, @@ -67,7 +67,7 @@ def _canonicalize_dtypes( return returned_param_dtype, dtype -def _canonicalize_numeric_dtypes( +def canonicalize_numeric_dtypes( input_dtype: NumericDType, param_dtype: Optional[NumericDType], computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, @@ -116,8 +116,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) inputs = jnp.asarray(inputs, dtype) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -207,8 +208,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) inputs = jnp.asarray(inputs, dtype) kernel = self.param('kernel', self.kernel_init, @@ -309,7 +311,7 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + param_dtype, dtype = canonicalize_numeric_dtypes(inputs.dtype, self.param_dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) @@ -499,7 +501,7 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + param_dtype, dtype = canonicalize_numeric_dtypes(inputs.dtype, self.param_dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 19ba0ebaa..835356d4a 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -22,7 +22,7 @@ import numpy as np from flax.linen.module import Module, compact, merge_param -from flax.linen.linear import _canonicalize_dtypes +from flax.linen.linear import canonicalize_inexact_dtypes PRNGKey = Any @@ -93,7 +93,8 @@ def _normalize(mdl: Module, x: Array, mean: Array, var: Array, A seperate bias and scale is learned for each feature as specified by feature_axes. """ input_dtype = jnp.result_type(x, mean, var) - param_dtype, dtype = _canonicalize_dtypes(input_dtype, param_dtype, dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(input_dtype, param_dtype, + dtype) reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) @@ -212,8 +213,8 @@ def __call__(self, Returns: Normalized inputs (the same shape as inputs). """ - param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) x = jnp.asarray(x, dtype) use_running_average = merge_param( @@ -289,8 +290,8 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) x = jnp.asarray(x, dtype) reduction_axes = (-1,) feature_axes = (-1,) @@ -351,8 +352,8 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = canonicalize_inexact_dtypes(x.dtype, self.param_dtype, + self.dtype) x = jnp.asarray(x, dtype) reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) feature_axes = (-1,) From f451757a5043244dd56fdd02d64ecda8683d1c42 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Mon, 7 Mar 2022 12:06:21 -0500 Subject: [PATCH 8/8] Factor out dtypes.py --- flax/linen/__init__.py | 30 ++++++++++--------- flax/linen/activation.py | 39 ++++++------------------- flax/linen/attention.py | 22 +++++--------- flax/linen/dtypes.py | 57 +++++++++++++++++++++++++++++++++++++ flax/linen/linear.py | 49 ++++--------------------------- flax/linen/module.py | 27 ++++++++---------- flax/linen/normalization.py | 17 ++++------- flax/linen/recurrent.py | 20 ++++++------- 8 files changed, 118 insertions(+), 143 deletions(-) create mode 100644 flax/linen/dtypes.py diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 3135d4d8e..c95d7c43b 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -17,23 +17,25 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions -from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, - log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh, PReLU) +from .activation import (PReLU, celu, elu, gelu, glu, leaky_relu, log_sigmoid, + log_softmax, relu, sigmoid, silu, soft_sign, softmax, + softplus, swish, tanh) from .attention import (MultiHeadDotProductAttention, SelfAttention, - dot_product_attention, dot_product_attention_weights, - make_attention_mask, make_causal_mask, combine_masks) -from ..core import broadcast, DenyList, FrozenDict -from .linear import (Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed - canonicalize_inexact_dtypes, canonicalize_numeric_dtypes) -from .module import (Module, compact, nowrap, enable_named_call, - disable_named_call, override_named_call, Variable, init, - init_with_output, apply, merge_param) + combine_masks, dot_product_attention, + dot_product_attention_weights, make_attention_mask, + make_causal_mask) +from ..core import DenyList, FrozenDict, broadcast +from .dtypes import canonicalize_inexact_dtypes, canonicalize_numeric_dtypes +from .initializers import ones, zeros +from .linear import Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed +from .module import (Module, Variable, apply, compact, disable_named_call, + enable_named_call, init, init_with_output, merge_param, + nowrap, override_named_call) from .normalization import BatchNorm, GroupNorm, LayerNorm from .pooling import avg_pool, max_pool, pool -from .recurrent import GRUCell, LSTMCell, ConvLSTM, OptimizedLSTMCell +from .recurrent import ConvLSTM, GRUCell, LSTMCell, OptimizedLSTMCell from .stochastic import Dropout -from .transforms import jit, named_call, checkpoint, remat, remat_scan, scan, vmap, map_variables, vjp, jvp, custom_vjp -from .initializers import zeros, ones +from .transforms import (checkpoint, custom_vjp, jit, jvp, map_variables, + named_call, remat, remat_scan, scan, vjp, vmap) # pylint: enable=g-multiple-import diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 56ebaed07..f640b08c8 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -15,41 +15,18 @@ """Activation functions. """ +import jax.numpy as jnp # pylint: disable=unused-import -# re-export activation functions from jax.nn -from jax.nn import celu -from jax.nn import elu -from jax.nn import gelu -from jax.nn import glu -from jax.nn import leaky_relu -from jax.nn import log_sigmoid -from jax.nn import log_softmax -from jax.nn import normalize -from jax.nn import relu -from jax.nn import sigmoid -from jax.nn import soft_sign -from jax.nn import softmax -from jax.nn import softplus -from jax.nn import swish -from jax.nn import silu -from jax.nn import selu -from jax.nn import hard_tanh -from jax.nn import relu6 -from jax.nn import hard_sigmoid -from jax.nn import hard_swish - +# re-export activation functions from jax.nn and jax.numpy +from jax.nn import (celu, elu, gelu, glu, hard_sigmoid, hard_swish, hard_tanh, + leaky_relu, log_sigmoid, log_softmax, normalize, relu, + relu6, selu, sigmoid, silu, soft_sign, softmax, softplus, + swish) from jax.numpy import tanh # pylint: enable=unused-import -from typing import Any - -from flax.linen.linear import canonicalize_inexact_dtypes -from flax.linen.module import Module, compact -import jax.numpy as jnp - - -FloatingDType = Type[jnp.floating] -Array = Any +from .dtypes import Array, FloatingDType, canonicalize_inexact_dtypes +from .module import Module, compact class PReLU(Module): diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 2e244c487..5b90ad1e3 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,26 +15,18 @@ """Attention core modules for Flax.""" from functools import partial -from typing import Any, Callable, Tuple, Type, Optional +from typing import Callable, Optional import jax -from jax import lax -from jax import random import jax.numpy as jnp -import numpy as np +from jax import lax, random from typing_extensions import Protocol -from flax.linen.initializers import zeros -from flax.linen.linear import DenseGeneral -from flax.linen.linear import canonicalize_inexact_dtypes -from flax.linen.linear import default_kernel_init -from flax.linen.module import Module, compact, merge_param - -PRNGKey = Any -Shape = Tuple[int, ...] -InexactDType = Type[jnp.inexact] -Array = Any -Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] +from .dtypes import (Array, InexactDType, Initializer, PRNGKey, + canonicalize_inexact_dtypes) +from .initializers import zeros +from .linear import DenseGeneral, default_kernel_init +from .module import Module, compact, merge_param class AttentionFunction(Protocol): diff --git a/flax/linen/dtypes.py b/flax/linen/dtypes.py new file mode 100644 index 000000000..4e4e41ca2 --- /dev/null +++ b/flax/linen/dtypes.py @@ -0,0 +1,57 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with dtypes.""" + +from typing import Any, Callable, Optional, Tuple, Type + +import jax.numpy as jnp +import numpy as np + + +Array = Any # pylint: disable=invalid-name +PRNGKey = Any # pylint: disable=invalid-name +Shape = Tuple[int, ...] +FloatingDType = Type[jnp.floating] +GenericDType = Type[np.generic] +InexactDType = Type[jnp.inexact] +NumericDType = Type[jnp.number] +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] + + +def canonicalize_inexact_dtypes( + input_dtype: InexactDType, + param_dtype: Optional[InexactDType], + computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, + InexactDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.inexact) + return returned_param_dtype, dtype + + +def canonicalize_numeric_dtypes( + input_dtype: NumericDType, + param_dtype: Optional[NumericDType], + computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, + NumericDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert jnp.issubdtype(input_dtype, jnp.number) + return returned_param_dtype, dtype + diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4f1ed41f9..ded3b5196 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -16,28 +16,15 @@ import abc from dataclasses import field +from typing import Iterable, List, Optional, Sequence, Tuple, Union -from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, - Type, Union) - -from flax.linen.module import Module, compact -from flax.linen.initializers import lecun_normal, variance_scaling, zeros - -from jax import lax -from jax import eval_shape -from jax import ShapedArray import jax.numpy as jnp import numpy as np +from jax import ShapedArray, eval_shape, lax - -PRNGKey = Any -Shape = Tuple[int, ...] -InexactDType = Type[jnp.inexact] -NumericDType = Type[jnp.number] -GenericDType = Type[np.generic] -Array = Any -Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] - +from .dtypes import Array, GenericDType, InexactDType, Initializer, NumericDType +from .initializers import lecun_normal, variance_scaling, zeros +from .module import Module, compact default_kernel_init = lecun_normal() @@ -54,32 +41,6 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]: return (x,) -def canonicalize_inexact_dtypes( - input_dtype: InexactDType, - param_dtype: Optional[InexactDType], - computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, - InexactDType]: - returned_param_dtype = input_dtype if param_dtype is None else param_dtype - dtype = (jnp.result_type(input_dtype, returned_param_dtype) - if computation_dtype is None else computation_dtype) - - assert jnp.issubdtype(input_dtype, jnp.inexact) - return returned_param_dtype, dtype - - -def canonicalize_numeric_dtypes( - input_dtype: NumericDType, - param_dtype: Optional[NumericDType], - computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, - NumericDType]: - returned_param_dtype = input_dtype if param_dtype is None else param_dtype - dtype = (jnp.result_type(input_dtype, returned_param_dtype) - if computation_dtype is None else computation_dtype) - - assert jnp.issubdtype(input_dtype, jnp.number) - return returned_param_dtype, dtype - - class DenseGeneral(Module): """A linear transformation with flexible axes. diff --git a/flax/linen/module.py b/flax/linen/module.py index 1aae88c5a..86e250ef9 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -13,7 +13,6 @@ # limitations under the License. """Flax Modules.""" -from contextlib import contextmanager import dataclasses import enum import functools @@ -23,33 +22,30 @@ import types import typing import weakref - -from typing import (Any, Callable, Sequence, Iterable, List, Optional, Tuple, - Set, Type, Union, TypeVar, Generic, Dict, overload) +from contextlib import contextmanager +from typing import (Any, Callable, Dict, Generic, Iterable, List, Optional, + Sequence, Set, Tuple, Type, TypeVar, Union, overload) import jax +import numpy as np from jax import tree_util from jax._src.numpy.lax_numpy import isin -import numpy as np import flax -from flax import config -from flax import errors -from flax import traceback_util -from flax import traverse_util -from flax import serialization -from flax import core +from flax import (config, core, errors, serialization, traceback_util, + traverse_util) from flax.core import Scope -from flax.core.scope import CollectionFilter, DenyList, Variable, VariableDict, FrozenVariableDict, union_filters from flax.core.frozen_dict import FrozenDict, freeze +from flax.core.scope import (CollectionFilter, DenyList, FrozenVariableDict, + Variable, VariableDict, union_filters) from flax.struct import __dataclass_transform__ +from .dtypes import Array, PRNGKey + # from .dotgetter import DotGetter traceback_util.register_exclusion(__file__) -PRNGKey = Any # pylint: disable=invalid-name RNGSequences = Dict[str, PRNGKey] -Array = Any # pylint: disable=invalid-name T = TypeVar('T') @@ -603,7 +599,8 @@ def _wrap_module_methods(cls): wrapped_method = wrap_method_once(method) if key != 'setup': # We import named_call at runtime to avoid a circular import issue. - from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top + from flax.linen.transforms import \ + named_call # pylint: disable=g-import-not-at-top wrapped_method = named_call(wrapped_method, force=False) setattr(cls, key, wrapped_method) return cls diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 835356d4a..f22ef5010 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,22 +14,15 @@ """Normalization modules for Flax.""" -from typing import Any, Callable, Optional, Tuple, Type, Iterable, Union +from typing import Any, Iterable, Optional, Tuple, Union +import jax.numpy as jnp from jax import lax from jax.nn import initializers -import jax.numpy as jnp -import numpy as np - -from flax.linen.module import Module, compact, merge_param -from flax.linen.linear import canonicalize_inexact_dtypes - -PRNGKey = Any -Array = Any -Shape = Tuple[int, ...] -InexactDType = Type[jnp.inexact] -Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] +from .dtypes import (Array, InexactDType, Initializer, + canonicalize_inexact_dtypes) +from .module import Module, compact, merge_param Axes = Union[int, Tuple[int, ...]] diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index a33a6c619..eb91313c3 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -25,24 +25,20 @@ import abc from functools import partial -from typing import (Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, - Type, Union) +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union -from flax.linen.module import Module, compact -from flax.linen.activation import sigmoid, tanh -from flax.linen.initializers import orthogonal, zeros -from flax.linen.linear import Conv, Dense, default_kernel_init - -from jax import numpy as jnp +import numpy as np from jax import lax +from jax import numpy as jnp from jax import random -import numpy as np +from .activation import sigmoid, tanh +from .dtypes import Array, PRNGKey, Shape +from .initializers import orthogonal, zeros +from .linear import Conv, Dense, default_kernel_init +from .module import Module, compact -PRNGKey = Any -Shape = Tuple[int, ...] Dtype = Any -Array = Any class RNNCellBase(Module):