diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 1470f35e..f79ca42f 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -87,7 +87,7 @@ from jax.tree_util import tree_map, tree_flatten, tree_unflatten from neural_tangents.utils import utils from neural_tangents.utils.kernel import Kernel -from neural_tangents.utils.typing import InitFn, AnalyticKernelFn, LayerKernelFn, InternalLayer, Layer, Kernels, Shapes, Axes, Get +from neural_tangents.utils.typing import InitFn, AnalyticKernelFn, LayerKernelFn, InternalLayer, Layer, Kernels, Shapes, Axes, PyTree, Get import scipy as osp @@ -298,8 +298,8 @@ def serial(*layers: Layer) -> InternalLayer: Based on `jax.experimental.stax.serial`. Args: - *layers: a sequence of layers, each an `(init_fn, apply_fn, kernel_fn)` - triple. + *layers: + a sequence of layers, each an `(init_fn, apply_fn, kernel_fn)` triple. Returns: A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple, @@ -325,8 +325,8 @@ def parallel(*layers: Layer) -> InternalLayer: `FanInSum`/`FanInConcat` layers. Based on `jax.experimental.stax.parallel`. Args: - *layers: a sequence of layers, each with a `(init_fn, apply_fn, kernel_fn)` - triple. + *layers: + a sequence of layers, each with a `(init_fn, apply_fn, kernel_fn)` triple. Returns: A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triples, @@ -361,30 +361,35 @@ def Dense( Based on `jax.experimental.stax.Dense`. Args: - out_dim: The output feature / channel dimension. This is ignored in by the - `kernel_fn` in NTK parameterization. + out_dim: + The output feature / channel dimension. This is ignored in by the + `kernel_fn` in `"ntk"` parameterization. - W_std: Specifies the standard deviation of the weights. + W_std: + Specifies the standard deviation of the weights. - b_std: Specifies the standard deviation of the biases. + b_std: + Specifies the standard deviation of the biases. - parameterization: Either `"ntk"` or `"standard"`. + parameterization: + Either `"ntk"` or `"standard"`. - Under ntk parameterization (https://arxiv.org/abs/1806.07572, page 3), + Under `"ntk"` parameterization (https://arxiv.org/abs/1806.07572, page 3), weights and biases are initialized as :math:`W_{ij} \sim \mathcal{N}(0,1)`, :math:`b_i \sim \mathcal{N}(0,1)`, and the finite width layer equation is :math:`z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i`. - Under standard parameterization (https://arxiv.org/abs/2001.07301), + Under `"standard"` parameterization (https://arxiv.org/abs/2001.07301), weights and biases are initialized as :math:`W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)`, :math:`b_i \sim \mathcal{N}(0,\sigma_b^2)`, and the finite width layer equation is :math:`z_i = \sum_j W_{ij} x_j + b_i`. - batch_axis: Specifies which axis is contains different elements of the - batch. Defaults to `0`, the leading axis. + batch_axis: + Specifies which axis is contains different elements of the batch. + Defaults to `0`, the leading axis. channel_axis: Specifies which axis contains the features / channels. Defaults to `-1`, the trailing axis. For `kernel_fn`, channel size is @@ -440,7 +445,7 @@ def apply_fn(params, inputs, **kwargs): @_requires(batch_axis=batch_axis, channel_axis=channel_axis) def kernel_fn(k: Kernel, **kwargs): - """Compute the transformed kernels after a dense layer.""" + """Compute the transformed kernels after a `Dense` layer.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk def fc(x): @@ -485,22 +490,29 @@ def GeneralConv( Based on `jax.experimental.stax.GeneralConv`. Args: - dimension_numbers: Specifies which axes should be convolved over. Should - match the specification in `jax.lax.dot_general_dilated`. - out_chan: The number of output channels / features of the - convolution. This is ignored in by the `kernel_fn` in NTK - parameterization. - filter_shape: The shape of the filter. The shape of the tuple should agree - with the number of spatial dimensions in `dimension_numbers`. - strides: The stride of the convolution. The shape of the tuple should agree - with the number of spatial dimensions in `dimension_nubmers`. - padding: Specifies padding for the convolution. Can be one of `"VALID"`, - `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. - W_std: The standard deviation of the weights. - b_std: The standard deviation of the biases. - parameterization: Either `"ntk"` or `"standard"`. These parameterizations - are the direct analogues for convolution of the corresponding - parameterizations for `Dense` layers. + dimension_numbers: + Specifies which axes should be convolved over. Should match the + specification in `jax.lax.dot_general_dilated`. + out_chan: + The number of output channels / features of the convolution. This is + ignored in by the `kernel_fn` in `"ntk"` parameterization. + filter_shape: + The shape of the filter. The shape of the tuple should agree with the + number of spatial dimensions in `dimension_numbers`. + strides: + The stride of the convolution. The shape of the tuple should agree with + the number of spatial dimensions in `dimension_nubmers`. + padding: + Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, + or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. + W_std: + standard deviation of the weights. + b_std: + standard deviation of the biases. + parameterization: + Either `"ntk"` or `"standard"`. These parameterizations are the direct + analogues for convolution of the corresponding parameterizations for + `Dense` layers. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -530,20 +542,27 @@ def Conv( Based on `jax.experimental.stax.Conv`. Args: - out_chan: The number of output channels / features of the - convolution. This is ignored in by the `kernel_fn` in NTK + out_chan: + The number of output channels / features of the + convolution. This is ignored in by the `kernel_fn` in `"ntk"` parameterization. - filter_shape: The shape of the filter. The shape of the tuple should agree - with the number of spatial dimensions in `dimension_numbers`. - strides: The stride of the convolution. The shape of the tuple should agree + filter_shape: + The shape of the filter. The shape of the tuple should agree with the + number of spatial dimensions in `dimension_numbers`. + strides: + The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in `dimension_nubmers`. - padding: Specifies padding for the convolution. Can be one of `"VALID"`, - `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. - W_std: The standard deviation of the weights. - b_std: The standard deviation of the biases. - parameterization: Either `"ntk"` or `"standard"`. These parameterizations - are the direct analogues for convolution of the corresponding - parameterizations for `Dense` layers. + padding: + Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, + or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. + W_std: + standard deviation of the weights. + b_std: + standard deviation of the biases. + parameterization: + Either `"ntk"` or `"standard"`. These parameterizations are the direct + analogues for convolution of the corresponding parameterizations for + `Dense` layers. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -1217,10 +1236,11 @@ def Erf(a: float = 1., c: float = 0., do_backprop: bool = False) -> InternalLayer: """Affine transform of `Erf` nonlinearity, i.e. `a Erf(b * x) + c`. + Args: - a: a float. - b: a float. - c: a float. + a: output scale. + b: input scale. + c: output shift. do_backprop: set to `True` if you want to backpropagate through the kernel. Returns: @@ -1255,12 +1275,12 @@ def Gelu(do_backprop: bool = False) -> InternalLayer: def Sin(a: float = 1., b: float = 1., c: float = 0.) -> InternalLayer: - """Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)` + """Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)`. Args: - a: a float. - b: a float. - c: a float. + a: output scale. + b: input scale. + c: input phase shift. Returns: `(init_fn, apply_fn, kernel_fn)`. """ @@ -1270,16 +1290,16 @@ def Sin(a: float = 1., @layer @_supports_masking(remask_kernel=True) def Rbf(gamma: float = 1.0) -> InternalLayer: - """Returns the dual activation function layer for normalized RBF or sqaured exponential kernel. + """Returns the dual activation function layer for normalized RBF or squared exponential kernel. Dual activation function is `f(x) = sqrt(2)*sin(sqrt(2*gamma) x + pi/4)`. - - NNGP kernel transformation correspond to (with input dimension `d`) + NNGP kernel transformation correspond to (with input dimension `d`) `k = exp(- gamma / d * ||x - x'||^2) = exp(- gamma*(q11 + q22 - 2 * q12))`. Args: - gamma: related to characteristic length-scale (l) that controls width of - the kernel, where `gamma = 1 / (2 l^2)`. + gamma: + related to characteristic length-scale (l) that controls width of the + kernel, where `gamma = 1 / (2 l^2)`. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -1419,6 +1439,29 @@ def _NumericalActivation(fn: Callable[[float], float], quad_points=quad_points, do_backprop=do_backprop) +class PositionalEmbedding(enum.Enum): + """Type of positional embeddings to use in a `GlobalSelfAttention` layer.""" + NONE = 'NONE' + SUM = 'SUM' + CONCAT = 'CONCAT' + + +class AttentionMechanism(enum.Enum): + """Type of nonlinearity to use in a `GlobalSelfAttention` layer.""" + SOFTMAX = 'SOFTMAX' + IDENTITY = 'IDENTITY' + ABS = 'ABS' + RELU = 'RELU' + + def fn(self): + return { + 'softmax': ostax.softmax, + 'identity': lambda x: x, + 'abs': np.abs, + 'relu': lambda x: np.maximum(x, 0.) + }[self.name.lower()] + + @layer @_supports_masking(remask_kernel=True) def GlobalSelfAttention( @@ -1426,36 +1469,44 @@ def GlobalSelfAttention( n_chan_key: int, n_chan_val: int, n_heads: int, - fixed: bool = True, + linear_scaling: bool = True, W_key_std: float = 1.0, W_value_std: float = 1.0, W_query_std: float = 1.0, W_out_std: float = 1.0, b_std: float = 0.0, + attention_mechanism: str = AttentionMechanism.SOFTMAX.name, + pos_emb_type: str = PositionalEmbedding.NONE.name, + pos_emb_p_norm: float = 2, + pos_emb_decay_fn: Callable[[float], float] = None, + n_chan_pos_emb: int = None, + W_pos_emb_std: float = 1.0, + val_pos_emb: bool = False, batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer: - """Scaled dot-product self-attention with multiple attention heads. + """Layer construction function for (global) scaled dot-product self-attention. + + Infinite width results based on https://arxiv.org/abs/2006.10540. Two versions of attention are available (the version to be used is - determined by the argument `fixed`): + determined by the argument `linear_scaling`): - 1. Parametric: this is the standard scaled dot-product attention, i.e., + 1. `False`: this is the standard scaled dot-product attention, i.e., the dot product between keys and queries is scaled by the squared root of their dimension. The expression for `nngp`/`ntk` involves an integral with no known closed form and thus call to `kernel_fn` results in an error. - 2. Fixed: same as Parametric except for scaling the dot products - between keys and queries by their dimension instead of the square root - of the same quantity, and tying the key and query weight matrices. - This makes the `nngp`/`ntk` analytically tractable but for the price - that, unlike in the parametric case, the dot products of keys and queries - converge to a constant. Because this constant would be zero - if the key and query weights are independent, the variant where these + 2. `True`: scaling the dot products between keys and queries by their + dimension instead of the square root of the same quantity, AND tying the key + and query weight matrices. This makes the `nngp`/`ntk` analytically tractable + but for the price that, unlike in the `False` case, the dot products of keys + and queries converge to a constant. Because this constant would be zero + if the key and query weights were independent, the variant where these two weight matrices are tied was implemented resulting in non-constant attention weights. The final computation for single head is then - :math:`f_h (x) + softmax( Q(x) K(x)^T) V(x)` + :math:`f_h (x) + attention_mechanism( Q(x) K(x)^T) V(x)` and the output of this layer is computed as :math:`f(x) = concat[f_1(x) , ... , f_{} (x)] W_{out} + b` where the shape of of `b` is `(n_chan_out,)`, i.e., single bias per channel @@ -1465,174 +1516,430 @@ def GlobalSelfAttention( goes to infinity. Args: - n_chan_out: Number of feature dimensions of outputs. - n_chan_key: Number of feature dimensions of keys/queries. - n_chan_val: Number of feature dimensions of values. - n_heads: Number of attention heads - fixed: If `True`, the dot products between keys and queries are - scaled by `1 / n_chan_key` and the key and query weight matrices are tied; + n_chan_out: + number of feature dimensions of outputs. + n_chan_key: + number of feature dimensions of keys/queries. + n_chan_val: + number of feature dimensions of values. + n_heads: + number of attention heads. + linear_scaling: + if `True`, the dot products between keys and queries are scaled by + `1 / n_chan_key` and the key and query weight matrices are tied; if `False`, the dot products are scaled by `1 / sqrt(n_chan_key)` and the key and query matrices are independent. - W_key_std: init standard deviation of the key weights values. - W_value_std: init standard deviation of the key weights values. - W_query_std: init standard deviation of the query weights values; if - `fixed` is `True` (and thus key and query weights are tied---see above) - then keys are computed with `WK = WK_std * W / sqrt(n_chan_in)` and the - queries are computed with `WQ = W_query_std * W / sqrt(n_chan_in)` weight - matrices - W_out_std: Initial standard deviation of the output weights values. - b_std: Initial standard deviation of the bias values. - batch_axis: Specifies the batch dimension. Defaults to `0`, the leading - axis. - channel_axis: Specifies the channel / feature dimension. Defaults to `-1`, - the trailing axis. For `kernel_fn`, channel size is considered to be - infinite. + W_key_std: + init standard deviation of the key weights values. Due to NTK + parameterization, influences computation only through the product + `W_key_std * W_query_std`. + W_value_std: + init standard deviation of the value weights values. Due to NTK + parameterization, influences computation only through the product + `W_out_std * W_value_std`. + W_query_std: + init standard deviation of the query weights values; if `linear_scaling` + is `True` (and thus key and query weights are tied - see above) then keys + are computed with `WK = W_key_std * W / sqrt(n_chan_in)` and queries are + computed with `WQ = W_query_std * W / sqrt(n_chan_in)` weight matrices. + Due to NTK parameterization, influences computation only through the + product `W_key_std * W_query_std`. + W_out_std: + initial standard deviation of the output weights values. Due to NTK + parameterization, influences computation only through the product + `W_out_std * W_value_std`. + b_std: + initial standard deviation of the bias values. + attention_mechanism: + a string, `"SOFTMAX"`, `"IDENTITY"`, `"ABS"`, or `"RELU"`, the + transformation applied to dot product attention weights. + pos_emb_type: + a string, `"NONE"`, `"SUM"`, or `"CONCAT"`, the type of positional + embeddings to use. In the infinite-width limit, `"SUM"` and `"CONCAT"` + are equivalent up to a scaling constant. Keep in mind that all `Dense` + sub-layers of the attention layer use the NTK parameterization, and weight + variances are always inversely proportional to the input channel size, + which leads to different effective variances when using `"SUM"` and + `"CONCAT"` embeddings, even if all variance scales like `W_key_std` etc. + are the same. + pos_emb_p_norm: + use the unnormalized L-`p` distance to the power of `p` (with + `p == pos_emb_p_norm`) to compute pairwise distances for positional + embeddings (see `pos_emb_decay_fn` for details). Used only if + `pos_emb_type != "NONE"` and `pos_emb_decay_fn is not None`. + pos_emb_decay_fn: + a function applied to the L-`p` distance to the power of `p` (with + `p == pos_emb_p_norm`) distance between two spatial positions to produce + the positional embeddings covariance matrix (e.g. power decay, + exponential decay, etc.). `None` is equivalent to an indicator function + `lambda d: d == 0`, and returns a diagonal covariance matrix. Used only + if `pos_emb_type != "NONE"`. + n_chan_pos_emb: + number of channels in positional embeddings. `None` means use the same + number of channels as in the layer inputs. Can be used to tune the + contribution of positional embeddings relative to contribution of inputs + if `pos_emb_type == "CONCAT"`. Used only if `pos_emb_type != "NONE"`. + Will trigger an error if `pos_emb_type == "SUM"` and `n_chan_pos_emb` is + not `None` or does not match the layer inputs channel size at runtime. + W_pos_emb_std: + init standard deviation of the random positional embeddings. Can be used + to tune the contribution of positional embeddings relative to the + contribution of inputs. Used only if `pos_emb_type != "NONE"`. To tune + the _relative_ (to the inputs) contribution, you can either use + `n_chan_pos_emb` when `pos_emb_type == "CONCAT"`, or, if + `pos_emb_type == "CONCAT"`, adjust `W_key_std` etc. relative to + `W_pos_emb_std`, to keep the total output variance fixed. + val_pos_emb: + `True` indicates using positional embeddings when computing all of the + keys/queries/values matrices, `False` makes them only used for keys and + queries, but not values. Used only if `pos_emb_type != "NONE"`. + batch_axis: + Specifies the batch dimension. Defaults to `0`, the leading axis. + channel_axis: + Specifies the channel / feature dimension. Defaults to `-1`, the trailing + axis. For `kernel_fn`, channel size is considered to be infinite. Returns: `(init_fn, apply_fn, kernel_fn)`. Raises: - NotImplementedError: If `fixed` is `False`, call to `kernel_fn` will result - in an error as there is no known analytic expression for the kernel. + NotImplementedError: If `linear_scaling` is `False`, calling `kernel_fn` + will result in an error as there is no known analytic expression for the + kernel for `attention_mechanism != "IDENTITY"`. + + NotImplementedError: If `apply_fn` is called with `pos_emb_decay_fn != None` + , since custom `pos_emb_decay_fn` is only implemented in the infinite + width regime currently. + """ + QK_std = W_query_std * W_key_std + OV_std = W_out_std * W_value_std + + pos_emb_type = PositionalEmbedding(pos_emb_type) + attention_mechanism = AttentionMechanism(attention_mechanism) - OV_gain = W_out_std * W_value_std - QK_gain = W_query_std * W_key_std - QK_prod_scaling = float(n_chan_key if fixed else n_chan_key**0.5) + @functools.lru_cache(1) + def get_pos_emb_L(spatial_shape): + size = utils.size_at(spatial_shape) + R = _pos_emb_pdist(spatial_shape, pos_emb_p_norm, pos_emb_decay_fn) + R = utils.unzip_axes(R) + L = np.linalg.cholesky(np.reshape(R, (size,) * 2)).reshape(R.shape) + return L def init_fn(rng, input_shape): _channel_axis = channel_axis % len(input_shape) - n_chan_in = input_shape[_channel_axis] output_shape = (input_shape[:_channel_axis] + (n_chan_out,) + input_shape[_channel_axis + 1:]) - rng_Q, rng_K, rng_V, rng_O, rng_b = random.split(rng, 5) - + rng_Q, rng_K, rng_V, rng_O, rng_b, rng_pe = random.split(rng, 6) rand = random.normal - key_matrices = rand(rng_K, shape=(n_heads, n_chan_in, n_chan_key)) - val_matrices = rand(rng_V, shape=(n_heads, n_chan_in, n_chan_val)) + + n_chan_in_keys = n_chan_in_vals = input_shape[channel_axis] + + # Generate and add / append positional embeddings. + if pos_emb_type == PositionalEmbedding.NONE: + pos_emb = None + else: + # `None` means positional embeddings have the same number of channels + # as inputs. + _n_chan_pos_emb = (n_chan_in_keys if n_chan_pos_emb is None + else n_chan_pos_emb) + + pos_emb_shape = list(input_shape) + pos_emb_shape[channel_axis] = _n_chan_pos_emb + pos_emb_shape[batch_axis] = 1 + pos_emb = rand(rng_pe, shape=pos_emb_shape) + + if pos_emb_type == PositionalEmbedding.CONCAT: + n_chan_in_keys += _n_chan_pos_emb + if val_pos_emb: + n_chan_in_vals += _n_chan_pos_emb + + key_matrices = rand(rng_K, shape=(n_heads, n_chan_in_keys, n_chan_key)) + val_matrices = rand(rng_V, shape=(n_heads, n_chan_in_vals, n_chan_val)) W_out = rand(rng_O, shape=(n_chan_val * n_heads, n_chan_out)) b_shape = [1] * len(input_shape) b_shape[_channel_axis] = n_chan_out b = rand(rng_b, shape=b_shape) - if fixed: + if linear_scaling: query_matrices = None - warnings.warn('Fixed attention used -> query initialization ignored, ' - 'tying the weights (see docstring for more details).') + warnings.warn('Linear scaling attention used -> query initialization ' + 'ignored, tying the weights ' + '(see docstring for more details).') else: - query_matrices = rand(rng_Q, shape=(n_heads, n_chan_in, n_chan_key)) + query_matrices = rand(rng_Q, (n_heads, n_chan_in_keys, n_chan_key)) - return output_shape, (query_matrices, key_matrices, val_matrices, W_out, b) + return (output_shape, + (query_matrices, key_matrices, val_matrices, W_out, b, pos_emb)) - def apply_fn(params, inputs, mask=None, **kwargs): - query_matrices, key_matrices, val_matrices, W_out, b = params + def apply_fn(params: PyTree, + inputs: np.ndarray, + mask: np.ndarray = None, + **kwargs) -> np.ndarray: + query_matrices, key_matrices, val_matrices, W_out, b, pos_emb = params + spatial_shape, spatial_axes = utils.shape_and_axes( + inputs, (batch_axis, channel_axis)) n = inputs.shape[batch_axis] - _channel_axis = channel_axis % inputs.ndim - n_chan_in = inputs.shape[_channel_axis] - spatial_shape = tuple(s for i, s in enumerate(inputs.shape) - if i not in (batch_axis, _channel_axis)) - inputs = np.moveaxis(inputs, (batch_axis, _channel_axis), (0, -1)) - inputs = inputs.reshape((n, -1, n_chan_in)) + if pos_emb is not None: + # Generate positional embeddings. + if pos_emb_decay_fn is not None: + L = get_pos_emb_L(spatial_shape) + first = tuple(range(L.ndim // 2)) + last = tuple(range(L.ndim // 2, L.ndim)) + pos_emb = np.tensordot(L, pos_emb, (last, spatial_axes)) + pos_emb = np.moveaxis(pos_emb, first, spatial_axes) - def _inputs_dot(matrices, std): - ret = np.dot(inputs, std * matrices / np.sqrt(n_chan_in)) + # Mask positional embeddings. + if mask is not None: + pos_emb = np.where(mask, np.zeros((), pos_emb.dtype), pos_emb) + + pos_emb *= W_pos_emb_std + + # Add / concat positional embeddings. + if pos_emb_type == PositionalEmbedding.SUM: + inputs_val = None if val_pos_emb else inputs + inputs = pos_emb + inputs + + elif pos_emb_type == PositionalEmbedding.CONCAT: + inputs_val = inputs if not val_pos_emb else None + _n_chan_pos_emb = (inputs.shape[channel_axis] if n_chan_pos_emb is None + else n_chan_pos_emb) + _channel_axis = channel_axis % inputs.ndim + pos_emb = np.broadcast_to( + pos_emb, + inputs.shape[:_channel_axis] + (_n_chan_pos_emb,) + + inputs.shape[_channel_axis + 1:]) + inputs = np.concatenate([inputs, pos_emb], axis=channel_axis) + + elif pos_emb_type == PositionalEmbedding.NONE: + inputs_val = None + + # Prepare separate inputs for values if asked to not add positional + # embeddings to values. + if inputs_val is not None: + inputs_val = np.moveaxis(inputs_val, (batch_axis, channel_axis), (0, -1)) + inputs_val = inputs_val.reshape((n, -1, inputs_val.shape[-1])) + + # Flatten all spatial dimensions and make input of shape + # `(batch_size, total_spatial_size, n_channels)`. + inputs = np.moveaxis(inputs, (batch_axis, channel_axis), (0, -1)) + inputs = inputs.reshape((n, -1, inputs.shape[-1])) + + def _inputs_dot(matrices, _inputs=inputs): + ret = np.dot(_inputs, matrices) return np.moveaxis(ret, 2, 0) - keys = _inputs_dot(key_matrices, W_key_std) - values = _inputs_dot(val_matrices, W_value_std) - if fixed: - queries = keys * W_query_std / W_key_std + # Drop positional embedding information for value matrices if requested. + if inputs_val is not None: + values = _inputs_dot(val_matrices, inputs_val) + n_chan_in = inputs_val.shape[-1] else: - queries = _inputs_dot(query_matrices, W_query_std) + values = _inputs_dot(val_matrices) + n_chan_in = inputs.shape[-1] + + keys = _inputs_dot(key_matrices) + if linear_scaling: + queries = keys + else: + queries = _inputs_dot(query_matrices) G_mat = np.matmul(queries, np.moveaxis(keys, -1, -2)) - G_mat /= QK_prod_scaling + norm = inputs.shape[-1] * n_chan_key ** (1 if linear_scaling else 0.5) + G_mat *= QK_std / norm if mask is not None: mask = np.all(mask, axis=channel_axis, keepdims=True) - mask_flat = np.moveaxis(mask, (batch_axis, channel_axis), (0, -1)) - mask_flat = mask_flat.reshape((1, mask.shape[0], 1, -1)) - G_mat = np.where(mask_flat, _NEG_INF, G_mat) - - G_mat = ostax.softmax(G_mat, axis=-1) + mask = np.moveaxis(mask, (batch_axis, channel_axis), (0, -1)) + mask = mask.reshape((1, mask.shape[0], 1, -1)) + + if attention_mechanism == AttentionMechanism.SOFTMAX: + G_mat = np.where(mask, _NEG_INF, G_mat) + elif attention_mechanism in (AttentionMechanism.IDENTITY, + AttentionMechanism.RELU, + AttentionMechanism.ABS): + G_mat = np.where(mask, np.zeros((), G_mat.dtype), G_mat) + else: + raise NotImplementedError(attention_mechanism, mask) + G_mat = attention_mechanism.fn()(G_mat) heads = np.matmul(G_mat, values) heads = np.moveaxis(heads, 0, -1) heads = np.reshape(heads, heads.shape[:-2] + (-1,)) - ret = np.matmul(heads, W_out_std * W_out / np.sqrt(n_chan_val * n_heads)) - ret = np.reshape(ret, (n,) + spatial_shape + (n_chan_out,)) - ret = np.moveaxis(ret, (0, -1), (batch_axis, _channel_axis)) + b_std * b - return ret + outputs = np.matmul(heads, W_out) + outputs *= OV_std / (n_chan_val * n_heads * n_chan_in) ** 0.5 + + outputs = np.reshape(outputs, (n,) + spatial_shape + (n_chan_out,)) + outputs = np.moveaxis(outputs, (0, -1), (batch_axis, channel_axis)) + return outputs + b_std * b @_requires(batch_axis=batch_axis, channel_axis=channel_axis, diagonal_spatial=False) def kernel_fn(k: Kernel, **kwargs): - cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk + # Generate (optional) positional embedding covariances. + R1, R12, R2 = _get_all_pos_emb(k, pos_emb_type, pos_emb_p_norm, + pos_emb_decay_fn) - if not fixed: - raise NotImplementedError('No known closed form expression.') + def _get_interpolation_coefficients(): + input_weight, pos_emb_weight = 1, W_pos_emb_std**2 - def _get_G_softmax(mat, mask): - if not k.diagonal_batch: - mat = np.moveaxis(np.diagonal(mat, axis1=0, axis2=1), -1, 0) + if pos_emb_type == PositionalEmbedding.CONCAT: + # Reweight based on relative widths of inputs and channels. + n_chan_input = k.shape1[channel_axis] + _n_chan_pos_emb = (k.shape1[channel_axis] if n_chan_pos_emb is None + else n_chan_pos_emb) + n_chan_total = n_chan_input + _n_chan_pos_emb - if mask is not None: - mask = np.all(mask, axis=channel_axis, keepdims=True) - mask = np.moveaxis(mask, (batch_axis, channel_axis), (0, -1)) - mask = np.squeeze(mask, axis=-1) - if k.is_reversed: - mask = np.moveaxis(mask, range(1, mask.ndim), - range(mask.ndim -1, 0, -1)) - mask = utils.interleave_ones(mask, 1, mask.ndim, False) - mat = np.where(mask, _NEG_INF, mat) - - axes = tuple(range(mat.ndim)) - return ostax.softmax(QK_gain * mat, axis=axes[2::2]) - - def _transform_kernel(mat, G1, G2=None): - if mat is None or mat.ndim == 0: - return mat - - G2 = G1 if G2 is None else G2 - - # Spatial axes - G1_dims = tuple(range(1, G1.ndim)) - G2_dims = tuple(range(G1.ndim, G1.ndim + G2.ndim - 1)) - mat_dims = utils.zip_flat(G1_dims[1::2], G2_dims[1::2]) - res_dims = utils.zip_flat(G1_dims[::2], G2_dims[::2]) - - # Batch axes - if mat.ndim % 2: - G1_dims = (0,) + G1_dims - G2_dims = (0,) + G2_dims - mat_dims = (0,) + mat_dims - res_dims = (0,) + res_dims + input_weight *= n_chan_input / n_chan_total + pos_emb_weight *= _n_chan_pos_emb / n_chan_total + + return input_weight, pos_emb_weight + + def weighted_sum(x, y, x_weight, y_weight): + if x is None or y is None: + return x + return x_weight * x + y_weight * y + + # Generate kernel interpolations. + kern_weight, pos_emb_weight = _get_interpolation_coefficients() + + cov1_interp = weighted_sum(k.cov1, R1, kern_weight, pos_emb_weight) + cov2_interp = weighted_sum(k.cov2, R2, kern_weight, pos_emb_weight) + + if val_pos_emb or (not linear_scaling and + attention_mechanism == AttentionMechanism.IDENTITY): + # These interpolations need to be computed in `d^-1/2` scaling even if + # positional embeddings aren't used in `values`. + nngp_interp = weighted_sum(k.nngp, R12, kern_weight, pos_emb_weight) + ntk_interp = weighted_sum(k.ntk, R12, kern_weight, pos_emb_weight) + + if linear_scaling: + + def _get_weighting(mat, mask): + if mat is None: + return None + + if not k.diagonal_batch: + mat = np.moveaxis(np.diagonal(mat, axis1=0, axis2=1), -1, 0) + + if mask is not None: + mask = np.all(mask, axis=channel_axis, keepdims=True) + mask = np.squeeze(np.moveaxis(mask, (batch_axis, channel_axis), + (0, -1)), -1) + if k.is_reversed: + mask = np.moveaxis(mask, + range(1, mask.ndim), + range(mask.ndim -1, 0, -1)) + mask = utils.interleave_ones(mask, 1, mask.ndim, x_first=False) + if attention_mechanism == AttentionMechanism.SOFTMAX: + mat = np.where(mask, _NEG_INF, mat) + else: + mat = np.where(mask, np.zeros((), mat.dtype), mat) + + if attention_mechanism == AttentionMechanism.SOFTMAX: + axes = tuple(range(mat.ndim)) + return attention_mechanism.fn()(QK_std * mat, axis=axes[2::2]) + else: + return attention_mechanism.fn()(QK_std * mat) + + def _weigh_kernel(mat, G1, G2=None): + if mat is not None and mat.ndim != 0: + G2 = G1 if G2 is None else G2 + + # Spatial axes + G1_dims = tuple(range(1, G1.ndim)) + G2_dims = tuple(range(G1.ndim, G1.ndim + G2.ndim - 1)) + mat_dims = utils.zip_flat(G1_dims[1::2], G2_dims[1::2]) + res_dims = utils.zip_flat(G1_dims[::2], G2_dims[::2]) + + G1_dims = (0,) + G1_dims + + # Batch axes + if mat.ndim % 2: # Even number of spatial axes + 1 or 2 batch axes + G2_dims = (0,) + G2_dims + mat_dims = (0,) + mat_dims + res_dims = (0,) + res_dims + + else: + G2_dims = (-1,) + G2_dims + mat_dims = (0, -1) + mat_dims + res_dims = (0, -1) + res_dims + + mat = np.einsum(G1, G1_dims, mat, mat_dims, G2, G2_dims, res_dims, + optimize=True) + return _affine(mat, OV_std, b_std) + + G1 = _get_weighting(cov1_interp, k.mask1) + G2 = _get_weighting(cov2_interp, k.mask2) + + cov1 = _weigh_kernel(cov1_interp if val_pos_emb else k.cov1, G1) + cov2 = _weigh_kernel(cov2_interp if val_pos_emb else k.cov2, G2) + + nngp = _weigh_kernel(nngp_interp if val_pos_emb else k.nngp, G1, G2) + if k.ntk is None: + ntk = None else: - G1_dims = (0,) + G1_dims - G2_dims = (-1,) + G2_dims - mat_dims = (0, -1) + mat_dims - res_dims = (0, -1) + res_dims + ntk = _weigh_kernel(ntk_interp if val_pos_emb else k.ntk, + G1, G2) + 2 * (nngp - b_std**2) - res = np.einsum(G1, G1_dims, mat, mat_dims, G2, G2_dims, res_dims, - optimize=True) - return _affine(res, OV_gain, b_std) + elif attention_mechanism == AttentionMechanism.IDENTITY: - G1 = _get_G_softmax(cov1, k.mask1) - G2 = _get_G_softmax(cov2, k.mask2) if cov2 is not None else G1 + def dot(lhs, rhs, diagonal_batch=k.diagonal_batch): + if lhs is None: + return None + + c_axes = tuple(range(1 if diagonal_batch else 2, lhs.ndim)) + if rhs is None: + return np.sum(lhs**2, axis=c_axes, keepdims=True) + + rhs = np.broadcast_to(rhs, lhs.shape) + b_axes = (0,) if diagonal_batch else (0, 1) + res = lax.dot_general(lhs, rhs, ((c_axes, c_axes), (b_axes, b_axes))) + return res.reshape(res.shape + (1,) * len(c_axes)) + + dot11 = dot(cov1_interp, None if val_pos_emb else k.cov1) + dot12 = dot(nngp_interp, None if val_pos_emb else k.nngp, False) + dot22 = dot(cov2_interp, None if val_pos_emb else k.cov2) - cov1 = _transform_kernel(cov1, G1) - cov2 = _transform_kernel(cov2, G2) if cov2 is not None else cov2 - nngp = _transform_kernel(nngp, G1, G2) - ntk = (_transform_kernel(ntk, G1, G2) + 2 * (nngp - b_std**2) - if ntk is not None else ntk) + std = QK_std * OV_std - return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk, is_gaussian=True) + nngp = _affine(dot12 * nngp_interp, std, b_std) + cov1 = _affine(dot11 * cov1_interp, std, b_std) + cov2 = _affine(None if dot22 is None else dot22 * cov2_interp, std, b_std) + + if ntk_interp is not None: + if val_pos_emb or pos_emb_type == PositionalEmbedding.NONE: + nngp_dot_ntk = dot(nngp_interp, ntk_interp, False) + ntk = 2 * nngp_dot_ntk + + else: + nngp_dot_ntk_1 = dot(nngp_interp, k.ntk, False) + nngp_dot_ntk_2 = dot(k.nngp, ntk_interp, False) + ntk = (nngp_dot_ntk_1 + nngp_dot_ntk_2) + + ntk = _affine( + ntk * nngp_interp + dot12 * (ntk_interp + 4 * nngp_interp), + std, + b_std) + + else: + ntk = None + + else: + raise NotImplementedError(f'No known closed form expression for square ' + f'root scaling and {attention_mechanism} ' + f'attention mechanism.') + + return k.replace(cov1=cov1, + nngp=nngp, + cov2=cov2, + ntk=ntk, + is_gaussian=True) def mask_fn(mask, input_shape): return np.all(mask, channel_axis, keepdims=True) @@ -1650,14 +1957,16 @@ def LayerNorm( """Layer normalisation. Args: - axis: Specifies dimensions over which to normalize. - eps: Specifies (small) positive constant to be added to the variance - estimates in order to prevent division by zero. - batch_axis: Specifies the batch dimension. Defaults to `0`, the leading - axis. - channel_axis: Specifies the channel / feature dimension. Defaults to `-1`, - the trailing axis. For `kernel_fn`, channel size is considered to be - infinite. + axis: + dimensions over which to normalize. + eps: + (small) positive constant to be added to the variance estimates in order + to prevent division by zero. + batch_axis: + batch dimension. Defaults to `0`, the leading axis. + channel_axis: + channel / feature dimension. Defaults to `-1`, the trailing axis. For + `kernel_fn`, channel size is considered to be infinite. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -1754,9 +2063,11 @@ def Dropout(rate: float, mode: str = 'train') -> InternalLayer: Based on `jax.experimental.stax.Dropout`. Args: - rate: Specifies the keep `rate`, e.g. `rate=1` is equivalent to - keeping all neurons. - mode: Either `train` or `test`. + rate: + Specifies the keep `rate`, e.g. `rate=1` is equivalent to keeping all + neurons. + mode: + Either `"train"` or `"test"`. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -1850,7 +2161,7 @@ def _get_input_req_attr(kernel_fns: List[LayerKernelFn]) -> Dict[str, bool]: return req -def _double_tuple(x): +def _double_tuple(x: tuple) -> tuple: return tuple(v for v in x for _ in range(2)) @@ -2039,6 +2350,8 @@ def _inputs_to_kernel( errors and try to use an atypical for inputs value. eps: a small number used to check whether x1 and x2 are the same up to `eps`. + **kwargs: other arguments passed to all intermediary `kernel_fn` calls (not + used here). Returns: The `Kernel` object containing inputs covariance[s]. @@ -2251,6 +2564,7 @@ def kernel_fn_any(x1_or_kernel: Union[np.ndarray, Kernels], width, width, ...)`). Defaults to least compute-heavy setting necessary to compute the output `nngp` [and `ntk`] covariance. + **kwargs: other arguments passed to all intermediary `kernel_fn` calls. Returns: If `get` is a string, returns the requested `np.ndarray`. If `get` is a @@ -2755,10 +3069,13 @@ def _affine( Gaussian biases with std `b_std`. Args: - mat: a `np.ndarray` containing sample-[sample-]position[-position] - covariances of inputs. - W_std: a float, standard deviation of a fully-connected layer weights. - b_std: a float, standard deviation of a fully-connected layer biases. + mat: + a `np.ndarray` containing sample-[sample-]position[-position] covariances + of inputs. + W_std: + standard deviation of a fully-connected layer weights. + b_std: + standard deviation of a fully-connected layer biases. Returns: a `np.ndarray` containing sample-[sample-]position[-position] covariances @@ -3004,7 +3321,7 @@ def _conv_kernel_full_spatial( strides: Tuple[int, ...], padding: Padding, batch_ndim: int - ) -> Optional[np.ndarray]: +) -> Optional[np.ndarray]: """Compute covariance of the CNN outputs given inputs with covariance `mat`. Used when `kernel.diagonal_spatial == False`. @@ -3446,3 +3763,66 @@ def _pool_mask( f'please submit a bug to ' f'https://github.com/google/neural-tangents/issues/new.') return mask + + +# POSITIONAL EMBEDDINGS + + +@functools.lru_cache() +def _pos_emb_identity(shape: Tuple[int, ...]) -> np.ndarray: + size = utils.size_at(shape) + R = np.eye(size).reshape(shape * 2) + R = utils.zip_axes(R) + return R + + +@functools.lru_cache() +def _pos_emb_pdist(shape: Tuple[int, ...], + pos_emb_p_norm: Optional[float], + pos_emb_decay_fn: Optional[Callable[[float], float]] + ) -> np.ndarray: + if pos_emb_decay_fn is None: + # Identity / one-hot positional embeddings. + return _pos_emb_identity(shape) + + # Pairwise distance-based positional embeddings. + ndim = len(shape) + R = np.zeros((1,) * (ndim * 2)) + for axis in range(ndim): + d = np.arange(shape[axis]) + pd = utils.outer_prod(d, d, 0, d.ndim, op.sub) + pd = pd.reshape((1,) * (2 * axis) + + pd.shape + + (1,) * (2 * (ndim - axis - 1))) + R += np.abs(pd) ** pos_emb_p_norm + + R = pos_emb_decay_fn(R) + return R + + +def _get_all_pos_emb(k: Kernel, + pos_emb_type: PositionalEmbedding, + pos_emb_p_norm: float, + pos_emb_decay_fn: Optional[Callable[[float], float]] + ) -> Tuple[Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray]]: + if pos_emb_type == PositionalEmbedding.NONE: + return None, None, None + + shape, _ = utils.shape_and_axes(k.shape1, (k.batch_axis, k.channel_axis)) + R = _pos_emb_pdist(shape, pos_emb_p_norm, pos_emb_decay_fn) + + if k.is_reversed: + R = utils.reverse_zipped(R) + + batch_ndim = 1 if k.diagonal_batch else 2 + R11 = np.expand_dims(R, tuple(range(batch_ndim))) + R12 = R11 if batch_ndim == 2 else np.expand_dims(R, (0, 1)) + R22 = None if k.cov2 is None else R11 + + mask11, mask12, mask22 = k._get_mask_prods(k.mask1, k.mask2) + R11 = utils.mask(R11, mask11) + R12 = utils.mask(R12, mask12) + R22 = utils.mask(R22, mask22) + return R11, R12, R22 diff --git a/neural_tangents/utils/kernel.py b/neural_tangents/utils/kernel.py index f877c890..bdcb87b9 100644 --- a/neural_tangents/utils/kernel.py +++ b/neural_tangents/utils/kernel.py @@ -159,27 +159,16 @@ def reverse(self) -> 'Kernel': `reverse(kernels).nngp` has shape `(batch_size_1, batch_size_2, ..., D, D, W, W, H, H)`. """ - # Number of spatial dimensions = total - (1 for batch + 1 for channels) - ndim = len(self.shape1) - 2 + batch_ndim = 1 if self.diagonal_batch else 2 + cov1 = utils.reverse_zipped(self.cov1, batch_ndim) + cov2 = utils.reverse_zipped(self.cov2, batch_ndim) + nngp = utils.reverse_zipped(self.nngp, 2) + ntk = utils.reverse_zipped(self.ntk, 2) - # ndim == 3: (-5, -6, -3, -4, -1, -2) - source_axes = tuple(j for i in range(-ndim * 2, 0, 2) for j in (i + 1, i)) - - # ndim == 3: (-1, -2, -3, -4, -5, -6) - target_axes = tuple(range(-1, -ndim * 2 - 1, -1)) - - def reverse(mat): - if mat is not None: - return np.moveaxis(mat, source_axes, target_axes) - return mat - - cov1, nngp, cov2, ntk = map(reverse, (self.cov1, - self.nngp, - self.cov2, - self.ntk)) return self.replace(cov1=cov1, nngp=nngp, - cov2=cov2, ntk=ntk, + cov2=cov2, + ntk=ntk, is_reversed=not self.is_reversed) def transpose(self, axes: Tuple[int, ...] = None) -> 'Kernel': @@ -221,15 +210,10 @@ def mask(self, """Mask all covariance matrices according to `mask1`, `mask2`""" mask11, mask12, mask22 = self._get_mask_prods(mask1, mask2) - def mask_mat(mat, mask): - if mat is None or mask is None: - return mat - return np.where(mask, np.zeros((), mat.dtype), mat) - - cov1 = mask_mat(self.cov1, mask11) - cov2 = mask_mat(self.cov2, mask22) - nngp = mask_mat(self.nngp, mask12) - ntk = mask_mat(self.ntk, mask12) + cov1 = utils.mask(self.cov1, mask11) + cov2 = utils.mask(self.cov2, mask22) + nngp = utils.mask(self.nngp, mask12) + ntk = utils.mask(self.ntk, mask12) return self.replace(cov1=cov1, nngp=nngp, diff --git a/neural_tangents/utils/utils.py b/neural_tangents/utils/utils.py index ec51a424..f834d764 100644 --- a/neural_tangents/utils/utils.py +++ b/neural_tangents/utils/utils.py @@ -314,6 +314,18 @@ def outer_prod(x, y, start_axis, end_axis, prod_op): return prod_op(x, y) +def reverse_zipped(mat: np.ndarray, start_axis: int = 0) -> np.ndarray: + if mat is not None: + source_axes = tuple(j + for i in range(mat.ndim - 2, start_axis - 1, -2) + for j in (i, i + 1)) + + target_axes = range(start_axis, mat.ndim) + mat = np.moveaxis(mat, source_axes, target_axes) + + return mat + + ArrayOrList = Union[Optional[np.ndarray], List[Optional[np.ndarray]]] @@ -346,26 +358,30 @@ def get_masked_array(x: ArrayOrList, return MaskedArray(*(list(f) for f in fields)) if x is None: - mask = None + mask_mat = None elif isinstance(x, MaskedArray): - x, mask = x.astuple() + x, mask_mat = x.astuple() elif isinstance(x, np.ndarray): if mask_constant is None: - mask = None + mask_mat = None else: - mask = lax.cond(np.isnan(mask_constant), - lambda x: np.isnan(x), - lambda x: x == mask_constant, - x) + mask_mat = lax.cond(np.isnan(mask_constant), + lambda x: np.isnan(x), + lambda x: x == mask_constant, + x) else: raise TypeError(x, type(x)) - if mask is not None: - x = np.where(mask, np.zeros((), x.dtype), x) + x = mask(x, mask_mat) + return MaskedArray(x, mask_mat) # pytype: disable=wrong-arg-count - return MaskedArray(x, mask) # pytype: disable=wrong-arg-count + +def mask(x: Optional[np.ndarray], mask_mat: Optional[np.ndarray]): + if x is None or mask_mat is None: + return x + return np.where(mask_mat, np.zeros((), x.dtype), x) def size_at(x: Union[np.ndarray, Sequence[int]], @@ -379,6 +395,18 @@ def size_at(x: Union[np.ndarray, Sequence[int]], return functools.reduce(operator.mul, [x[a] for a in axes], 1) +def shape_and_axes( + x: Union[np.ndarray, Sequence[int]], + ignore_axes: Iterable[int] = ()) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + if hasattr(x, 'shape'): + x = x.shape + ndim = len(x) + ignore_axes = tuple(i % ndim for i in ignore_axes) + axes = tuple(i for i in range(ndim) if i not in ignore_axes) + shape = tuple(x[i] for i in axes) + return shape, axes + + def get_res_batch_dims(contracting_dims: List[int], batch_dims: List[int]) -> List[int]: res_batch_dims = [2 * b - i for i, b in enumerate(batch_dims)] diff --git a/tests/stax_test.py b/tests/stax_test.py index 4cdb10cb..6959a4f4 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -15,16 +15,17 @@ """Tests for stax.py.""" -import string - -import random as prandom import functools import itertools import logging +import random as prandom +import string +from typing import Tuple + from absl.testing import absltest -from jax.api import jit from jax import ops from jax import test_util as jtu +from jax.api import jit from jax.config import config as jax_config from jax.lib import xla_bridge import jax.numpy as np @@ -33,7 +34,6 @@ from neural_tangents.utils import monte_carlo from neural_tangents.utils import test_utils import numpy as onp -from typing import Tuple jax_config.parse_flags_with_absl() @@ -50,7 +50,7 @@ WIDTHS = [2**10] -N_SAMPLES = 100 +N_SAMPLES = 128 RTOL = 0.025 @@ -77,8 +77,7 @@ PROJECTIONS = [ 'FLAT', 'POOL', - 'ATTN_FIXED', - 'ATTN_PARAM' + 'ATTN', ] LAYER_NORM = [ @@ -240,14 +239,13 @@ def conv(out_chan): elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) - fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, - fixed=fixed, + linear_scaling=True, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, @@ -370,8 +368,9 @@ def test_exact(self, model, width, strides, padding, phi, same_inputs, net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) - self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk, - proj_into_2d) + self._check_agreement_with_empirical( + net, same_inputs, use_dropout, is_ntk, + rtol=0.03 if proj_into_2d == 'ATTN' else RTOL) # pylint: disable=g-complex-comprehension @jtu.parameterized.named_parameters( @@ -425,8 +424,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) - self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk, - proj_into_2d) + self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -486,7 +484,7 @@ def test_layernorm(self, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk, - proj_into_2d, True) + 0.05) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -513,7 +511,7 @@ def test_layernorm(self, strides, 'normalize_edges': normalize_edges - } for width in WIDTHS for same_inputs in [False, True] + } for width in WIDTHS for same_inputs in [False] for is_ntk in [False, True] for pool_type in POOL_TYPES for padding in PADDINGS for filter_shape in FILTER_SHAPES @@ -521,11 +519,8 @@ def test_layernorm(self, for normalize_edges in [True, False])) def test_pool(self, width, same_inputs, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges): - is_conv = True use_dropout = False - proj_into_2d = 'POOL' # Check for duplicate / incorrectly-shaped NN configs / wrong backend. - if xla_bridge.get_backend().platform == 'cpu': raise absltest.SkipTest('Not running CNN models on CPU to save time.') if pool_type == 'SUM' and normalize_edges: @@ -533,8 +528,7 @@ def test_pool(self, width, same_inputs, is_ntk, pool_type, net = _get_net_pool(width, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges) - self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk, - proj_into_2d) + self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk) def test_avg_pool(self): X1 = np.ones((4, 2, 3, 2)) @@ -652,8 +646,7 @@ def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides, net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) - self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk, - proj_into_2d) + self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -763,8 +756,8 @@ def _check_agreement_with_empirical( same_inputs, use_dropout, is_ntk, - proj_into_2d, - use_layer_norm=False): + rtol=RTOL + ): ((init_fn, apply_fn, kernel_fn), input_shape, device_count, channel_axis) = net @@ -784,37 +777,20 @@ def _check_agreement_with_empirical( def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=device_count, - trace_axes=(channel_axis,) - ) + trace_axes=(channel_axis,)) if same_inputs: assert x2 is None return kernel_fn_empirical(x1, x2, get) - if proj_into_2d == 'ATTN_PARAM': - # no analytic kernel available, just test forward/backward pass - _get_empirical(1, 'ntk' if is_ntk else 'nngp') + if is_ntk: + exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) + empirical = _get_empirical(num_samples, 'ntk') else: - platform = xla_bridge.get_backend().platform - if proj_into_2d == 'ATTN_FIXED': - if platform == 'tpu': - rtol = 0.08 - else: - rtol = 0.04 - else: - if use_layer_norm and platform == 'tpu': - rtol = 0.05 - else: - rtol = RTOL - - if is_ntk: - exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) - empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) - else: - exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) - empirical = _get_empirical(num_samples, 'nngp') - test_utils.assert_close_matrices(self, exact, empirical, rtol) - self.assertEqual(shape1, x1_out_shape) - self.assertEqual(shape2, x2_out_shape) + exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) + empirical = _get_empirical(num_samples, 'nngp') + test_utils.assert_close_matrices(self, exact, empirical, rtol) + self.assertEqual(shape1, x1_out_shape) + self.assertEqual(shape2, x2_out_shape) class ActivationTest(test_utils.NeuralTangentsTestCase): @@ -846,7 +822,6 @@ def kernel_fn(kernels, **kwargs): nngp=np.exp(-input_dim * gamma * (cov1 + cov2 - 2 * nngp))) return init_fn, apply_fn, kernel_fn - def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): platform = xla_bridge.get_backend().platform @@ -1015,7 +990,6 @@ def _test_activation(self, activation, fn, same_inputs, model, get): test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol) - @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': @@ -1544,12 +1518,12 @@ def test_conv_nd(self, same_inputs, n, get, proj, use_attn, channels_first, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, - fixed=True, + linear_scaling=True, W_key_std=2., W_value_std=1., W_query_std=1., W_out_std=1.0, - b_std=0.01, + b_std=0.1, channel_axis=channel_axis), proj) nn = stax.serial( @@ -1850,7 +1824,7 @@ def apply_mask(x): 'use_attn': use_attn, 'n': n } - for proj in ['avg', 'flatten'] + for proj in ['avg'] for use_attn in [True] for same_inputs in [False] for get in ['nngp', 'ntk'] @@ -2014,5 +1988,163 @@ def get_attn(): test_utils.assert_close_matrices(self, empirical, exact, tol) +class AttentionTest(test_utils.NeuralTangentsTestCase): + + @jtu.parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + f'[same_inputs={same_inputs}_' + f'get={get}_' + f'axis={mask_axis}' + f'_mask={mask_constant}_' + f'p={p}_' + f'linear_scaling={linear_scaling}_' + f'n={n}_pos_emb_type={pos_emb_type}_' + f'n_chan_pos_emb={n_chan_pos_emb}' + f'_pos_emb_decay_fn={pos_emb_decay_fn}_' + f'val_pos_emb={val_pos_emb}_' + f'W_pos_emb_std={W_pos_emb_std}]', + 'same_inputs': same_inputs, + 'get': get, + 'n': n, + 'linear_scaling': linear_scaling, + 'mask_constant': mask_constant, + 'p': p, + 'mask_axis': mask_axis, + 'pos_emb_type': pos_emb_type, + 'n_chan_pos_emb': n_chan_pos_emb, + 'pos_emb_decay_fn': pos_emb_decay_fn, + 'val_pos_emb': val_pos_emb, + 'W_pos_emb_std': W_pos_emb_std + } + for same_inputs in [ + False + ] + for get in [ + 'nngp', + 'ntk' + ] + for n in [ + 2, + ] + for linear_scaling in [ + True, + False + ] + for mask_constant in [ + 10. + ] + for p in [0.5] + for mask_axis in [(-1,)] + for pos_emb_type in [ + 'CONCAT', + 'SUM', + 'NONE' + ] + for n_chan_pos_emb in ([None] + if pos_emb_type != 'CONCAT' + else [None, 512]) + for pos_emb_decay_fn in [ + None, + 'linear' + ] + for val_pos_emb in ([ + True, + False + ] if pos_emb_type != 'NONE' else [True]) + for W_pos_emb_std in ([ + 2, + ] if pos_emb_type != 'NONE' else [0.]) + )) + def test_attention( + self, + same_inputs, + get, + n, + linear_scaling, + mask_constant, + p, + mask_axis, + pos_emb_type, + n_chan_pos_emb, + pos_emb_decay_fn, + val_pos_emb, + W_pos_emb_std): + if xla_bridge.get_backend().platform == 'cpu': + raise absltest.SkipTest('Skipping attention tests on CPU for speed.') + + width = 1024 + n_samples = 1024 + tol = 0.05 + key = random.PRNGKey(1) + n_chan_in = 2 + spatial_shape = (2, 3, 4, 3, 2, 1)[:n] + mask_axis = [i % (n + 2) for i in mask_axis] + + def get_x0(batch_size): + x0 = random.normal(key, (batch_size,) + spatial_shape + (n_chan_in,)) + if mask_constant is not None: + mask_shape = [1 if i in mask_axis else s + for i, s in enumerate(x0.shape)] + mask = random.bernoulli(key, p=p, shape=mask_shape) + x0 = np.where(mask, mask_constant, x0) + x0 = np.sort(x0, 1) + return x0 + + X0_1 = get_x0(2) + X0_2 = None if same_inputs else get_x0(4) + + pos_emb_fns = { + None: None, + 'one_hot': lambda x: x == 0, + 'linear': lambda x: 1 / (1 + 4 * x) + } + + def get_attn(): + return stax.GlobalSelfAttention( + linear_scaling=linear_scaling, + n_chan_out=width, + n_chan_key=width, + n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), + n_heads=int(np.sqrt(width)), + n_chan_pos_emb=n_chan_pos_emb, + attention_mechanism='SOFTMAX' if linear_scaling else 'IDENTITY', + pos_emb_type=pos_emb_type, + W_pos_emb_std=W_pos_emb_std, + pos_emb_decay_fn=pos_emb_fns[pos_emb_decay_fn], + val_pos_emb=val_pos_emb, + W_key_std=0.9, + W_out_std=1.2, + W_query_std=0.7, + W_value_std=1.5, + b_std=0.9 + ) + + nn = stax.serial( + stax.Conv(width, (1,) * n, padding='SAME'), + get_attn(), + stax.Relu(), + stax.GlobalAvgPool() + ) + + if get == 'nngp': + init_fn, apply_fn, kernel_fn = nn + elif get == 'ntk': + init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) + else: + raise ValueError(get) + + kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( + init_fn, apply_fn, key, n_samples, + device_count=-1 + ) + + kernel_fn = jit(kernel_fn, static_argnums=(2,)) + exact = kernel_fn(X0_1, X0_2, get, mask_constant=mask_constant) + + empirical = kernel_fn_mc(X0_1, X0_2, get=get, mask_constant=mask_constant) + test_utils.assert_close_matrices(self, empirical, exact, tol) + + if __name__ == '__main__': absltest.main()