# LRP for Index and Scatter Operations

In [1]:
from ipywidgets import Output, HBox, VBox
from IPython.display import HTML
import numpy as np
import torch
import torch_scatter
import warnings

In [12]:
class RelevanceFunction(object):
    def __init__(self):
        self.saved_stuff = None

In [13]:
def table(x):
    return f'<table>{x}</table>'
def tr(x):
    return f'<tr>{x}</tr>'
def td(x):
    return f'<td style="text-align: left; vertical-align: middle;">{x}</td>'
def html_repr(x):
    if isinstance(x, torch.Tensor):
        x = x.numpy()
    if isinstance(x, np.ndarray):
        if x.size == 1:
            x = x.item()
        else:
            x = x.reshape(-1, 1)
    return '<pre>' + str(x).replace('\n', '<br/>') + '</pre>'
def make_table(items):
    res = ''
    for row in items:
        row_html = ''
        for cell in row:
            row_html += td(html_repr(cell))
        res += tr(row_html)
    return HTML(table(res))

In [14]:
import sys

def simple_format(message, category, filename, lineno, line=None):
    return f'{category.__name__}: {message}'
warnings.formatwarning = simple_format

## Index select

In [15]:
def index_select(input, idx):
    out = torch.empty((idx.shape[0], *input.shape[1:]))
    for input_idx, output_idx in zip(idx, range(len(idx))):
        out[output_idx] = input[input_idx]
    return out

In [16]:
input = torch.tensor([.7, .2, .0, .5, .3, .1])
idx = torch.tensor([0, 2, 5, 5])
output = index_select(input, idx=idx)

make_table([
    ['input', 'idx', 'output'],
    [input, idx, output]
])

0,1,2
input,idx,output
[[0.7]  [0.2]  [0. ]  [0.5]  [0.3]  [0.1]],[[0]  [2]  [5]  [5]],[[0.7]  [0. ]  [0.1]  [0.1]]


In [17]:
def index_select_relevance(input, idx, rel_out):
    rel_in = torch.zeros(input.shape)
    for idx_input, idx_output in zip(idx, range(len(rel_out))):
        rel_in[idx_input] += rel_out[idx_output]
    return rel_in

In [18]:
input = torch.tensor([.7, .2, .0, .5, .3, .1])
idx = torch.tensor([0, 2, 5, 5])
rel_out = torch.tensor([5, 0, 2, 1], dtype=torch.float)
rel_in = index_select_relevance(input, idx, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.7]  [0.2]  [0. ]  [0.5]  [0.3]  [0.1]],[[0]  [2]  [5]  [5]],[[0.7]  [0. ]  [0.1]  [0.1]],[[5.]  [0.]  [2.]  [1.]],[[0]  [2]  [5]  [5]],[[5.]  [0.]  [0.]  [0.]  [0.]  [3.]]
Σ,,,8.0,,8.0


In [26]:
class IndexSelect(RelevanceFunction):
    def forward(self, input, idx, dim):
        self.saved_stuff = idx, dim, input.shape[dim]
        return torch.index_select(input, dim=dim, index=idx)
    def relevance(self, rel_out):
        idx, dim, input_dim = self.saved_stuff
        return torch_scatter.scatter_add(rel_out, idx, dim=dim, dim_size=input_dim)

In [27]:
op = IndexSelect()
input = torch.tensor([.7, .2, .0, .5, .3, .1])
idx = torch.tensor([0, 2, 5, 5])
output = op.forward(input, idx, dim=0)
rel_out = torch.tensor([5, 0, 2, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.7]  [0.2]  [0. ]  [0.5]  [0.3]  [0.1]],[[0]  [2]  [5]  [5]],[[0.7]  [0. ]  [0.1]  [0.1]],[[5.]  [0.]  [2.]  [1.]],[[0]  [2]  [5]  [5]],[[5.]  [0.]  [0.]  [0.]  [0.]  [3.]]
Σ,,,8.0,,8.0


## Sum Pooling

In [28]:
def sum_pooling(input):
    return input.sum()

def sum_pooling_relevance(input, output, rel_out):
    return rel_out * input / output

In [29]:
input = torch.tensor([.7, .2, .0, .5, .5, .1])
output = sum_pooling(input)
rel_out = 10.0
rel_in = sum_pooling_relevance(input, output, rel_out)

make_table([
    ['input', 'output', 'rel_out', 'rel_in'],
    [input,    output,   rel_out,   rel_in],
    ['Σ',          '',   rel_out,   rel_in.sum()],
])

0,1,2,3
input,output,rel_out,rel_in
[[0.7]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],2.0,10.0,[[3.5]  [1. ]  [0. ]  [2.5]  [2.5]  [0.5]]
Σ,,10.0,10.0


In [34]:
class SumPooling(RelevanceFunction):
    def forward(self, input):
        output = input.sum()
        self.saved_stuff = input, output
        return output
    def relevance(self, rel_out):
        input, output = self.saved_stuff
        if output == 0:
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
            return torch.zeros_like(input)
        return rel_out * input / output

In [36]:
op = SumPooling()
input = torch.tensor([.7, .2, .0, .5, .5, .1])
output = op.forward(input)
rel_out = 10.0
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'output', 'rel_out', 'rel_in'],
    [input,    output,   rel_out,   rel_in],
    ['Σ',          '',   rel_out,   rel_in.sum()],
])

0,1,2,3
input,output,rel_out,rel_in
[[0.7]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],2.0,10.0,[[3.5]  [1. ]  [0. ]  [2.5]  [2.5]  [0.5]]
Σ,,10.0,10.0


In [39]:
op = SumPooling()
input = torch.tensor([.7, .2, .1, -.7, -.3, .0])
output = op.forward(input)
rel_out = 10.0
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'output', 'rel_out', 'rel_in'],
    [input,    output,   rel_out,   rel_in],
    ['Σ',          '',   rel_out,   rel_in.sum()],
])



0,1,2,3
input,output,rel_out,rel_in
[[ 0.7]  [ 0.2]  [ 0.1]  [-0.7]  [-0.3]  [ 0. ]],0.0,10.0,[[0.]  [0.]  [0.]  [0.]  [0.]  [0.]]
Σ,,10.0,0.0


## Scatter Add

In [40]:
def scatter_add(input, idx, size):
    output = torch.zeros(size)
    for idx_input, idx_output in zip(range(len(input)), idx):
        output[idx_output] += input[idx_input]
    return output

In [41]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = scatter_add(input, idx, 4)

make_table([
    ['input', 'idx', 'output'],
    [input,    idx,   output]
])

0,1,2
input,idx,output
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]]


In [42]:
def scatter_add_relevance(input, idx, output, rel_out):
    rel_in = torch.zeros_like(input)
    for idx_input, idx_output in zip(range(len(input)), idx):
        rel_in[idx_input] = rel_out[idx_output] * input[idx_input] / output[idx_output]
    return rel_in

In [43]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = scatter_add_relevance(input, idx, output, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,9.0,,9.0


Corner case: if an output neuron with 0 activation (either because no index was pointing at it or because all of its inputs were 0) receives relevance, that relevance is _lost_.

In [44]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = scatter_add_relevance(input, idx, output, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,19.0,,9.0


In [53]:
class ScatterAdd(RelevanceFunction):
    def forward(self, input, idx, dim, dim_size):
        output = torch_scatter.scatter_add(input, idx, dim=dim, dim_size=dim_size)
        self.saved_stuff = input, idx, dim, output
        return output
    def relevance(self, rel_out):
        input, idx, dim, output = self.saved_stuff
        if ((output == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        return torch.index_select(rel_out / output, dim, idx) * input

In [55]:
op = ScatterAdd()
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = op.forward(input, idx, dim=0, dim_size=4)
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,9.0,,9.0


In [56]:
op = ScatterAdd()
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = op.forward(input, idx, dim=0, dim_size=4)
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])



0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,19.0,,9.0


## Scatter Mean

In [64]:
def scatter_mean(input, idx, size):
    output = torch.zeros(size)
    counts = torch.zeros(size, dtype=torch.int)
    for idx_input, idx_output in zip(range(len(input)), idx):
        output[idx_output] += input[idx_input]
        counts[idx_output] += 1
    return output / counts.float().clamp(min=1), counts

In [65]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output, counts = scatter_mean(input, idx, 4)

make_table([
    ['input', 'idx', 'output', 'counts'],
    [input,    idx,   output.numpy().round(2),   counts],
])

0,1,2,3
input,idx,output,counts
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.35]  [0.33]  [0. ]  [0.1 ]],[[2]  [3]  [0]  [1]]


In [66]:
def scatter_mean_relevance(input, idx, output, counts, rel_out):
    output = output * counts.float()
    rel_in = torch.zeros_like(input)
    for idx_input, idx_output in zip(range(len(input)), idx):
        rel_in[idx_input] = rel_out[idx_output] * input[idx_input] / output[idx_output]
    return rel_in

In [67]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = scatter_mean_relevance(input, idx, output, counts, rel_out)

make_table([
    ['input', 'idx', 'output', 'counts', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output.numpy().round(2),   counts,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',       '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5,6
input,idx,output,counts,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.35]  [0.33]  [0. ]  [0.1 ]],[[2]  [3]  [0]  [1]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,,9.0,,9.0


Corner case: if an output neuron with 0 activation (either because no index was pointing at it or because all of its inputs were 0) receives relevance, that relevance is _lost_.

In [68]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = scatter_mean_relevance(input, idx, output, counts, rel_out)

make_table([
    ['input', 'idx', 'output', 'counts', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output.numpy().round(2),   counts,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',       '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5,6
input,idx,output,counts,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.35]  [0.33]  [0. ]  [0.1 ]],[[2]  [3]  [0]  [1]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,,19.0,,9.0


In [69]:
class ScatterMean(RelevanceFunction):
    def forward(self, input, idx, dim_size):
        sums = torch_scatter.scatter_add(input, idx, dim=0, dim_size=dim_size)
        counts = torch_scatter.scatter_add(torch.ones_like(input), idx, dim=0, dim_size=dim_size)
        self.saved_stuff = input, idx, sums
        return sums / counts.clamp(min=1)
    def relevance(self, rel_out):
        input, idx, sums = self.saved_stuff
        if ((sums == 0) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        return torch.index_select(rel_out / sums, 0, idx) * input

In [72]:
op = ScatterMean()
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = op.forward(input, idx, 4)
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output.numpy().round(2), rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',     rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.35]  [0.33]  [0. ]  [0.1 ]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,9.0,,9.0


In [73]:
op = ScatterMean()
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = op.forward(input, idx, 4)
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output.numpy().round(2), rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',     rel_out.sum(), '', rel_in.sum()],
])



0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.35]  [0.33]  [0. ]  [0.1 ]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[3.57]  [1.43]  [0. ]  [1.5 ]  [1.5 ]  [1. ]]
Σ,,,19.0,,9.0


## Scatter Max

In [31]:
def scatter_max(input, idx, size):
    output = torch.zeros(size)
    idx_maxes = - torch.ones(size, dtype=torch.int)
    for idx_input, idx_output in zip(range(len(input)), idx):
        if input[idx_input] > output[idx_output]:
            output[idx_output] = input[idx_input]
            idx_maxes[idx_output] = idx_input
    return output, idx_maxes

In [32]:
input = torch.tensor([.5, .1, .0, .7, .3, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output, idx_maxes = scatter_max(input, idx, 4)

make_table([
    ['input', 'idx', 'output', 'idx_maxes'],
    [input,    idx,   output,   idx_maxes],
])

0,1,2,3
input,idx,output,idx_maxes
[[0.5]  [0.1]  [0. ]  [0.7]  [0.3]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.5]  [0.7]  [0. ]  [0.1]],[[ 0]  [ 3]  [-1]  [ 5]]


In [33]:
def scatter_max_relevance(input, idx_maxes, rel_out):
    rel_in = torch.zeros_like(input)
    for idx_input, idx_output in zip(idx_maxes, range(len(idx_maxes))):
        if idx_output != -1:
            rel_in[idx_input] = rel_out[idx_output]
    return rel_in

In [34]:
input = torch.tensor([.5, .1, .0, .7, .3, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = scatter_max_relevance(input, idx_maxes, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.1]  [0. ]  [0.7]  [0.3]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.5]  [0.7]  [0. ]  [0.1]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[5.]  [0.]  [0.]  [3.]  [0.]  [1.]]
Σ,,,9.0,,9.0


Corner case: if an output neuron with 0 activation (either because no index was pointing at it or because all of its inputs were 0) receives relevance, that relevance is _lost_.

In [35]:
input = torch.tensor([.5, .1, .0, .7, .3, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = scatter_max_relevance(input, idx_maxes, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.1]  [0. ]  [0.7]  [0.3]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.5]  [0.7]  [0. ]  [0.1]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[5.]  [0.]  [0.]  [3.]  [0.]  [1.]]
Σ,,,19.0,,9.0


In [74]:
class ScatterMax(RelevanceFunction):
    def forward(self, input, idx, dim, dim_size):
        output, idx_maxes = torch_scatter.scatter_max(input, idx, dim=dim, dim_size=dim_size)
        self.saved_stuff = idx, dim, input.shape[dim], output, idx_maxes
        return output, idx_maxes
    def relevance(self, rel_out):
        idx, dim, input_dim, output, idx_maxes = self.saved_stuff
        if ((idx_maxes == -1) & (rel_out > 0)).any():
            warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
        # Where idx_maxes==-1 set idx=0 so that the indexes are valid for scatter_add
        # The corresponding relevance should already be 0, but set it relevance=0 to be sure
        rel_out = torch.where(idx_maxes != -1, rel_out, torch.zeros_like(rel_out))
        idx_maxes = torch.where(idx_maxes != -1, idx_maxes, torch.zeros_like(idx_maxes))
        
        return torch_scatter.scatter_add(rel_out, idx_maxes, dim=dim, dim_size=input_dim)

In [78]:
op = ScatterMax()
input = torch.tensor([.5, .1, .0, .7, .3, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output, _ = op.forward(input, idx, dim=0, dim_size=4)
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.1]  [0. ]  [0.7]  [0.3]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.5]  [0.7]  [0. ]  [0.1]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[5.]  [0.]  [0.]  [3.]  [0.]  [1.]]
Σ,,,9.0,,9.0


In [79]:
op = ScatterMax()
input = torch.tensor([.5, .1, .0, .7, .3, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output, _ = op.forward(input, idx, dim=0, dim_size=4)
rel_out = torch.tensor([5, 3, 10, 1], dtype=torch.float)
rel_in = op.relevance(rel_out)
make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum()],
])



0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.1]  [0. ]  [0.7]  [0.3]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.5]  [0.7]  [0. ]  [0.1]],[[ 5.]  [ 3.]  [10.]  [ 1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[5.]  [0.]  [0.]  [3.]  [0.]  [1.]]
Σ,,,19.0,,9.0


## Scatter Max as $\ell_p$ pooling

In [80]:
def scatter_lp(input, idx, size, p=1000):
    return scatter_add(input ** p, idx, size) ** (1/p)

With `p=1` it's the same as `sum`:

In [81]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = scatter_lp(input, idx, 4, p=1)

make_table([
    ['input', 'idx', 'output'],
    [input,    idx,   output],
])

0,1,2
input,idx,output
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.7]  [1. ]  [0. ]  [0.1]]


The highest the `p` the closes we get to `max`

In [82]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
output = scatter_lp(input, idx, 4, p=10)

make_table([
    ['input', 'idx', 'output'],
    [input,    idx,   output],
])

0,1,2
input,idx,output
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.50000525]  [0.5358867 ]  [0. ]  [0.1 ]]


How to define relevance for powers like $x^p$? It's not linear anymore therefore the Taylor decomposition fails.

In [83]:
def scatter_lp_relevance(input, idx, output, rel_out):
    return scatter_add_relevance(input, idx, output, rel_out)

In [85]:
input = torch.tensor([.5, .2, .0, .5, .5, .1])
idx = torch.tensor([0, 0, 1, 1, 1, 3])
rel_out = torch.tensor([5, 3, 0, 1], dtype=torch.float)
rel_in = scatter_add_relevance(input, idx, output, rel_out)

make_table([
    ['input', 'idx', 'output', 'rel_out', 'idx', 'rel_in'],
    [input,    idx,   output,   rel_out,   idx,   rel_in.numpy().round(2)],
    ['Σ',        '',      '',   rel_out.sum(), '', rel_in.sum().numpy().round(2)],
])

0,1,2,3,4,5
input,idx,output,rel_out,idx,rel_in
[[0.5]  [0.2]  [0. ]  [0.5]  [0.5]  [0.1]],[[0]  [0]  [1]  [1]  [1]  [3]],[[0.50000525]  [0.5358867 ]  [0. ]  [0.1 ]],[[5.]  [3.]  [0.]  [1.]],[[0]  [0]  [1]  [1]  [1]  [3]],[[5. ]  [2. ]  [0. ]  [2.8]  [2.8]  [1. ]]
Σ,,,9.0,,13.6
