In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torchvision.models.resnet import ResNet, BasicBlock

In [28]:
class CheckpointABNFN(torch.autograd.Function):
    # from: http://cthorey.github.io./backpropagation/

    @staticmethod
    def forward(ctx, x, gamma, beta):

        mu = x.mean(dim=(0, 2, 3), keepdim=True)
        var = ((x - mu) ** 2).mean(dim=(0, 2, 3), keepdim=True)

        xhat = (x - mu) / torch.sqrt(var + 1e-5)

        bn_out = gamma * xhat + beta

        out = bn_out.clamp(min=0)

        # ctx.save_for_backward(x, gamma, bn_out)
        ctx.save_for_backward(x, gamma, beta)

        return out

    @staticmethod
    def backward(ctx, dout):
        # x, gamma, bn_out = ctx.saved_tensors
        x, gamma, beta = ctx.saved_tensors
        dx = dgamma = dbeta = None

        mu = x.mean(dim=(0, 2, 3), keepdim=True)
        var = ((x - mu) ** 2).mean(dim=(0, 2, 3), keepdim=True)

        xhat = (x - mu) / torch.sqrt(var + 1e-5)

        bn_out = gamma * xhat + beta

        dout = dout * (bn_out > 0)

        N = dout.shape[0] * dout.shape[2] * dout.shape[3]

        dbeta = dout.sum(dim=(0, 2, 3), keepdim=True)

        dgamma = torch.sum(((x - mu) / torch.sqrt(var + 1e-5))
                           * dout, dim=(0, 2, 3), keepdim=True)

        dx = (1.0 / N) * gamma * (1.0 / torch.sqrt(var + 1e-5)) * (N * dout - torch.sum(dout, dim=(0, 2, 3),
                                                                                        keepdim=True) - ((x - mu) * ((var + 1e-5) ** -1.0) * torch.sum(dout * (x - mu), dim=(0, 2, 3), keepdim=True)))

        return dx, dgamma, dbeta

In [29]:
bn = CheckpointABNFN.apply

In [30]:
x = torch.rand((1, 2, 4, 4), requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), requires_grad=True)

In [31]:
out = bn(x, gamma, beta)

In [32]:
temp = out.sum()
temp.backward()

In [33]:
x.mean(), x.var()

(tensor(0.5034, grad_fn=<MeanBackward0>),
 tensor(0.0732, grad_fn=<VarBackward0>))

In [34]:
out.mean(), out.var()

(tensor(0.4292, grad_fn=<MeanBackward0>),
 tensor(0.2602, grad_fn=<VarBackward0>))

In [35]:
from torch.autograd import gradcheck

x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), dtype=torch.double, requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), dtype=torch.double, requires_grad=True)

input = (x, gamma, beta)
test = gradcheck(bn, input, eps=1e-6, atol=1e-4)
print(test)

True


In [19]:
mod = ActivatedBatchNormAutograd(2)

NameError: name 'ActivatedBatchNormAutograd' is not defined

In [None]:
x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)

In [None]:
out = mod(x)

In [None]:
out.mean(), out.var()