Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,85 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> (


class Conv(_Conv):
"""Convolution Module wrapping `lax.conv_general_dilated`."""
"""Convolution Module wrapping `lax.conv_general_dilated`.

Attributes:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: an integer or a sequence of `n` integers, representing the
inter-window strides (default: 1).
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
left-pad the convolution axis, resulting in same-sized output.
input_dilation: an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
kernel_dilation: an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
feature_group_count: integer, default 1. If specified divides the input
features into groups.
use_bias: whether to add a bias to the output (default: True).
mask: Optional mask for the weights during masked convolution. The mask must
be the same shape as the convolution weight matrix.
dtype: the dtype of the computation (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
"""

@property
def shared_weights(self) -> bool:
return True


class ConvLocal(_Conv):
"""Local convolution Module wrapping `lax.conv_general_dilated_local`."""
"""Local convolution Module wrapping `lax.conv_general_dilated_local`.

Attributes:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: an integer or a sequence of `n` integers, representing the
inter-window strides (default: 1).
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
left-pad the convolution axis, resulting in same-sized output.
input_dilation: an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
kernel_dilation: an integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
feature_group_count: integer, default 1. If specified divides the input
features into groups.
use_bias: whether to add a bias to the output (default: True).
mask: Optional mask for the weights during masked convolution. The mask must
be the same shape as the convolution weight matrix.
dtype: the dtype of the computation (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
"""

@property
def shared_weights(self) -> bool:
Expand Down