Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request of Implementation of BatchNorm #204

Closed
rickypang0219 opened this issue Dec 18, 2023 · 4 comments
Closed

Request of Implementation of BatchNorm #204

rickypang0219 opened this issue Dec 18, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@rickypang0219
Copy link

Hi all. I am new to AI and Deep Learning and I am fascinated by this package. When I look at the documentation, I found that currently there are RMS/Group/Layer Normalisation but there is no BatchNorm. I hope that there can be a BatchNorm function for the deep learning model since it helps us to reduce the internal covariate shift.

@rickypang0219
Copy link
Author

rickypang0219 commented Dec 18, 2023

Hi, I tried to implement the 1D BatchNorm in MLX and do a comparison test with PyTorch. Due to the similarity of the equations of Layer Norm and Batch Norm, I just modified the axis for summation from axis=-1 to axis=0.

import mlx.core as mx
from mlx.nn.layers.base import Module

class MLX_BatchNorm1d(Module): 
    r"""Applies layer normalization [1] on the inputs.

    Computes

    .. math::

        y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,

    where :math:`\gamma` and :math:`\beta` are learned per feature dimension
    parameters initialized at 1 and 0 respectively.

    """
    def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
        super().__init__()
        if affine:
            self.bias = mx.zeros((dims,))
            self.weight = mx.ones((dims,))
        self.eps = eps
        self.dims = dims

    def _extra_repr(self):
        return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"

    def __call__(self, x):
        means = mx.mean(x, axis=0, keepdims=True)
        var = mx.var(x, axis=0, keepdims=True)
        x = (x - means) * mx.rsqrt(var + self.eps)
        return (self.weight * x + self.bias) if "weight" in self else x

Then I use PyTorch example on 1D BatchNorm to compare the result.

import torch 
import torch.nn as nn 

# With Learnable Parameters
torch.manual_seed(0)
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)

# MLX/ Torch comparison
m2 =  MLX_BatchNorm1d(100, affine=False)
input2 = mx.array(input.numpy()) 
output2 = m2(input2)

And the result matches.
result

@awni awni added the enhancement New feature or request label Dec 19, 2023
@awni
Copy link
Member

awni commented Dec 19, 2023

Nice! Just FYI there is a PR open for BatchNorm #217

@dastrobu
Copy link
Contributor

could possibly be closed as #217 is merged.

@awni
Copy link
Member

awni commented Dec 31, 2023

Oh great! Thanks!

@awni awni closed this as completed Dec 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants