Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ vNext
- Added Optax update guide and deprecated `flax.optim`.
- Added `sep` argument to `flax.traverse_util.flatten_dict()`.
-
- Remove float32 dtype assumption throughout linen (except for reccurent
modules).
-
- Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`.
-
Expand Down Expand Up @@ -48,7 +50,7 @@ Breaking changes:
- You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv.
Instead a type error is raised stating that
a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not
ambigious when the kernel rank is known.
ambigious when the kernel rank is known.
- `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a `linen.Module`.

Expand All @@ -64,15 +66,15 @@ Bugfixes:
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes.



0.3.4
------

Possibly breaking changes:
- When calling `init` the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default.
Therefore, intermediates will no longer be returned from initialization by default.
- Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`.

Expand Down Expand Up @@ -105,9 +107,9 @@ Possible breaking changes:
latest checkpoint already saved.
- MultiOptimizer now rejects the case where multiple sub optimizers update the
same parameter.

Other changes:
- Added custom error classes to many Linen errors. See:
- Added custom error classes to many Linen errors. See:
https://flax.readthedocs.io/en/latest/flax.errors.html
- Adds `Module.bind` for binding variables and RNGs to an interactive Module.
- Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument.
Expand All @@ -127,7 +129,7 @@ NOTE: You must now explicitly import `flax.nn` if you want to use the old
0.3.1
------

Many improvements to Linen, and the old `flax.nn` is officially reprecated!
Many improvements to Linen, and the old `flax.nn` is officially reprecated!

Notably, there's a clean API for extracting intermediates from modules
defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules
Expand All @@ -141,7 +143,7 @@ Possible breaking changes:
is enforced by raising a TypeError in `__setattr__` after `setup`.
- Pytrees of dicts and lists are transformed into FrozenDict and tuples during
attribute assignment.
This avoids undetected submodules and inner state.
This avoids undetected submodules and inner state.
- Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple
containing the output and a frozen empty
collection when `mutable` is specified as an empty list.
Expand Down
27 changes: 15 additions & 12 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,25 @@

# pylint: disable=g-multiple-import
# re-export commonly used modules and functions
from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, soft_sign, softmax,
softplus, swish, silu, tanh, PReLU)
from .activation import (PReLU, celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, silu, soft_sign, softmax,
softplus, swish, tanh)
from .attention import (MultiHeadDotProductAttention, SelfAttention,
dot_product_attention, dot_product_attention_weights,
make_attention_mask, make_causal_mask, combine_masks)
from ..core import broadcast, DenyList, FrozenDict
combine_masks, dot_product_attention,
dot_product_attention_weights, make_attention_mask,
make_causal_mask)
from ..core import DenyList, FrozenDict, broadcast
from .dtypes import canonicalize_inexact_dtypes, canonicalize_numeric_dtypes
from .initializers import ones, zeros
from .linear import Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed
from .module import (Module, compact, nowrap, enable_named_call,
disable_named_call, override_named_call, Variable, init,
init_with_output, apply, merge_param)
from .module import (Module, Variable, apply, compact, disable_named_call,
enable_named_call, init, init_with_output, merge_param,
nowrap, override_named_call)
from .normalization import BatchNorm, GroupNorm, LayerNorm
from .pooling import avg_pool, max_pool, pool
from .recurrent import GRUCell, LSTMCell, ConvLSTM, OptimizedLSTMCell
from .recurrent import ConvLSTM, GRUCell, LSTMCell, OptimizedLSTMCell
from .stochastic import Dropout
from .transforms import jit, named_call, checkpoint, remat, remat_scan, scan, vmap, map_variables, vjp, jvp, custom_vjp
from .initializers import zeros, ones
from .transforms import (checkpoint, custom_vjp, jit, jvp, map_variables,
named_call, remat, remat_scan, scan, vjp, vmap)

# pylint: enable=g-multiple-import
54 changes: 22 additions & 32 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,33 @@
"""Activation functions.
"""

import jax.numpy as jnp
# pylint: disable=unused-import
# re-export activation functions from jax.nn
from jax.nn import celu
from jax.nn import elu
from jax.nn import gelu
from jax.nn import glu
from jax.nn import leaky_relu
from jax.nn import log_sigmoid
from jax.nn import log_softmax
from jax.nn import normalize
from jax.nn import relu
from jax.nn import sigmoid
from jax.nn import soft_sign
from jax.nn import softmax
from jax.nn import softplus
from jax.nn import swish
from jax.nn import silu
from jax.nn import selu
from jax.nn import hard_tanh
from jax.nn import relu6
from jax.nn import hard_sigmoid
from jax.nn import hard_swish

# re-export activation functions from jax.nn and jax.numpy
from jax.nn import (celu, elu, gelu, glu, hard_sigmoid, hard_swish, hard_tanh,
leaky_relu, log_sigmoid, log_softmax, normalize, relu,
relu6, selu, sigmoid, silu, soft_sign, softmax, softplus,
swish)
from jax.numpy import tanh
# pylint: enable=unused-import

from typing import Any

from flax.linen.module import Module, compact
import jax.numpy as jnp


Array = Any
from .dtypes import Array, FloatingDType, canonicalize_inexact_dtypes
from .module import Module, compact


class PReLU(Module):
"""Parametric Rectified Linear Unit (PReLU) activation function.

Attributes:
negative_slope_init: the value to initialize the negative slope.
dtype: the dtype of the computation (default: float32).
param_dtype: the dtype passed to parameter initializers (default: float32).
negative_slope_init: the value to initialize the negative slope
(default 0.01).
"""
dtype: Optional[FloatingDType] = jnp.float32
param_dtype: Optional[FloatingDType] = jnp.float32
negative_slope_init: float = 0.01

@compact
def __call__(self, inputs: Array) -> Array:
"""Applies an activation to the inputs.
Expand All @@ -67,8 +52,13 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The transformed input.
"""
assert jnp.issubdtype(inputs.dtype, jnp.floating)
inputs = jnp.asarray(inputs, dtype)
param_dtype, dtype = canonicalize_inexact_dtypes(inputs.dtype, param_dtype,
self.dtype)
negative_slope = self.param(
'negative_slope',
lambda k: jnp.asarray(self.negative_slope_init, jnp.float32)
lambda k: jnp.asarray(self.negative_slope_init, param_dtype)
)
return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs)
negative_slope = jnp.asarray(negative_slope, dtype)
return jnp.where(inputs >= 0, inputs, negative_slope * inputs)
86 changes: 51 additions & 35 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,34 @@
"""Attention core modules for Flax."""

from functools import partial
from typing import (Any, Callable, Tuple, Optional)
from typing import Callable, Optional

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np

from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.module import Module, compact, merge_param
from flax.linen.initializers import zeros

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any
from jax import lax, random
from typing_extensions import Protocol

from .dtypes import (Array, InexactDType, Initializer, PRNGKey,
canonicalize_inexact_dtypes)
from .initializers import zeros
from .linear import DenseGeneral, default_kernel_init
from .module import Module, compact, merge_param


class AttentionFunction(Protocol):
@staticmethod
def __call__(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
...


def dot_product_attention_weights(query: Array,
Expand All @@ -42,7 +53,7 @@ def dot_product_attention_weights(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None):
"""Computes dot-product attention weights given query and key.

Expand Down Expand Up @@ -109,7 +120,7 @@ def dot_product_attention_weights(query: Array,
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) /
multiplier = (keep.astype(dtype) /
jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier

Expand All @@ -125,8 +136,8 @@ def dot_product_attention(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
precision: Optional[lax.Precision] = None):
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
"""Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on
Expand Down Expand Up @@ -205,18 +216,18 @@ class MultiHeadDotProductAttention(Module):
decode: whether to prepare and use an autoregressive cache.
"""
num_heads: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
dtype: Optional[InexactDType] = None
param_dtype: Optional[InexactDType] = None
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.
deterministic: Optional[bool] = None
precision: Optional[lax.Precision] = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = zeros
use_bias: bool = True
attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention
attention_fn: AttentionFunction = dot_product_attention
decode: bool = False

@compact
Expand Down Expand Up @@ -246,6 +257,10 @@ def __call__(self,
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
param_dtype, dtype = canonicalize_inexact_dtypes(jnp.result_type(inputs_q,
inputs_kv),
self.param_dtype,
self.dtype)
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
Expand All @@ -254,8 +269,8 @@ def __call__(self,

dense = partial(DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
Expand Down Expand Up @@ -323,16 +338,16 @@ def __call__(self,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
dtype=dtype,
precision=self.precision) # pytype: disable=wrong-keyword-args
# back to the original inputs dimensions
out = DenseGeneral(features=features,
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
precision=self.precision,
name='out')(x)
return out
Expand All @@ -350,11 +365,12 @@ def __call__(self, inputs_q: Array, mask: Optional[Array] = None,
# mask-making utility functions


def make_attention_mask(query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32):
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: InexactDType = jnp.float32):
"""Mask-making helper for attention weights.

In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
Expand All @@ -381,7 +397,7 @@ def make_attention_mask(query_input: Array,

def make_causal_mask(x: Array,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32) -> Array:
dtype: InexactDType = jnp.float32) -> Array:
"""Make a causal mask for self-attention.

In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
Expand All @@ -403,7 +419,7 @@ def make_causal_mask(x: Array,


def combine_masks(*masks: Optional[Array],
dtype: Dtype = jnp.float32) -> Array:
dtype: InexactDType = jnp.float32) -> Array:
"""Combine attention masks.

Args:
Expand Down
Loading