In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torchvision.models.resnet import ResNet, BasicBlock

In [16]:
# from: http://cthorey.github.io./backpropagation/
class BatchNormFN2(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, gamma, beta):

        mu = x.mean(dim=(0, 2, 3), keepdim=True)
        var = ((x - mu) ** 2).mean(dim=(0, 2, 3), keepdim=True)

        xhat =  (x - mu) / torch.sqrt(var + 1e-5) 

        out = gamma * xhat + beta

        ctx.save_for_backward(x, gamma)

        return out

    @staticmethod
    def backward(ctx, dout):
        x, gamma = ctx.saved_tensors
        dx = dgamma = dbeta = None
        N = dout.shape[0] * dout.shape[2] * dout.shape[3]

        dbeta = dout.sum(dim=(0, 2, 3), keepdim=True)

        mu = x.mean(dim=(0, 2, 3), keepdim=True)
        var = ((x - mu) ** 2).mean(dim=(0, 2, 3), keepdim=True)

        dgamma = torch.sum(((x - mu) / torch.sqrt(var + 1e-5)) * dout, dim=(0, 2, 3), keepdim=True)

        dx = (1.0 / N) * gamma * (1.0 / torch.sqrt(var + 1e-5)) * (N * dout - torch.sum(dout, dim=(0, 2, 3), keepdim=True) - ((x - mu) * ((var + 1e-5) ** -1.0) * torch.sum(dout * (x - mu), dim=(0, 2, 3), keepdim=True)))  

        return dx, dgamma, dbeta

In [17]:
bn = BatchNormFN.apply

In [18]:
x = torch.rand((1, 2, 4, 4), requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), requires_grad=True)

In [19]:
out = bn(x, gamma, beta)

In [20]:
temp = out.sum()
temp.backward()

In [21]:
x.mean(), x.var()

(tensor(0.5422, grad_fn=<MeanBackward0>),
 tensor(0.0944, grad_fn=<VarBackward0>))

In [22]:
out.mean(), out.var()

(tensor(0., grad_fn=<MeanBackward0>), tensor(1.0321, grad_fn=<VarBackward0>))

In [23]:
from torch.autograd import gradcheck

x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), dtype=torch.double, requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), dtype=torch.double, requires_grad=True)

input = (x, gamma, beta)
test = gradcheck(bn, input, eps=1e-6, atol=1e-4)
print(test)

True


In [None]:
mod = ActivatedBatchNormAutograd(2)

In [None]:
x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)

In [None]:
out = mod(x)

In [None]:
out.mean(), out.var()