# Hijacking autograd for LRP

In [1]:
import warnings
import numpy as np

import torch
import torch_scatter
import torchgraphs as tg

np.set_printoptions(formatter={'float_kind': '{:5.2f}'.format, 'int_kind': '{:5d}'.format}, linewidth=150)

## Add

In [2]:
a = torch.tensor([1, 2, 3, 1], dtype=torch.float, requires_grad=True)
b = torch.tensor([0, 1, 2, 4], dtype=torch.float, requires_grad=True)
c = a + b

grad_out = torch.ones_like(c) * (c != 0).float()
c.backward(grad_out)

print(a.detach().numpy(), '+', b.detach().numpy(), '---->', c.detach().numpy())
print(a.grad.numpy(), ' ', b.grad.numpy(), '<----', grad_out.detach().numpy())

[ 1.00  2.00  3.00  1.00] + [ 0.00  1.00  2.00  4.00] ----> [ 1.00  3.00  5.00  5.00]
[ 1.00  1.00  1.00  1.00]   [ 1.00  1.00  1.00  1.00] <---- [ 1.00  1.00  1.00  1.00]


In [3]:
class AddRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        out = a + b
        ctx.save_for_backward(a, b, out)
        return out

    @staticmethod
    def backward(ctx, rel_out):
        a, b, out = ctx.saved_tensors
        if ((out == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        rel_a = torch.where(out != 0, rel_out * a / out, out.new_tensor(0))
        rel_b = torch.where(out != 0, rel_out * b / out, out.new_tensor(0))
        return rel_a, rel_b
    
def add(a, b):
    return AddRelevance.apply(a, b)

In [4]:
a = torch.tensor([1, 2, 3, 1], dtype=torch.float, requires_grad=True)
b = torch.tensor([0, 1, 2, 4], dtype=torch.float, requires_grad=True)
c = add(a, b)

rel_out = torch.ones_like(c) * (c != 0).float()
c.backward(torch.ones_like(rel_out))

print(a.detach().numpy(), '+', b.detach().numpy(), '---->', c.detach().numpy())
print(a.grad.numpy(), ' ', b.grad.numpy(), '<----', rel_out.numpy())

[ 1.00  2.00  3.00  1.00] + [ 0.00  1.00  2.00  4.00] ----> [ 1.00  3.00  5.00  5.00]
[ 1.00  0.67  0.60  0.20]   [ 0.00  0.33  0.40  0.80] <---- [ 1.00  1.00  1.00  1.00]


## Add

In [5]:
a = torch.tensor([[1, 2, 3, 1], [-6, -2, -1, -1]], dtype=torch.float, requires_grad=True)
b = torch.sum(a, dim=0)

grad_out = torch.ones_like(b) * (b != 0).float()
b.backward(grad_out)

print(a.detach().numpy(), '---->', b.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.detach().numpy())

[[ 1.00  2.00  3.00  1.00]
 [-6.00 -2.00 -1.00 -1.00]] ----> [-5.00  0.00  2.00  0.00]
[[ 1.00  0.00  1.00  0.00]
 [ 1.00  0.00  1.00  0.00]] <---- [ 1.00  0.00  1.00  0.00]


In [6]:
class SumPooling(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, dim, keepdim):
        out = torch.sum(src, dim=dim, keepdim=keepdim)
        ctx.dim = dim
        ctx.keepdim = keepdim
        ctx.save_for_backward(src, out)
        return out

    @staticmethod
    def backward(ctx, rel_out):
        src, out = ctx.saved_tensors
        if ((out == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0))
        if not ctx.keepdim and ctx.dim is not None:
            rel_out.unsqueeze_(ctx.dim)
        return rel_out * src, None, None


def sum(tensor, dim=None, keepdim=False):
    return SumPooling.apply(tensor, dim, keepdim)

In [7]:
a = torch.tensor([[1, 2, 3, 1], [-6, -2, -1, -1]], dtype=torch.float, requires_grad=True)
b = sum(a, dim=0)

rel_out = torch.ones_like(b) * (b != 0).float()
b.backward(rel_out)

print(a.detach().numpy(), '---->', b.detach().numpy())
print(a.grad.numpy(), '<----', rel_out.detach().numpy())

[[ 1.00  2.00  3.00  1.00]
 [-6.00 -2.00 -1.00 -1.00]] ----> [-5.00  0.00  2.00  0.00]
[[-0.20  0.00  1.50  0.00]
 [ 1.20 -0.00 -0.50 -0.00]] <---- [ 1.00  0.00  1.00  0.00]


In [8]:
a = torch.tensor([[1, 2, 3, 1], [-6, -2, -1, -1]], dtype=torch.float, requires_grad=True)
b = sum(a, dim=0, keepdim=True)

rel_out = torch.ones_like(b) * (b != 0).float()
b.backward(rel_out)

print(a.detach().numpy(), '---->', b.detach().numpy())
print(a.grad.numpy(), '<----', rel_out.detach().numpy())

[[ 1.00  2.00  3.00  1.00]
 [-6.00 -2.00 -1.00 -1.00]] ----> [[-5.00  0.00  2.00  0.00]]
[[-0.20  0.00  1.50  0.00]
 [ 1.20 -0.00 -0.50 -0.00]] <---- [[ 1.00  0.00  1.00  0.00]]


In [9]:
a = torch.tensor([[1, 2, 3, 1], [-6, -2, -1, -1]], dtype=torch.float, requires_grad=True)
b = sum(a, dim=1)

rel_out = torch.ones_like(b) * (b != 0).float()
b.backward(rel_out)

print(a.detach().numpy(), '---->', b.detach().numpy())
print(a.grad.numpy(), '<----', rel_out.detach().numpy())

[[ 1.00  2.00  3.00  1.00]
 [-6.00 -2.00 -1.00 -1.00]] ----> [ 7.00 -10.00]
[[ 0.14  0.29  0.43  0.14]
 [ 0.60  0.20  0.10  0.10]] <---- [ 1.00  1.00]


## Scatter Add

In [10]:
a = torch.tensor(
    [1, 1, 1, 1, 2, 6, 1, 1, -2], dtype=torch.float, requires_grad=True)
a_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 2, 2, 2])
a_new = torch_scatter.scatter_add(a, a_idx, dim=0)

grad_out = torch.ones_like(a_new) * (a_new != 0).float()
a_new.backward(grad_out)

print(a_idx.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.detach().numpy())

[    0     0     0     0     1     1     2     2     2]
[ 1.00  1.00  1.00  1.00  2.00  6.00  1.00  1.00 -2.00] ----> [ 4.00  8.00  0.00]
[ 1.00  1.00  1.00  1.00  1.00  1.00  0.00  0.00  0.00] <---- [ 1.00  1.00  0.00]


In [11]:
class ScatterAddRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, idx, dim, dim_size, fill_value):
        out = torch_scatter.scatter_add(src, idx, dim, None, dim_size, fill_value)
        ctx.dim = dim
        ctx.save_for_backward(src, idx, out)
        return out

    @staticmethod
    def backward(ctx, rel_out):
        src, idx, out = ctx.saved_tensors
        if ((out == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0))
        rel_src = torch.index_select(rel_out, ctx.dim, idx) * src
        return rel_src, None, None, None, None
    
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    return ScatterAddRelevance.apply(src, index, dim, dim_size, fill_value)

In [12]:
b = torch.tensor(
    [1, 1, 1, 1, 2, 6, 1, 1, -2], dtype=torch.float, requires_grad=True)
b_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 2, 2, 2])
b_new = scatter_add(b, b_idx, dim=0)

rel_out = torch.ones_like(b_new) * (b_new != 0).float()
b_new.backward(rel_out)

print(b.detach().numpy(), '---->', b_new.detach().numpy())
print(b.grad.numpy(), '<----', rel_out.numpy())

[ 1.00  1.00  1.00  1.00  2.00  6.00  1.00  1.00 -2.00] ----> [ 4.00  8.00  0.00]
[ 0.25  0.25  0.25  0.25  0.25  0.75  0.00  0.00 -0.00] <---- [ 1.00  1.00  0.00]


## Scatter Mean

In [13]:
a = torch.tensor(
    [1, 1, 1, 1, 2, 8, 1, 1, -2], dtype=torch.float, requires_grad=True)
a_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 2, 2, 2])
a_new = torch_scatter.scatter_mean(a, a_idx, dim=0)

grad_out = torch.full_like(a_new, 1) * (a_new != 0).float()
a_new.backward(grad_out)

print(a_idx.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.detach().numpy())

[    0     0     0     0     1     1     2     2     2]
[ 1.00  1.00  1.00  1.00  2.00  8.00  1.00  1.00 -2.00] ----> [ 1.00  5.00  0.00]
[ 0.25  0.25  0.25  0.25  0.50  0.50  0.00  0.00  0.00] <---- [ 1.00  1.00  0.00]


In [14]:
class ScatterMeanRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, idx, dim, dim_size, fill_value):
        sums = torch_scatter.scatter_add(src, idx, dim, None, dim_size, fill_value)
        count = torch_scatter.scatter_add(torch.ones_like(src), idx, dim, None, dim_size, fill_value=0)
        out =  sums / count.clamp(min=1)
        ctx.dim = dim
        ctx.save_for_backward(src, idx, sums)
        return out

    @staticmethod
    def backward(ctx, rel_out):
        src, idx, sums = ctx.saved_tensors
        if ((sums == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        rel_out = torch.where(sums != 0, rel_out / sums, sums.new_tensor(0))
        rel_src = torch.index_select(rel_out, ctx.dim, idx) * src
        return rel_src, None, None, None, None

def scatter_mean(src, index, dim=-1, dim_size=None, fill_value=0):
    return ScatterMeanRelevance.apply(src, index, dim, dim_size, fill_value)

In [15]:
b = torch.tensor(
    [1, 1, 1, 1, 2, 8, 1, 1, -2], dtype=torch.float, requires_grad=True)
b_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 2, 2, 2])
b_new = scatter_mean(b, b_idx, dim=0)

rel_out = torch.ones_like(b_new) * (b_new != 0).float()
b_new.backward(rel_out)

print(b.detach().numpy(), '---->', b_new.detach().numpy())
print(b.grad.numpy(), '<----', rel_out.numpy())

[ 1.00  1.00  1.00  1.00  2.00  8.00  1.00  1.00 -2.00] ----> [ 1.00  5.00  0.00]
[ 0.25  0.25  0.25  0.25  0.20  0.80  0.00  0.00 -0.00] <---- [ 1.00  1.00  0.00]


## Scatter Max

In [16]:
a = torch.tensor(
    [1, 1, 1, 1, 2, 8, 1, 1, -2], dtype=torch.float, requires_grad=True)
a_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 2, 2, 2])
a_new = torch_scatter.scatter_max(a, a_idx, dim=0)[0]

grad_out = torch.full_like(a_new, 1) * (a_new != 0).float()
a_new.backward(grad_out)

print(a_idx.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.detach().numpy())

[    0     0     0     0     1     1     2     2     2]
[ 1.00  1.00  1.00  1.00  2.00  8.00  1.00  1.00 -2.00] ----> [ 1.00  8.00  1.00]
[ 0.00  0.00  0.00  1.00  0.00  1.00  0.00  1.00  0.00] <---- [ 1.00  1.00  1.00]


In [17]:
class ScatterMaxRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, idx, dim, dim_size, fill_value):
        out, idx_maxes = torch_scatter.scatter_max(src, idx, dim=dim, dim_size=dim_size, fill_value=fill_value)
        ctx.dim = dim
        ctx.dim_size = src.shape[dim]
        ctx.save_for_backward(idx, out, idx_maxes)
        return out, idx_maxes

    @staticmethod
    def backward(ctx, rel_out, rel_idx_maxes):
        idx, out, idx_maxes = ctx.saved_tensors
        if ((out == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        rel_out = torch.where(out != 0, rel_out, out.new_tensor(0))
        
        # Where idx_maxes==-1 set idx=0 so that the indexes are valid for scatter_add
        # The corresponding relevance should already be 0, but set it relevance=0 to be sure
        rel_out = torch.where(idx_maxes != -1, rel_out, torch.zeros_like(rel_out))
        idx_maxes = torch.where(idx_maxes != -1, idx_maxes, torch.zeros_like(idx_maxes))

        rel_src = torch_scatter.scatter_add(rel_out, idx_maxes, dim=ctx.dim, dim_size=ctx.dim_size)
        return rel_src, None, None, None, None

def scatter_max(src, index, dim=-1, dim_size=None, fill_value=0):
    return ScatterMaxRelevance.apply(src, index, dim, dim_size, fill_value)

In [18]:
b = torch.tensor(
    [1, 1, 1, 1, 2, 8, 0, 0, 0], dtype=torch.float, requires_grad=True)
b_idx = torch.tensor(
    [0, 0, 0, 0, 1, 1, 3, 3, 3])
b_new = scatter_max(b, b_idx, dim=0)[0]

rel_out = torch.full_like(b_new, 3) * (b_new != 0).float()
b_new.backward(rel_out)

print(b.detach().numpy(), '---->', b_new.detach().numpy())
print(b.grad.numpy(), '<----', rel_out.numpy())

[ 1.00  1.00  1.00  1.00  2.00  8.00  0.00  0.00  0.00] ----> [ 1.00  8.00  0.00  0.00]
[ 0.00  0.00  0.00  3.00  0.00  3.00  0.00  0.00  0.00] <---- [ 3.00  3.00  0.00  0.00]


## Linear layer with epsilon rule

In [19]:
weight = torch.tensor([
    [1, 2, -1],
    [0, 0, +1],
], dtype=torch.float)
bias = torch.tensor([0, 1], dtype=torch.float)
x = torch.tensor([[4, 8, 0]], dtype=torch.float, requires_grad=True)
y = x @ weight.t() + bias

grad_out = torch.ones_like(y) * (y != 0).float()
y.backward(grad_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy(), '<----', grad_out.detach().numpy())

[[ 4.00  8.00  0.00]] ----> [[20.00  1.00]]
[[ 1.00  2.00  0.00]] <---- [[ 1.00  1.00]]


In [20]:
class LinearEpsilonRelevance(torch.autograd.Function):
    eps = 1e-16
    
    @staticmethod
    def forward(ctx, input, weight, bias):
        Z = weight.t()[None, :, :] * input[:, :, None]
        Zs = Z.sum(dim=1, keepdim=True)
        if bias is not None:
            Zs += bias[None, None, :]
        ctx.save_for_backward(Z, Zs)
        return Zs.squeeze(dim=1)

    @staticmethod
    def backward(ctx, rel_out):
        Z, Zs = ctx.saved_tensors
        eps = rel_out.new_tensor(LinearEpsilonRelevance.eps)
        Zs += torch.where(Zs >= 0, eps, -eps)
        return (rel_out[:, None, :] * Z / Zs).sum(dim=2), None, None
    
def linear_eps(input, weight, bias=None):
    return LinearEpsilonRelevance.apply(input, weight, bias)

No bias -> conservation

In [21]:
weight = torch.tensor([
    [1, 2, -1],
    [0, 0, +1],
], dtype=torch.float)
bias = None
print('W', weight.numpy(), sep='\n', end='\n\n')

x = torch.tensor([
    [4, 8, 0]
], dtype=torch.float, requires_grad=True)

y = linear_eps(x, weight, bias)
rel_out = torch.ones_like(y) * (y != 0).float()
y.backward(rel_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy().round(2), '<----', rel_out.numpy())

W
[[ 1.00  2.00 -1.00]
 [ 0.00  0.00  1.00]]

[[ 4.00  8.00  0.00]] ----> [[20.00  0.00]]
[[ 0.20  0.80  0.00]] <---- [[ 1.00  0.00]]


Bias absorbs relevance

In [22]:
weight = torch.tensor([
    [1, 2, -1],
    [0, 0, +1],
], dtype=torch.float)
bias = torch.tensor([0, 1], dtype=torch.float)
print('W', weight.numpy(), 'b', bias.numpy(), sep='\n', end='\n\n')

x = torch.tensor([
    [4, 8, 0]
], dtype=torch.float, requires_grad=True)

y = linear_eps(x, weight, bias)
rel_out = torch.ones_like(y) * (y != 0).float()
y.backward(rel_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy().round(2), '<----', rel_out.numpy())

W
[[ 1.00  2.00 -1.00]
 [ 0.00  0.00  1.00]]
b
[ 0.00  1.00]

[[ 4.00  8.00  0.00]] ----> [[20.00  1.00]]
[[ 0.20  0.80  0.00]] <---- [[ 1.00  1.00]]


But in a weird way?

In [23]:
weight = torch.tensor([
    [1, 2, -1],
    [0, 0, +1],
], dtype=torch.float)
bias = torch.tensor([0, 100], dtype=torch.float)
print('W', weight.numpy(), 'b', bias.numpy(), sep='\n', end='\n\n')

x = torch.tensor([
    [4, 8, 0]
], dtype=torch.float, requires_grad=True)

y = linear_eps(x, weight, bias)
rel_out = torch.ones_like(y) * (y != 0).float()
y.backward(rel_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy().round(2), '<----', rel_out.numpy())

W
[[ 1.00  2.00 -1.00]
 [ 0.00  0.00  1.00]]
b
[ 0.00 100.00]

[[ 4.00  8.00  0.00]] ----> [[20.00 100.00]]
[[ 0.20  0.80  0.00]] <---- [[ 1.00  1.00]]


Zeros in the input

In [24]:
weight = torch.tensor([
    [+1, -1],
], dtype=torch.float)
bias = None
print('W', weight.numpy(), sep='\n', end='\n\n')

x = torch.tensor([
    [4, 0]
], dtype=torch.float, requires_grad=True)

y = linear_eps(x, weight, bias)
rel_out = torch.ones_like(y) * (y != 0).float()
y.backward(rel_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy().round(2), '<----', rel_out.numpy())

W
[[ 1.00 -1.00]]

[[ 4.00  0.00]] ----> [[ 4.00]]
[[ 1.00  0.00]] <---- [[ 1.00]]


Zeros in the weights

In [25]:
weight = torch.tensor([
    [1, 0, 2],
], dtype=torch.float)
bias = None
print('W', weight.numpy(), sep='\n', end='\n\n')

x = torch.tensor([
    [4, 1, 3]
], dtype=torch.float, requires_grad=True)

y = linear_eps(x, weight, bias)
rel_out = torch.ones_like(y) * (y != 0).float()
y.backward(rel_out)

print(x.detach().numpy(), '---->', y.detach().numpy())
print(x.grad.numpy().round(2), '<----', rel_out.numpy())

W
[[ 1.00  0.00  2.00]]

[[ 4.00  1.00  3.00]] ----> [[10.00]]
[[ 0.40  0.00  0.60]] <---- [[ 1.00]]


## Putting things together

In [26]:
a = torch.tensor([[1, 2, 3, 1, 0]], dtype=torch.float, requires_grad=True)
b = torch.tensor([[0, 1, 2, 4, 5]], dtype=torch.float, requires_grad=True)
c = add(a, b)

print(a.detach().numpy(), '+', b.detach().numpy(), '---->', c.detach().numpy(), end='\n\n')

idx = torch.tensor([0, 0, 1, 1, 2])
d = scatter_add(c, idx, dim=1)

print(idx.numpy())
print(c.detach().numpy(), '---->', d.detach().numpy(), end='\n\n')

weight_u = torch.tensor([
    [1, 0, 2],
    [0, 5, 1],
], dtype=torch.float)
u = linear_eps(d, weight_u, None)

print(weight_u.numpy())
print(d.detach().numpy(), '---->', u.detach().numpy(), end='\n\n')

weight_v = torch.tensor([
    [ 0,  1, -1],
    [-4, -2, -3],
], dtype=torch.float)
v = linear_eps(d, weight_v, None)

print(weight_v.numpy())
print(d.detach().numpy(), '---->', v.detach().numpy(), end='\n\n')

z = add(u, v)
print(u.detach().numpy(), '+', v.detach().numpy(), '---->', z.detach().numpy(), end='\n\n')

rel_out = torch.ones_like(z) * (z != 0).float()
z.backward(rel_out)
print(a.grad.numpy(), ' ', b.grad.numpy(), '<----', rel_out.numpy())
print('  ', a.grad.numpy().sum(), '\t\t\t  +   ', b.grad.numpy().sum(), '\t\t\t==    ', rel_out.numpy().sum())

[[ 1.00  2.00  3.00  1.00  0.00]] + [[ 0.00  1.00  2.00  4.00  5.00]] ----> [[ 1.00  3.00  5.00  5.00  5.00]]

[    0     0     1     1     2]
[[ 1.00  3.00  5.00  5.00  5.00]] ----> [[ 4.00 10.00  5.00]]

[[ 1.00  0.00  2.00]
 [ 0.00  5.00  1.00]]
[[ 4.00 10.00  5.00]] ----> [[14.00 55.00]]

[[ 0.00  1.00 -1.00]
 [-4.00 -2.00 -3.00]]
[[ 4.00 10.00  5.00]] ----> [[ 5.00 -51.00]]

[[14.00 55.00]] + [[ 5.00 -51.00]] ----> [[19.00  4.00]]

[[-0.95 -1.89  2.41  0.80 -0.00]]   [[-0.00 -0.95  1.61  3.21 -2.24]] <---- [[ 1.00  1.00]]
   0.36842078 			  +    1.6315787 			==     2.0


In [27]:
def computational_graph(op, prefix=''):
    print(f'{prefix}- {op} ' + (f'({op.variable.detach().numpy()} at {hex(id(op.variable))})' if op.__class__.__name__ == 'AccumulateGrad' else ''))
    if op is not None: 
        for op in op.next_functions:
            computational_graph(op[0], prefix+'  ')
computational_graph(z.grad_fn)

- <torch.autograd.function.AddRelevanceBackward object at 0x55ee07c718c8> 
  - <torch.autograd.function.LinearEpsilonRelevanceBackward object at 0x55ee08db8b08> 
    - <torch.autograd.function.ScatterAddRelevanceBackward object at 0x55ee08acc198> 
      - <torch.autograd.function.AddRelevanceBackward object at 0x55ee08e00c48> 
        - <AccumulateGrad object at 0x7f8bfd8bac18> ([[ 1.00  2.00  3.00  1.00  0.00]] at 0x7f8bfc8909d8)
        - <AccumulateGrad object at 0x7f8bfd8babe0> ([[ 0.00  1.00  2.00  4.00  5.00]] at 0x7f8bfd8d9678)
      - None 
    - None 
  - <torch.autograd.function.LinearEpsilonRelevanceBackward object at 0x55ee0783b618> 
    - <torch.autograd.function.ScatterAddRelevanceBackward object at 0x55ee08acc198> 
      - <torch.autograd.function.AddRelevanceBackward object at 0x55ee08e00c48> 
        - <AccumulateGrad object at 0x7f8bfd8babe0> ([[ 1.00  2.00  3.00  1.00  0.00]] at 0x7f8bfc8909d8)
        - <AccumulateGrad object at 0x7f8bfd8bafd0> ([[ 0.00  1.00  2.00 

## Index select (not needed)

In [28]:
a = torch.tensor(
    [1, 1, 1, 1, 2, 6, 0, 1, -2], dtype=torch.float, requires_grad=True)
a_idx = torch.tensor(
    [2, 4, 6, 8, 8])
a_new = torch.index_select(a, index=a_idx, dim=0)

grad_out = torch.full_like(a_new, 3) * (a_new != 0).float()
a_new.backward(grad_out)

print(' ' * 61, a_idx.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.detach().numpy())

                                                              [    2     4     6     8     8]
[ 1.00  1.00  1.00  1.00  2.00  6.00  0.00  1.00 -2.00] ----> [ 1.00  2.00  0.00 -2.00 -2.00]
[ 0.00  0.00  3.00  0.00  3.00  0.00  0.00  0.00  6.00] <---- [ 3.00  3.00  0.00  3.00  3.00]


In [29]:
class IndexSelectRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, dim, idx):
        out = torch.index_select(src, dim, idx)
        ctx.dim = dim
        ctx.dim_size = src.shape[dim]
        ctx.save_for_backward(src, idx, out)
        return out

    @staticmethod
    def backward(ctx, rel_out):
        src, idx, out = ctx.saved_tensors
        return torch_scatter.scatter_add(rel_out, idx, dim=ctx.dim, dim_size=ctx.dim_size), None, None
    
def index_select(src, dim, index):
    return IndexSelectRelevance.apply(src, dim, index)

In [30]:
a = torch.tensor(
    [1, 1, 1, 1, 2, 6, 0, 1, -2], dtype=torch.float, requires_grad=True)
a_idx = torch.tensor(
    [2, 4, 6, 8, 8])
a_new = index_select(a, index=a_idx, dim=0)

rel_out = torch.full_like(a_new, 3) * (a_new != 0).float()
a_new.backward(rel_out)

print(' ' * 61, a_idx.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', rel_out.detach().numpy())

                                                              [    2     4     6     8     8]
[ 1.00  1.00  1.00  1.00  2.00  6.00  0.00  1.00 -2.00] ----> [ 1.00  2.00  0.00 -2.00 -2.00]
[ 0.00  0.00  3.00  0.00  3.00  0.00  0.00  0.00  6.00] <---- [ 3.00  3.00  0.00  3.00  3.00]


## Cat (not needed)

In [31]:
a = torch.tensor([ 1,  2, 0], dtype=torch.float, requires_grad=True)
b = torch.tensor([-1, -2, 0], dtype=torch.float, requires_grad=True)
c = torch.cat([a, b], dim=0)

grad_out = torch.full_like(c, 3) * (c != 0).float()
c.backward(grad_out)

print(a.detach().numpy(), ':', b.detach().numpy(), '---->', c.detach().numpy())
print(a.grad.numpy(), ' ', b.grad.numpy(), '<----', grad_out.detach().numpy())

[ 1.00  2.00  0.00] : [-1.00 -2.00  0.00] ----> [ 1.00  2.00  0.00 -1.00 -2.00  0.00]
[ 3.00  3.00  0.00]   [ 3.00  3.00  0.00] <---- [ 3.00  3.00  0.00  3.00  3.00  0.00]


In [32]:
class CatRelevance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dim, *tensors):
        ctx.dim = dim
        ctx.sizes = [t.shape[dim] for t in tensors]
        return torch.cat(tensors, dim)

    @staticmethod
    def backward(ctx, rel_out):
        return (None, *torch.split_with_sizes(rel_out, dim=ctx.dim, split_sizes=ctx.sizes))
    
def cat(tensors, dim=0):
    return CatRelevance.apply(dim, *tensors)

In [33]:
a = torch.tensor([ 1,  2, 0], dtype=torch.float, requires_grad=True)
b = torch.tensor([-1, -2, 0], dtype=torch.float, requires_grad=True)
c = cat((a, b), dim=0)

rel_out = torch.full_like(c, 3) * (c != 0).float()
c.backward(rel_out)

print(a.detach().numpy(), ':', b.detach().numpy(), '---->', c.detach().numpy())
print(a.grad.numpy(), ' ', b.grad.numpy(), '<----', rel_out.detach().numpy())

[ 1.00  2.00  0.00] : [-1.00 -2.00  0.00] ----> [ 1.00  2.00  0.00 -1.00 -2.00  0.00]
[ 3.00  3.00  0.00]   [ 3.00  3.00  0.00] <---- [ 3.00  3.00  0.00  3.00  3.00  0.00]


## Repeat tensor (not needed)

In [34]:
a = torch.tensor([ 7, 2, 0], dtype=torch.float, requires_grad=True)
a_repeats = torch.tensor([ 3, 0, 2])
a_new = tg.utils.repeat_tensor(a, a_repeats, dim=0)

grad_out = torch.ones_like(a_new) * (a_new != 0).float()
a_new.backward(grad_out)

print(a_repeats.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', grad_out.numpy())

[    3     0     2]
[ 7.00  2.00  0.00] ----> [ 7.00  7.00  7.00  0.00  0.00]
[ 3.00  0.00  0.00] <---- [ 1.00  1.00  1.00  0.00  0.00]


In [35]:
def repeat_tensor(src, repeats, dim=0):
    idx = src.new_tensor(np.arange(len(repeats)).repeat(repeats.cpu().numpy()), dtype=torch.long)
    return torch.index_select(src, dim, idx)

In [36]:
a = torch.tensor([ 7, 2, 0], dtype=torch.float, requires_grad=True)
a_repeats = torch.tensor([ 3, 0, 2])
a_new = repeat_tensor(a, a_repeats, dim=0)

rel_out = torch.ones_like(a_new) * (a_new != 0).float()
a_new.backward(rel_out)

print(a_repeats.numpy())
print(a.detach().numpy(), '---->', a_new.detach().numpy())
print(a.grad.numpy(), '<----', rel_out.numpy())

[    3     0     2]
[ 7.00  2.00  0.00] ----> [ 7.00  7.00  7.00  0.00  0.00]
[ 3.00  0.00  0.00] <---- [ 1.00  1.00  1.00  0.00  0.00]
