In [2]:
import torch
%load_ext autoreload
!pwd

/Users/magnus/repos/dl-by-doing/dev


In [83]:
batch_size = 4
dim = (2, 3)
batch = torch.randn((batch_size, *dim)) * 2 + 1
batch.shape

torch.Size([4, 2, 3])

In [84]:
# x from paper -> batch
# mu: per dimension mean
mu = batch.mean(axis=0, keepdims=True)

# var: per dimension variance
var = batch.var(axis=0, keepdims=True)

# x_hat: per dimension normalization, eps: for numerical stability
eps = 1e-10
x_hat = (batch - mu) / torch.sqrt(var + eps)

# scale and shift output
gamma = torch.randn(dim)
beta = torch.randn(dim)
y = gamma * x_hat + beta

In [86]:
x_hat.mean(axis=0, keepdims=True)

tensor([[[-5.9605e-08,  1.4901e-08,  0.0000e+00],
         [ 0.0000e+00, -1.8626e-08, -1.7881e-07]]])

In [185]:
from torch import nn


class BatchNorm(nn.Module):
    def __init__(self, dim, eps=1e-05):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

        self.cum_mu = 0
        self.cum_var = 0
        self.m = 0

    def forward(self, x):
        if self.train:
            mu = x.mean(axis=0, keepdims=True)
            var = x.var(axis=0, keepdims=True, unbiased=False)
            x_hat = (x - mu) / torch.sqrt(var + eps)
            y = self.gamma * x_hat + self.beta

        else:
            self._update_cum_mu(x)
            self._update_cum_var(x)
            self.m += 1
            sigma = torch.sqrt(self.cum_var + self.eps)
            y = self.gamma / sigma * x + (self.beta - self.gamma * self.cum_mu / sigma)

        return y

    def _update_cum_mu(self, x):
        self.cum_mu = (x.mean(axis=0, keepdims=True) + self.m * self.cum_mu) / (self.m + 1)

    def _update_cum_var(self, x):
        self.cum_var = self.m / (self.m -1) * (x.var(axis=0, keepdims=True, unbiased=False) + self.m * self.cum_var) / (self.m + 1)




In [186]:
batch_size = 4
dim = 100_000
batch = torch.randn((batch_size, dim)) * 2 + 1

bn = BatchNorm(dim=dim)

In [187]:
y1 = bn(batch)
y1.shape

torch.Size([4, 100000])

In [188]:
torch_bn = torch.nn.BatchNorm1d(num_features=dim, momentum=0, track_running_stats=True)
y2 = torch_bn(batch)
y2.shape

torch.Size([4, 100000])

In [189]:
mu1 = y2.mean(dim=1, keepdims=True)

In [190]:
mu2 = y1.mean(dim=1, keepdims=True)

In [191]:
var2 = y2.var(dim=1, keepdims=True, unbiased=False)

In [192]:
var1 = y1.var(dim=1, keepdims=True, unbiased=False)

In [193]:
torch.allclose(var1, var2, rtol=1e-3)

True

In [194]:
torch.allclose(mu1, mu2, rtol=1e-3)

True