In [1]:
import torch
import torch.nn as nn

In [2]:
input_size = 3
batch_size = 5
eps = 1e-1

In [3]:
class CustomBatchNorm1d:
    def __init__(self, weight, bias, eps, momentum):
        self.gamma = weight
        self.beta = bias
        self.eps = eps
        self.momentum = momentum
        self.exp_avg = 0
        self.exp_var = 1
        self.click = False

    def __call__(self, input_tensor):
        
        if self.click is False:
            
            mean, var = input_tensor.mean(dim=0), input_tensor.var(dim=0, unbiased=False)

            self.exp_avg = (1 - self.momentum) * mean + self.momentum * self.exp_avg
            self.exp_var = (1 - self.momentum) * var * input_tensor.shape[0]/(input_tensor.shape[0]-1) + self.momentum * self.exp_var

            return (input_tensor - mean)/torch.sqrt(var + eps)*self.gamma + self.beta
        
        else:
            
            mean, var = self.exp_avg, self.exp_var
            
        return (input_tensor - mean)/torch.sqrt(var + eps)*self.gamma + self.beta

        
    def eval(self):
        self.click = True

In [4]:
batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5

custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data,
                                        batch_norm.bias.data, eps, batch_norm.momentum)

In [5]:
all_correct = True

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    all_correct &= torch.allclose(batch_norm(torch_input), custom_batch_norm1d(torch_input))
    
batch_norm.eval()
custom_batch_norm1d.eval()

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    all_correct &= torch.allclose(batch_norm(torch_input), custom_batch_norm1d(torch_input))

print(all_correct)

True
