-
Notifications
You must be signed in to change notification settings - Fork 898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request of Implementation of BatchNorm #204
Labels
enhancement
New feature or request
Comments
Hi, I tried to implement the 1D BatchNorm in MLX and do a comparison test with PyTorch. Due to the similarity of the equations of Layer Norm and Batch Norm, I just modified the axis for summation from import mlx.core as mx
from mlx.nn.layers.base import Module
class MLX_BatchNorm1d(Module):
r"""Applies layer normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=0, keepdims=True)
var = mx.var(x, axis=0, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x Then I use PyTorch example on 1D BatchNorm to compare the result. import torch
import torch.nn as nn
# With Learnable Parameters
torch.manual_seed(0)
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)
# MLX/ Torch comparison
m2 = MLX_BatchNorm1d(100, affine=False)
input2 = mx.array(input.numpy())
output2 = m2(input2) |
Nice! Just FYI there is a PR open for BatchNorm #217 |
This was referenced Dec 19, 2023
could possibly be closed as #217 is merged. |
Oh great! Thanks! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi all. I am new to AI and Deep Learning and I am fascinated by this package. When I look at the documentation, I found that currently there are RMS/Group/Layer Normalisation but there is no BatchNorm. I hope that there can be a BatchNorm function for the deep learning model since it helps us to reduce the internal covariate shift.
The text was updated successfully, but these errors were encountered: