From 9a9539b8cf91ec443da5776b575a4fd1f1b68e32 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 2 Sep 2022 10:23:33 -0500 Subject: [PATCH] add attribute docs to Conv and ConvLocal --- flax/linen/linear.py | 74 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/flax/linen/linear.py b/flax/linen/linear.py index b5872f64a..ca86465ca 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -467,7 +467,42 @@ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> ( class Conv(_Conv): - """Convolution Module wrapping `lax.conv_general_dilated`.""" + """Convolution Module wrapping `lax.conv_general_dilated`. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: an integer or a sequence of `n` integers, representing the + inter-window strides (default: 1). + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. `'CAUSAL'` padding for a 1D convolution will + left-pad the convolution axis, resulting in same-sized output. + input_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + kernel_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + feature_group_count: integer, default 1. If specified divides the input + features into groups. + use_bias: whether to add a bias to the output (default: True). + mask: Optional mask for the weights during masked convolution. The mask must + be the same shape as the convolution weight matrix. + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ @property def shared_weights(self) -> bool: @@ -475,7 +510,42 @@ def shared_weights(self) -> bool: class ConvLocal(_Conv): - """Local convolution Module wrapping `lax.conv_general_dilated_local`.""" + """Local convolution Module wrapping `lax.conv_general_dilated_local`. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: an integer or a sequence of `n` integers, representing the + inter-window strides (default: 1). + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. `'CAUSAL'` padding for a 1D convolution will + left-pad the convolution axis, resulting in same-sized output. + input_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + kernel_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + feature_group_count: integer, default 1. If specified divides the input + features into groups. + use_bias: whether to add a bias to the output (default: True). + mask: Optional mask for the weights during masked convolution. The mask must + be the same shape as the convolution weight matrix. + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ @property def shared_weights(self) -> bool: