Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify normalization layers #3

Closed
seanmor5 opened this issue Feb 23, 2021 · 1 comment
Closed

Unify normalization layers #3

seanmor5 opened this issue Feb 23, 2021 · 1 comment

Comments

@seanmor5
Copy link
Contributor

seanmor5 commented Feb 23, 2021

The current API has 4 normalization layers:

  • Batch Normalization
  • Instance Normalization
  • Group Normalization
  • Layer Normalization

All of these implementations are built on a fundamental formula:

defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
  opts = keyword!(opts, epsilon: 1.0e-6)
  scale =
    variance
    |> Nx.add(opts[:epsilon])
    |> Nx.rsqrt()
    |> Nx.multiply(gamma)

  input
  |> Nx.subtract(mean)
  |> Nx.multiply(scale)
  |> Nx.add(bias)
end

But differ in how the compute the mean and variance across the input:

  • Batch Normalization - calculated for each individual channel across all samples and spatial dimensions.
    • reduction_axes: [:batch, :height, :width, ...]
  • Instance Normalization - calculated for each individual channel for each individual sample across both spatial dimensions.
    • reduction_axes: [:height, :width, ...]
  • Layer Normalization - calculated for each individual sample across all channels and both spatial dimensions.
    • reduction_axes: [:channels, :height, :width, ...]
  • Group Normalization - calculated across groups of channels and both spatial dimensions for the given group size.
    • reduction_axes: [:groups, :height, :width, ...] (after some reshaping to get :groups)

Additionally, some of these layers are stateful (batch/instance norm) and some are stateless (layer/group norm). Stateful normalization layers return the transformed input and a running average mean and variance adjusted with momentum, relying on the state to compute the next iteration of normalization. Stateless normalization layers return just the transformed input.

In order to unify these normalization layers under the lower-level functional API, rather than have individualized functions for each layer we will instead have:

In the layers API:

  • normalize - see above

In a separate module:

  • batch_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • instance_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • group_norm_stats(input, opts \\ []) - returns {mean, var}
  • layer_norm_stats(input, opts \\ []) - returns {mean, var}

In a separate module (probably an updates.ex or something that has gradient/parameter transforms):

  • ema(x, momentum) - returns a scaled x, exponential moving average

I think this limits code reuse and still enables us to easily build these normalization layers into a high level API

@seanmor5
Copy link
Contributor Author

Mostly resolved here with shared normalize and mean_and_var functions: 5428ec2

ema will be addressed with other state management.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant