In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [152]:
class BatchNormFN(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, gamma, beta):

        mu = x.mean(dim=(0, 2, 3), keepdim=True)

        xmu = x - mu
        sq = xmu ** 2

        var = sq.mean(dim=(0, 2, 3), keepdim=True)

        sqrtvar = torch.sqrt(var + 1e-5)

        ivar = 1.0 / sqrtvar

        xhat = xmu * ivar

        gammax = gamma * xhat

        out = gammax + beta

        ctx.save_for_backward(xhat, gamma, xmu, ivar, sqrtvar, var)

        return out

    @staticmethod
    def backward(ctx, dout):
        xhat, gamma, xmu, ivar, sqrtvar, var = ctx.saved_tensors

        dx = dgamma = dbeta = None

        dbeta = dout.sum(dim=(0, 2, 3), keepdim=True)

        dgammax = dout

        dgamma = torch.sum(dgammax * xhat, dim=(0, 2, 3), keepdim=True)
        dxhat = dgammax * gamma

        divar = torch.sum(dxhat * xmu, dim=(0, 2, 3), keepdim=True)
        dxmu1 = dxhat * ivar

        dsqrtvar = -1.0 / (sqrtvar ** 2) * divar

        dvar = 0.5 * 1.0 / torch.sqrt(var + 1e-5) * dsqrtvar

        dsq = 1.0 / (dout.shape[0] * dout.shape[2] * dout.shape[3]) * torch.ones_like(dout) * dvar

        dxmu2 = 2.0 * xmu * dsq

        dx1 = dxmu1 + dxmu2
        dmu = -1.0 * torch.sum(dxmu1 + dxmu2, dim=(0, 2, 3), keepdim=True)

        dx2 = 1.0 / (dout.shape[0] * dout.shape[2] * dout.shape[3]) * torch.ones_like(dout) * dmu

        dx = dx1 + dx2

        return dx, dgamma, dbeta

In [153]:
class BatchNormAutograd(nn.Module):
    def __init__(self, num_features):
        super(BatchNormAutograd, self).__init__()

        shape = (1, num_features, 1, 1)

        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        self.moving_mean = torch.zeros(shape, device=device)
        self.moving_var = torch.zeros(shape, device=device)
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=(0, 2, 3), keepdim=True)
            var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

            self.moving_mean = 0.9 * self.moving_mean * (1 - 0.9) * mean

            self.moving_var = 0.9 * self.moving_var * (1 - 0.9) * var
        
            out = BatchNormFN.apply(x, self.gamma, self.beta)
        else:
            out = (x - self.moving_mean) / torch.sqrt(self.moving_var + 1e-5)

        return out

In [154]:
bn = BatchNormFN.apply

In [155]:
x = torch.rand((1, 2, 4, 4), requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), requires_grad=True)

In [156]:
out = bn(x, gamma, beta)

In [157]:
temp = out.sum()
temp.backward()

In [158]:
x.mean(), x.var()

(tensor(0.5352, grad_fn=<MeanBackward0>),
 tensor(0.0985, grad_fn=<VarBackward0>))

In [159]:
out.mean(), out.var()

(tensor(-1.1176e-07, grad_fn=<MeanBackward0>),
 tensor(1.0321, grad_fn=<VarBackward0>))

In [160]:
from torch.autograd import gradcheck

# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))

x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), dtype=torch.double, requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), dtype=torch.double, requires_grad=True)

input = (x, gamma, beta)
test = gradcheck(bn, input, eps=1e-6, atol=1e-4)
print(test)

True


In [161]:
mod = BatchNorm(2)

In [168]:
x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)

In [169]:
out = mod(x)

In [170]:
out.mean(), out.var()

(tensor(-4.8572e-17, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(1.0321, dtype=torch.float64, grad_fn=<VarBackward0>))