In [3]:
import os
import torch
import numpy as np
import pandas as pd

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda

import matplotlib.pyplot as plt

In [4]:
available = torch.cuda.is_available()
curr_device = torch.cuda.current_device()
device = torch.device("cuda:0" if available else "cpu")
device_count = torch.cuda.device_count() 
device_name =  torch.cuda.get_device_name(0)

print(f'Cuda available: {available}')
print(f'Current device: {curr_device}')
print(f'Device: {device}')
print(f'Device count: {device_count}')
print(f'Device name: {device_name}')

#device = torch.device("cpu")

Cuda available: True
Current device: 0
Device: cuda:0
Device count: 1
Device name: GeForce GTX 1070


In [157]:
class MyBatchNorm2d(nn.modules.batchnorm._NormBase):
    ''' Partially based on: 
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
        https://discuss.pytorch.org/t/implementing-batchnorm-in-pytorch-problem-with-updating-self-running-mean-and-self-running-var/49314/5 
    '''
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype, 'affine': False, 'track_running_stats': True}
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, **factory_kwargs
        )
        
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

    def forward(self, input):
        self._check_input_dim(input)
    
        if self.training:
            var, mean = torch.var_mean(input, [0, 2, 3], unbiased=False) # along channel axis
            unbiased_var = torch.var(input, [0, 2, 3], unbiased=True) # along channel axis
            self.running_mean = (1.0 - self.momentum) * self.running_mean + self.momentum * mean
            
            # Strange: PyTorch impl. of running variance uses biased_variance for the batch calc but
            # *unbiased_var* for the running_var!
            # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L190
            self.running_var = (1.0 - self.momentum) * self.running_var + self.momentum * unbiased_var
        else:
            mean = self.running_mean
            var = self.running_var

        # change shape
        current_mean = mean.view([1, self.num_features, 1, 1]).expand_as(input)
        current_var = var.view([1, self.num_features, 1, 1]).expand_as(input)

        # get output
        denom = (current_var + self.eps)
        y = (input - current_mean) / denom.sqrt()

        return y, denom


In [158]:
eps = 1e-8
for i in range(10):
    input = torch.randn(2, 3, 2, 2)
    
    # Without Learnable Parameters
    m = nn.BatchNorm2d(3, affine=False, momentum=0.5)
    output = m(input)
    
    m2 = MyBatchNorm2d(3, momentum=0.5)
    output2, _ = m2(input)
    
    assert torch.any(torch.abs(output - output2) < eps), (i, output, output2)
    assert torch.any(torch.abs(m.running_mean - m2.running_mean) < eps), (i, m.running_mean, m2.running_mean)
    assert torch.any(torch.abs(m.running_var - m2.running_var) < eps), (i, m.running_var, m2.running_var)
    
    
for i in range(10):
    input = torch.randn(2, 3, 2, 2)
    
    # Without Learnable Parameters
    m.eval()
    output = m(input)
    
    m2.eval()
    output2, _ = m2(input)
    
    assert torch.any(torch.abs(output - output2) < eps), (i, output, output2)
    assert torch.any(torch.abs(m.running_mean - m2.running_mean) < eps), (i, m.running_mean, m2.running_mean)
    assert torch.any(torch.abs(m.running_var - m2.running_var) < eps), (i, m.running_var, m2.running_var)

# 2022-02-21

* Strange: PyTorch impl. of running variance uses biased variance for the batch calc but
  *unbiased variance* for the running_var!
* https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L190
* Anyways, used their method of estimating it to get 1e-8 matching precision with their implementation