In [331]:
import imp
import torch
import numpy as np

# Gradient in conv

In [332]:
x = torch.ones(4,4, requires_grad=True).view(1,1,4,4)
x.retain_grad()
x.register_hook(lambda x: print('x grad: \n',x))
w = torch.rand(1,1,3,3)
w.requires_grad_(True)
w.register_hook(lambda x: print('w grad: \n',x))
c = torch.conv2d(x, weight=w, padding=1, stride=1, dilation=1)
c.backward(torch.ones(1,1,4,4))

w grad: 
 tensor([[[[ 9., 12.,  9.],
          [12., 16., 12.],
          [ 9., 12.,  9.]]]])
x grad: 
 tensor([[[[1.6835, 2.8811, 2.8811, 1.7759],
          [2.8740, 4.8261, 4.8261, 2.7601],
          [2.8740, 4.8261, 4.8261, 2.7601],
          [2.3029, 3.6182, 3.6182, 1.6654]]]])


In [260]:
c.shape

torch.Size([1, 1, 5, 5])

In [329]:
NotImplemented

NotImplemented

# LRP in backward

In [302]:
class LRP_ZRule_func(torch.autograd.Function):

    @staticmethod
    def forward(ctx, func, input, *args):
        '''
        forawd pass perform usual func forward pass
        '''
        ctx.func = func
        ctx.input =input.clone().detach()
        ctx.args = [*args]
        return func(input, *args)

    @staticmethod
    def backward(ctx, R):
        '''
        substitute backward pass with z rule propagation 
        backward pass must return same namber of ouputs as number of inputs
        '''
        ctx.input.requires_grad_(True)
        with torch.enable_grad():
            Z = ctx.func(ctx.input, *ctx.args)
            S = R /(Z + (Z==0).float()*np.finfo(np.float32).eps)
            Z.backward(S)
            C = ctx.input.grad
            R = ctx.input * C
        return tuple([None, R] + len(ctx.args)*[None])

In [303]:
x = torch.tensor([[-1.,-1.,1.,1.]]*4, requires_grad=True).view(1,1,4,4)
print('input: \n', x)
x.retain_grad()
w = torch.tensor([[1.,1.,1.]]*3).view(1,1,3,3)
print('weights:\n ', w)
w.requires_grad_(True)
f = LRP_ZRule_func.apply(torch.conv2d, x, w, None, 1, 1)
f.register_hook(lambda x: print('f backward inpt: \n', x))
print('f: \n', f)

out = f.sum()

print('out: \n', out)
out.backward()

print('x relevance: \n', x.grad)

print('weights grad: \n', w.grad)

input: 
 tensor([[[[-1., -1.,  1.,  1.],
          [-1., -1.,  1.,  1.],
          [-1., -1.,  1.,  1.],
          [-1., -1.,  1.,  1.]]]], grad_fn=<ViewBackward>)
weights:
  tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
f: 
 tensor([[[[-4., -2.,  2.,  4.],
          [-6., -3.,  3.,  6.],
          [-6., -3.,  3.,  6.],
          [-4., -2.,  2.,  4.]]]], grad_fn=<LRP_ZRule_funcBackward>)
out: 
 tensor(0., grad_fn=<SumBackward0>)
f backward inpt: 
 tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
x relevance: 
 tensor([[[[1.2500, 0.4167, 0.4167, 1.2500],
          [1.7500, 0.5833, 0.5833, 1.7500],
          [1.7500, 0.5833, 0.5833, 1.7500],
          [1.2500, 0.4167, 0.4167, 1.2500]]]])
weights grad: 
 tensor([[[[0.5833, 3.5000, 0.5833],
          [0.8333, 5.0000, 0.8333],
          [0.5833, 3.5000, 0.5833]]]])


# Skip connection emulation

In [320]:
x = torch.ones(4,4,requires_grad=True).view(1,1,4,4)
x.register_hook(lambda x: print('x input grad: \n', x))
w = torch.ones(1,1,3,3)

f = LRP_ZRule.apply(torch.conv2d, x, w, None, 1, 1)

c = f + x
c.backward(torch.ones(1,1,4,4))


x input grad: 
 tensor([[[[1.6944, 1.9722, 1.9722, 1.6944],
          [1.9722, 2.3611, 2.3611, 1.9722],
          [1.9722, 2.3611, 2.3611, 1.9722],
          [1.6944, 1.9722, 1.9722, 1.6944]]]])


In [322]:
x = torch.ones(4,4,requires_grad=True).view(1,1,4,4)
x.register_hook(lambda x: print('x input grad: \n', x))
w = torch.ones(1,1,3,3)

f = LRP_ZRule.apply(torch.conv2d, x, w, None, 1, 1)

c = f
c.backward(torch.ones(1,1,4,4))


x input grad: 
 tensor([[[[0.6944, 0.9722, 0.9722, 0.6944],
          [0.9722, 1.3611, 1.3611, 0.9722],
          [0.9722, 1.3611, 1.3611, 0.9722],
          [0.6944, 0.9722, 0.9722, 0.6944]]]])


In [317]:
-torch.tensor([[[[0.7597, 1.0883, 0.9689, 0.3407],
          [1.2799, 1.3725, 1.2638, 0.5784],
          [1.2476, 1.4671, 1.3128, 0.5307],
          [1.1303, 1.1453, 1.0129, 0.5011]]]]) + 
torch.tensor([[[[0.3512, 0.9608, 0.9809, 0.5729],
          [0.7860, 1.6011, 1.7153, 1.0738],
          [0.7341, 1.5913, 1.7625, 1.0757],
          [0.3870, 0.8800, 0.9579, 0.5696]]]])

tensor([[[[-0.4085, -0.1275,  0.0120,  0.2322],
          [-0.4939,  0.2286,  0.4515,  0.4954],
          [-0.5135,  0.1242,  0.4497,  0.5450],
          [-0.7433, -0.2653, -0.0550,  0.0685]]]])

In [316]:
w

tensor([[[[0.5382, 0.2435, 0.8232],
          [0.2217, 0.0189, 0.4675],
          [0.5245, 0.3569, 0.6152]]]])