In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [125]:
from torchvision.models.resnet import ResNet, BasicBlock

In [127]:
class CheckpointBNFN(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, gamma, beta):

        mu = x.mean(dim=(0, 2, 3), keepdim=True)

        xmu = x - mu
        sq = xmu ** 2

        var = sq.mean(dim=(0, 2, 3), keepdim=True)

        sqrtvar = torch.sqrt(var + 1e-5)

        ivar = 1.0 / sqrtvar

        xhat = xmu * ivar

        gammax = gamma * xhat

        bn_out = gammax + beta

        # ctx.save_for_backward(xhat, gamma, xmu, ivar, sqrtvar, var, bn_out)
        ctx.save_for_backward(x, gamma, beta)

        return bn_out

    @staticmethod
    def backward(ctx, dout):
        # xhat, gamma, xmu, ivar, sqrtvar, var, bn_out = ctx.saved_tensors
        x, gamma, beta = ctx.saved_tensors

        # recompute
        mu = x.mean(dim=(0, 2, 3), keepdim=True)

        xmu = x - mu
        sq = xmu ** 2

        var = sq.mean(dim=(0, 2, 3), keepdim=True)

        sqrtvar = torch.sqrt(var + 1e-5)

        ivar = 1.0 / sqrtvar

        xhat = xmu * ivar

        gammax = gamma * xhat

        bn_out = gammax + beta

        # backwards pass

        dx = dgamma = dbeta = None

        dbeta = dout.sum(dim=(0, 2, 3), keepdim=True)

        dgammax = dout

        dgamma = torch.sum(dgammax * xhat, dim=(0, 2, 3), keepdim=True)
        dxhat = dgammax * gamma

        divar = torch.sum(dxhat * xmu, dim=(0, 2, 3), keepdim=True)
        dxmu1 = dxhat * ivar

        dsqrtvar = -1.0 / (sqrtvar ** 2) * divar

        dvar = 0.5 * 1.0 / torch.sqrt(var + 1e-5) * dsqrtvar

        dsq = 1.0 / (dout.shape[0] * dout.shape[2] *
                     dout.shape[3]) * torch.ones_like(dout) * dvar

        dxmu2 = 2.0 * xmu * dsq

        dx1 = dxmu1 + dxmu2
        dmu = -1.0 * torch.sum(dxmu1 + dxmu2, dim=(0, 2, 3), keepdim=True)

        dx2 = 1.0 / (dout.shape[0] * dout.shape[2] *
                     dout.shape[3]) * torch.ones_like(dout) * dmu

        dx = dx1 + dx2

        return dx, dgamma, dbeta

In [128]:
class CheckpointBN(nn.Module):
    def __init__(self, num_features):
        super(CheckpointBN, self).__init__()

        shape = (1, num_features, 1, 1)

        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        self.moving_mean = torch.zeros(shape, device=device)
        self.moving_var = torch.zeros(shape, device=device)

    def forward(self, x):
        if not torch.is_grad_enabled():
            out = (x - self.moving_mean) / torch.sqrt(self.moving_var + 1e-5)
            out = self.gamma * out + self.beta
        else:

            out = CheckpointBNFN.apply(x, self.gamma, self.beta)

            with torch.no_grad():
                mean = x.mean(dim=(0, 2, 3), keepdim=True)
                var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

                self.moving_mean = 0.9 * self.moving_mean + (1 - 0.9) * mean
                self.moving_var = 0.9 * self.moving_var + (1 - 0.9) * var

        return out

In [136]:
from torchvision import datasets, transforms

In [138]:
    train_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Lambda(
                               lambda x: torch.cat([x, x, x], dim=0))
                       ])),
        batch_size=16, shuffle=True)

In [141]:
for img, label in train_dataloader:
    print(img.shape)
    print(label.shape)
    break

torch.Size([16, 3, 28, 28])
torch.Size([16])


In [129]:
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10, norm_layer=CheckpointBN)

In [132]:
x = torch.rand((2, 3, 28, 28), requires_grad=True)

In [133]:
out = model(x)

In [134]:
out.shape

torch.Size([2, 10])

In [115]:
bn = CheckpointABNFN.apply

In [116]:
x = torch.rand((1, 2, 4, 4), requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), requires_grad=True)

In [117]:
out = bn(x, gamma, beta)

In [118]:
temp = out.sum()
temp.backward()

In [119]:
x.mean(), x.var()

(tensor(0.3716, grad_fn=<MeanBackward0>),
 tensor(0.0811, grad_fn=<VarBackward0>))

In [120]:
out.mean(), out.var()

(tensor(0.3954, grad_fn=<MeanBackward0>),
 tensor(0.4804, grad_fn=<VarBackward0>))

In [121]:
from torch.autograd import gradcheck

x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)
gamma = torch.ones((1, 2, 1, 1), dtype=torch.double, requires_grad=True)
beta = torch.zeros((1, 2, 1, 1), dtype=torch.double, requires_grad=True)

input = (x, gamma, beta)
test = gradcheck(bn, input, eps=1e-6, atol=1e-4)
print(test)

True


In [122]:
mod = ActivatedBatchNormAutograd(2)

In [99]:
x = torch.rand((1, 2, 4, 4), dtype=torch.double, requires_grad=True)

In [100]:
out = mod(x)

In [102]:
out.mean(), out.var()

(tensor(0.4186, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.2750, dtype=torch.float64, grad_fn=<VarBackward0>))