Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement-batch-norm-layer #217

Merged
merged 23 commits into from
Dec 25, 2023
Merged

Conversation

m0saan
Copy link
Contributor

@m0saan m0saan commented Dec 19, 2023

Proposed changes

Description

This pull request introduces implementation of Batch Normalization, following the specifications outlined in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

Changes Made

  • Added a new class BatchNorm1d that extends the Module class.
  • The implementation includes the forward pass logic for Batch Normalization.
  • Options for configurable parameters such as eps (numerical stability constant), momentum (for running mean and variance updates), and affine (whether to include learnable affine parameters).
  • Provided examples in the documentation to demonstrate how to use the BatchNorm1d module with and without learnable parameters.

Usage

import mlx.core as mx
import mlx.nn as nn

# With Learnable Parameters
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(4, affine=False)
input = mx.random.normal(20, 4)
output = m(input)

Notes

  • The implementation ensures compatibility with mlx conventions and practices.

Please review and provide feedback.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@m0saan m0saan marked this pull request as draft December 19, 2023 02:52
@m0saan m0saan marked this pull request as ready for review December 19, 2023 06:53
@m0saan m0saan requested a review from dc-dc-dc December 19, 2023 09:18
@m0saan
Copy link
Contributor Author

m0saan commented Dec 19, 2023

Hey @awni, I've been thinking about how we should structure the batch normalization module. Do you think it's a good idea to have one class that covers all batch normalization types (like BatchNorm1d, BatchNorm2d, BatchNorm3d), or do we go down the road of having separate classes for each type? I'd love to know what you think!

@awni
Copy link
Member

awni commented Dec 19, 2023

Adds #204 #216

@awni
Copy link
Member

awni commented Dec 19, 2023

From an implementation standpoint, I think having a single BatchNorm with the ability to specify a tuple of axes is a good idea, I would also go with that from an API standpoint for now since it's pretty clean and general (compared to adding ND options).

@gboduljak Has an implementation in #216

@m0saan
Copy link
Contributor Author

m0saan commented Dec 19, 2023

From an implementation standpoint, I think having a single BatchNorm with the ability to specify a tuple of axes is a good idea, I would also go with that from an API standpoint for now since it's pretty clean and general (compared to adding ND options).

@gboduljak Has an implementation in #216

I fully agree with the idea of a unified BatchNorm. It's a clean and versatile approach for both implementation and the API.

@m0saan
Copy link
Contributor Author

m0saan commented Dec 19, 2023

@awni Here is an updated version of BN, that is general:

from typing import Tuple

import mlx.core as mx
from mlx.nn.layers.base import Module

class BatchNorm(Module):
    def __init__(
        self,
        num_features: int,
        num_dims: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super().__init__()

        dims_dict = {
            2: ((1, num_features), (0,)),
            3: ((1, num_features, 1), (0, 2)),
            4: ((1, num_features, 1, 1), (0, 2, 3)),
        }

        if num_dims not in dims_dict:
            raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")

        shape, self.reduction_axes = dims_dict[num_dims]
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if self.affine:
            self.weight = mx.ones(shape)
            self.bias = mx.zeros(shape)

        if self.track_running_stats:
            self.running_mean = mx.zeros(shape)
            self.running_var = mx.ones(shape)

    def _extra_repr(self):
        return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"

    def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """
        Calculate the mean and variance of the input tensor.

        Args:
            x (mx.array): Input tensor.

        Returns:
            tuple: Tuple containing mean and variance.
        """

        
        means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
        var = mx.var(x, axis=self.reduction_axes, keepdims=True)

        if self.track_running_stats and self.training:
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * means
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * var
        return means, var

    def __call__(self, x: mx.array):
        """
        Forward pass of BatchNorm1d.

        Args:
            x (mx.array): Input tensor.

        Returns:
            mx.array: Output tensor.
        """

        if self.training or not self.track_running_stats:
            means, var = self._calc_stats(x)
        else:
            means, var = self.running_mean, self.running_var
        x = (x - means) * mx.rsqrt(var + self.eps)
        return (self.weight * x + self.bias) if "weight" in self else x
        # return x

but can be used as follow:

batch_size = 4
num_features = 32
num_iters = 5
input = mx.random.normal((batch_size, num_features))
bn = BatchNorm(num_features=num_features, num_dims=2)
output = bn(input)

@m0saan
Copy link
Contributor Author

m0saan commented Dec 19, 2023

We can remove the num_dims parameter by updating the implementation like so.

class BatchNorm(Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if self.affine:
            self.weight = mx.ones((num_features,))
            self.bias = mx.zeros((num_features,))

        if self.track_running_stats:
            self.running_mean = mx.zeros((num_features,))
            self.running_var = mx.ones((num_features,))

    def _extra_repr(self):
        return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
    
    def _check_and_expand_dims(self, x: mx.array):
        """
        Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.

        Args:
            x (mx.array): Input tensor.
        """
        
        num_dims = len(x.shape)
        dims_dict = {
            2: ((1, self.num_features), (0,)),
            3: ((1, self.num_features, 1), (0, 2)),
            4: ((1, self.num_features, 1, 1), (0, 2, 3)),
        }

        if num_dims not in dims_dict:
            raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")

        shape, self.reduction_axes = dims_dict[num_dims]
        
        if self.affine and self.weight.ndim != num_dims:
            self.weight = mx.expand_dims(self.weight, self.reduction_axes)
            self.bias = mx.expand_dims(self.bias, self.reduction_axes)
        
        if self.track_running_stats and self.running_mean.ndim != num_dims:
            self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
            self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)

    def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """
        Calculate the mean and variance of the input tensor.

        Args:
            x (mx.array): Input tensor.

        Returns:
            tuple: Tuple containing mean and variance.
        """

        means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
        var = mx.var(x, axis=self.reduction_axes, keepdims=True)

        if self.track_running_stats and self.training:
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * means
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * var
        return means, var

    def __call__(self, x: mx.array):
        """
        Forward pass of BatchNorm1d.

        Args:
            x (mx.array): Input tensor.

        Returns:
            mx.array: Output tensor.
        """
        
        self._check_and_expand_dims(x)

        if self.training or not self.track_running_stats:
            means, var = self._calc_stats(x)
        else:
            means, var = self.running_mean, self.running_var
        x = (x - means) * mx.rsqrt(var + self.eps)
        return (self.weight * x + self.bias) if "weight" in self else x
    

@gboduljak
Copy link
Contributor

gboduljak commented Dec 19, 2023

We can remove the num_dims parameter by updating the implementation like so.

class BatchNorm(Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if self.affine:
            self.weight = mx.ones((num_features,))
            self.bias = mx.zeros((num_features,))

        if self.track_running_stats:
            self.running_mean = mx.zeros((num_features,))
            self.running_var = mx.ones((num_features,))

    def _extra_repr(self):
        return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
    
    def _check_and_expand_dims(self, x: mx.array):
        """
        Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.

        Args:
            x (mx.array): Input tensor.
        """
        
        num_dims = len(x.shape)
        dims_dict = {
            2: ((1, self.num_features), (0,)),
            3: ((1, self.num_features, 1), (0, 2)),
            4: ((1, self.num_features, 1, 1), (0, 2, 3)),
        }

        if num_dims not in dims_dict:
            raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")

        shape, self.reduction_axes = dims_dict[num_dims]
        
        if self.affine and self.weight.ndim != num_dims:
            self.weight = mx.expand_dims(self.weight, self.reduction_axes)
            self.bias = mx.expand_dims(self.bias, self.reduction_axes)
        
        if self.track_running_stats and self.running_mean.ndim != num_dims:
            self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
            self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)

    def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """
        Calculate the mean and variance of the input tensor.

        Args:
            x (mx.array): Input tensor.

        Returns:
            tuple: Tuple containing mean and variance.
        """

        means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
        var = mx.var(x, axis=self.reduction_axes, keepdims=True)

        if self.track_running_stats and self.training:
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * means
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * var
        return means, var

    def __call__(self, x: mx.array):
        """
        Forward pass of BatchNorm1d.

        Args:
            x (mx.array): Input tensor.

        Returns:
            mx.array: Output tensor.
        """
        
        self._check_and_expand_dims(x)

        if self.training or not self.track_running_stats:
            means, var = self._calc_stats(x)
        else:
            means, var = self.running_mean, self.running_var
        x = (x - means) * mx.rsqrt(var + self.eps)
        return (self.weight * x + self.bias) if "weight" in self else x
    

@m0saan your final suggestion looks great. Your dims_dict captures commonly used scenarios well. However, I think that more flexibility in selection of reduction axes and feature axes is beneficial. Maybe we cannot always infer these from the shape of the array. To this end, we can maybe keep reduction_axes, feature_axes and num_features arguments in the constructor of BatchNorm. However, we can extract your dims_dict outside of the BatchNorm class and use it in an enum.

Then we can use BatchNorm as follows,

bn = mx.layers.BatchNorm(
   num_features=16, 
   reduction_axes=BatchNormReductionAxes.2D, 
   feature_axes=BatchNormFeatureAxes.2D
)

We can implement some validations, e.g. axis cannot be both a reduction axis and a feature axis.
However, I am not sure that this generality is necessary. I am also not sure what are the performance implications of using expand_dims. Thus, a 'static' implementation as above may be faster.

@gboduljak
Copy link
Contributor

gboduljak commented Dec 19, 2023

In this discussion post, I included some ideas to test batch norm layers. I think it is important to verify we match PyTorch and/or Jax implementations. Maybe you can use these tests. I can also add them. Your tests look good as well, but we may want to test whether we are doing moving stats tracking correctly. It would be also beneficial to test that BatchNorm is behaving correctly in train/eval mode.

@m0saan
Copy link
Contributor Author

m0saan commented Dec 20, 2023

Hey @gboduljak, thanks a lot for your input! I really value your ideas on making the BatchNorm class more flexible. We would like to maintain the simplicity and user-friendliness of the framework. While your proposed changes provide additional options, they may also introduce complexity that might be unnecessary for many use cases. Maybe @awni has some thoughts on this too?

Regarding your point on the performance impact of using expand_dims, I get your concern. We're doing it just once to ensure all created arrays have shapes suitable for broadcasting. In my opinion, it's not a big issue, but I'd love to hear @awni's thoughts to make sure!

@m0saan
Copy link
Contributor Author

m0saan commented Dec 20, 2023

In this discussion post, I included some ideas to test batch norm layers. I think it is important to verify we match PyTorch and/or Jax implementations. Maybe you can use these tests. I can also add them. Your tests look good as well, but we may want to test whether we are doing moving stats tracking correctly. It would be also beneficial to test that BatchNorm is behaving correctly in train/eval mode.

I will incorporate these ideas into the testing of BN. If you have additional tests to add, feel free to include them, and we can collaborate to ensure comprehensive testing.

@rickypang0219
Copy link

rickypang0219 commented Dec 20, 2023

Could someone explain to me what is the axes in num_dict referring to? I know that having this dict we can generalise the BatchNorm Class to higher-dimension but I do not understand what is the meaning of (1, self.num_features), (1, self.num_features, 1), and (1, self.num_features, 1,1)

dims_dict = {
            2: ((1, self.num_features), (0,)),
            3: ((1, self.num_features, 1), (0, 2)),
            4: ((1, self.num_features, 1, 1), (0, 2, 3)),
        }

Besides, if we generalise to N-dimensional BatchNorm, does dims_dict be like

dims_dict = {
            N : (1, self.num_features, *( 1 for I in range(3,N)) )  , tuple( i for i in range(N) if i != 1 ) 
        }

@gboduljak
Copy link
Contributor

gboduljak commented Dec 20, 2023

Could someone explain to me what is the axes in num_dict referring to? I know that having this dict we can generalise the BatchNorm Class to higher-dimension but I do not understand what is the meaning of (1, self.num_features), (1, self.num_features, 1), and (1, self.num_features, 1,1)

dims_dict = {
            2: ((1, self.num_features), (0,)),
            3: ((1, self.num_features, 1), (0, 2)),
            4: ((1, self.num_features, 1, 1), (0, 2, 3)),
        }

Depending on the BatchNorm you want to implement (e.g. 1D, 2D), you want to normalize over different axes.
For example, in BatchNorm1D, your input is of shape [batch_dim, num_features] and you normalize over the batch axis (0). This means you also have num_features scale and shift parameters. To broadcast correctly when scaling and shifting your normalized inputs, your parameter shape is [1, num_features]. If you use BatchNorm2D, your input is of shape [batch_dim, num_features, height, width] and you want to normalize over all axes except the channel axis (axis 1). Thus, you normalize over axes (0, 2, 3). You also have num_features parameters, but you need to keep them in a tensor of shape [1, num_features, 1, 1] to broadcast correctly when scaling and shifting after normalization. Hope this helps :)

@gboduljak
Copy link
Contributor

gboduljak commented Dec 20, 2023

We can remove the num_dims parameter by updating the implementation like so.

class BatchNorm(Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if self.affine:
            self.weight = mx.ones((num_features,))
            self.bias = mx.zeros((num_features,))

        if self.track_running_stats:
            self.running_mean = mx.zeros((num_features,))
            self.running_var = mx.ones((num_features,))

    def _extra_repr(self):
        return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
    
    def _check_and_expand_dims(self, x: mx.array):
        """
        Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.

        Args:
            x (mx.array): Input tensor.
        """
        
        num_dims = len(x.shape)
        dims_dict = {
            2: ((1, self.num_features), (0,)),
            3: ((1, self.num_features, 1), (0, 2)),
            4: ((1, self.num_features, 1, 1), (0, 2, 3)),
        }

        if num_dims not in dims_dict:
            raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")

        shape, self.reduction_axes = dims_dict[num_dims]
        
        if self.affine and self.weight.ndim != num_dims:
            self.weight = mx.expand_dims(self.weight, self.reduction_axes)
            self.bias = mx.expand_dims(self.bias, self.reduction_axes)
        
        if self.track_running_stats and self.running_mean.ndim != num_dims:
            self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
            self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)

    def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """
        Calculate the mean and variance of the input tensor.

        Args:
            x (mx.array): Input tensor.

        Returns:
            tuple: Tuple containing mean and variance.
        """

        means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
        var = mx.var(x, axis=self.reduction_axes, keepdims=True)

        if self.track_running_stats and self.training:
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * means
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * var
        return means, var

    def __call__(self, x: mx.array):
        """
        Forward pass of BatchNorm1d.

        Args:
            x (mx.array): Input tensor.

        Returns:
            mx.array: Output tensor.
        """
        
        self._check_and_expand_dims(x)

        if self.training or not self.track_running_stats:
            means, var = self._calc_stats(x)
        else:
            means, var = self.running_mean, self.running_var
        x = (x - means) * mx.rsqrt(var + self.eps)
        return (self.weight * x + self.bias) if "weight" in self else x
    

I just came up with a new idea to avoid repeatedly calling _check_and_expand_dims. In (very) large models, this may cause some overhead, due to the cost of setting up a new call frame in Python. We can use the condition self.weight.ndim != num_dims within __call__ to determine whether it is necessary to expand dims. If it is, we can call _setup_parameter_shape, implementing your logic. Alternatively, we may just include all the logic for the expansion in the __call__. The same holds for _calc_stats. Perhaps we can include _calc_stats within __call__ to eliminate the function call.

@m0saan what do you think?

@m0saan
Copy link
Contributor Author

m0saan commented Dec 22, 2023

Hello @gboduljak, I apologize for the delayed response. Using self.weight.ndim != num_dims may not be effective in all scenarios. Consider a situation where self.affine is set to False; in this case, there is no weight parameter in the BN Module, rendering the condition self.weight.ndim != num_dims inappropriate. To address this, we could introduce an additional check, such as if self.affine and self.weight.ndim != num_dims. However, implementing this check may pose a challenge because it prevents the reshaping of moving_mean and moving_var.

@m0saan
Copy link
Contributor Author

m0saan commented Dec 22, 2023

@awni can you please review?

@awni
Copy link
Member

awni commented Dec 22, 2023

Sorry for the delay in reviewing this, we were busy getting v0.0.6 out yesterday w/ quantization etc. I will get on this asap.

@m0saan
Copy link
Contributor Author

m0saan commented Dec 22, 2023

Sorry for the delay in reviewing this, we were busy getting v0.0.6 out yesterday w/ quantization etc. I will get on this asap.

sure, thanks. I have a question, while trying to make sure that the stats calculated by BatchNorm layer is correct I noticed this, the variance is somehow is a bit not equal to the one calculated by PyTorch:

Screenshot 2023-12-22 at 19 42 36

but it is the same as numpy:

Screenshot 2023-12-22 at 19 46 56

Copy link
Contributor

@robertmccraith robertmccraith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on trying this layer in MIMM using the mimm/scripts/train.py to train imagenet

python/mlx/nn/layers/normalization.py Outdated Show resolved Hide resolved
python/mlx/nn/layers/normalization.py Outdated Show resolved Hide resolved
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice! I have one request that I think will make it perfect then we merge it.

Comment on lines 199 to 200
The input tensor shape is specified as (N, C) or (N, C, L), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, C, L).
For three-dimensional tensors, the shape is denoted as (N, C, H, W), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

I think one change can make it more consistent and a lot simpler:

For our convolutions (and in general) we follow the convention that the channels are last. So inputs to convolutions are NLC or NHWC. We should change two thigns:

  1. Batch norm should also follow that convention
  2. Since it is following that convention it should easily broadcast with the inputs and you can remove the whole check_and_expand_dims machinery and just let broadcasting manage it (it's super cheap to expand dims at runtime so from a perf perspective it should be trivial!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, I've updated the batch norm implementation and tests to handle inputs of shape, NLC, NWHC!

m0saan and others added 17 commits December 24, 2023 23:08
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Update BatchNorm to support NLC and NHWC input formats

In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over.

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 This looks awesome, thanks for adding it!

@awni
Copy link
Member

awni commented Dec 25, 2023

@m0saan

torch.var uses a bias correction by default whereas MLX and NumPy do not, that is why you see the slight difference. I think it is the right call for now to use the uncorrected variance in our BN as I believe PyTorch also uses an uncorrected variance in their normalization layers.

@awni
Copy link
Member

awni commented Dec 25, 2023

PS @gboduljak, @dc-dc-dc, @robertmccraith thanks for the extra reviews / discussion!

PS @robertmccraith I'm following mimm eagerly, keep us posted on how it's going and what else you need to get it fully operational!

@awni awni merged commit a123c3c into ml-explore:main Dec 25, 2023
@m0saan
Copy link
Contributor Author

m0saan commented Dec 25, 2023

thanks @awni for your inputs!

awni added a commit that referenced this pull request Dec 25, 2023
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
Jyun1998 pushed a commit to Jyun1998/mlx that referenced this pull request Jan 7, 2024
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants