In [125]:
from torchvision.models import resnet18
import imp
import torch
import numpy as np
utils = imp.load_source('util', '../utils.py')
flatten_model = utils.flatten_model

model = resnet18()

In [220]:
class LRP_ZRule(torch.autograd.Function):

    @staticmethod
    def forward(ctx, func, input, *args):
        '''
        forawd pass perform usual func forward pass
        '''
        ctx.func = func
        ctx.input =input.clone().detach()
        ctx.args = [*args]
        return func(input, *args)

    @staticmethod
    def backward(ctx, R=1):
        '''
        substitute backward pass with z rule propagation 
        because forward pass given with 3 arguments, backward return 3 outputs
        '''
        ctx.input.requires_grad_(True)
        with torch.enable_grad():
            Z = ctx.func(ctx.input, *ctx.args)
            S = R /(Z + (Z==0).float()*np.finfo(np.float32).eps)
            Z.backward(S)
            C = ctx.input.grad
            R = ctx.input * C
        return None, R, None

In [237]:
x = torch.ones(4,4, requires_grad=True).view(1,1,4,4)
print('input: \n', x)
x.retain_grad()
w = torch.rand(1,1, 2,2)
print('weights:\n ', w)
w.requires_grad_(True)
f = LRP_ZRule.apply(torch.conv2d, x, w)
print('output: \n', f)

f.backward(torch.ones(1,1,3,3))

print('x grad: \n', x.grad)

print('weights grad: \n', w.grad)

input: 
 tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]], grad_fn=<ViewBackward>)
weights:
  tensor([[[[0.9557, 0.2066],
          [0.8205, 0.0062]]]])
output: 
 tensor([[[[1.9890, 1.9890, 1.9890],
          [1.9890, 1.9890, 1.9890],
          [1.9890, 1.9890, 1.9890]]]], grad_fn=<LRP_ZRuleBackward>)
x grad: 
 tensor([[[[0.4805, 0.5844, 0.5844, 0.1039],
          [0.8930, 1.0000, 1.0000, 0.1070],
          [0.8930, 1.0000, 1.0000, 0.1070],
          [0.4125, 0.4156, 0.4156, 0.0031]]]])
weights grad: 
 tensor([[[[4.5249, 4.5249],
          [4.5249, 4.5249]]]])


In [241]:
x = torch.ones(4,4, requires_grad=True).view(1,1,4,4)
x.retain_grad()
x.register_hook(lambda x: print('x grad: \n',x))
w = torch.rand(1,1,2,2)
w.requires_grad_(True)
w.register_hook(lambda x: print('w grad: \n',x))
c = torch.conv2d(x, w)
c.backward(torch.ones(1,1,3,3))
print(x.grad)

w grad: 
 tensor([[[[9., 9.],
          [9., 9.]]]])
x grad: 
 tensor([[[[0.8868, 1.4965, 1.4965, 0.6097],
          [1.7728, 2.4240, 2.4240, 0.6512],
          [1.7728, 2.4240, 2.4240, 0.6512],
          [0.8860, 0.9275, 0.9275, 0.0414]]]])
tensor([[[[0.8868, 1.4965, 1.4965, 0.6097],
          [1.7728, 2.4240, 2.4240, 0.6512],
          [1.7728, 2.4240, 2.4240, 0.6512],
          [0.8860, 0.9275, 0.9275, 0.0414]]]])


In [11]:
from  torch.nn import modules
import torch.nn.functional as F
_ConvNd = modules.conv._ConvNd
class MyLayer(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

    def conv2d_forward(self, input, weight):
        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        return self.conv2d_forward(input, self.weight)


In [12]:
for module in flatten_model(model):
    if module.__class__.__name__ == 'Conv2d':
#         print(module)
        module.__class__ = MyLayer
print(model)

ResNet(
  (conv1): MyLayer(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): MyLayer(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): MyLayer(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): MyLayer(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True

In [14]:
out = model(torch.rand(1,3,256,256))

In [16]:
out.sum().backward()

# skip grad

In [81]:
a = torch.tensor([1.], requires_grad=True)
a.register_hook(lambda x: print(x))
b = 2*a
c = b + a**2
c.register_hook(lambda x: torch.tensor([5.]))

<torch.utils.hooks.RemovableHandle at 0x7f7a005c11d0>

In [82]:
c.backward()

tensor([20.])


In [84]:
c.grad_fn.next_functions

((<MulBackward0 at 0x7f7a005c1160>, 0), (<PowBackward0 at 0x7f7a005c13c8>, 0))

In [88]:
d = torch.ones(4,4, dtype=torch.float32).view(1,1,4,4)
r = torch.nn.Conv2d(1,1,(2,2))(d)
r

tensor([[[[0.9797, 0.9797, 0.9797],
          [0.9797, 0.9797, 0.9797],
          [0.9797, 0.9797, 0.9797]]]], grad_fn=<MkldnnConvolutionBackward>)

In [90]:
r.grad_fn.next_functions

((None, 0),
 (<AccumulateGrad at 0x7f7a005c1be0>, 0),
 (<AccumulateGrad at 0x7f7a005c1a58>, 0))