# LRP for Linear + ReLU $\def\bm#1{{\bf #1}}$

In [1]:
import torch
import numpy as np
import torch_scatter

import sympy
from IPython.display import display, Latex
sympy.init_printing(use_latex=True)
torch.Tensor._repr_latex_ = lambda self: f'${sympy.latex(sympy.Matrix(self.numpy().round(2)))}$'

## Vectorize a simple matrix product
$$
y_j = \sum_i w_{ji} x_i \qquad \bm{y} = \bm{W} \bm{x}
$$

In [2]:
W = torch.tensor([
    [1, 1, 0],
    [0, 1, 2]
], dtype=torch.float)
x = torch.tensor([1,1,1.]).view(-1, 1)

In [3]:
y = torch.zeros(len(W), 1)
for j in range(len(y)):
    for i in range(len(x)):
        y[j] += W[j, i] * x[i]

In [4]:
y = torch.zeros(len(W), 1)
for j in range(len(y)):
    y[j] = sum(W[j, ii] * x[ii] for ii in range(len(x)))

In [5]:
y = torch.zeros(len(W), 1)
for j in range(len(y)):
    y[j] = W[j] @ x

In [6]:
y = W @ x

In [7]:
Latex(fr'''$$ 
\underbrace{{{sympy.latex(sympy.Matrix(W.int()))}}}_{{W}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(x.int()))}}}_{{x}} = 
\underbrace{{{sympy.latex(sympy.Matrix(y.int()))}}}_{{y}} $$''')

<IPython.core.display.Latex object>

## Vectorize LRP
$$
R_i = x_i \sum_j \frac{w_{ij}}{\sum_{i'} w_{ji'} x_{i'}} Q_j
$$

In [8]:
W = torch.tensor([
    [1, 1, 0],
    [0, 1, 2]
], dtype=torch.float)
x = torch.tensor([1,1,1.]).view(-1, 1)
y = W @ x
Q = torch.ones_like(y)

In [9]:
R = torch.zeros_like(x)
for i in range(len(R)):
    for j in range(len(Q)):
        den = sum(W[j, ii] * x[ii] for ii in range(len(x)))
        R[i] += W[j, i] * Q[j] / den
    R[i] *= x[i]

In [10]:
R = torch.zeros_like(x)
for i in range(len(R)):
    for j in range(len(Q)):
        den = W[j] @ x
        R[i] += W[j, i] * Q[j] / den
    R[i] *= x[i]

In [11]:
R = torch.zeros_like(x)
den = W @ x
for i in range(len(R)):
    for j in range(len(Q)):
        R[i] += W[j, i] * Q[j] / den[j]
    R[i] *= x[i]

In [12]:
R = torch.zeros_like(x)
den = W @ x
prod = Q / den
for i in range(len(R)):
    for j in range(len(Q)):
        R[i] += W[j, i] * prod[j]
    R[i] *= x[i]

In [13]:
R = torch.zeros_like(x)
den = W @ x
prod = Q / den
for i in range(len(R)):
    R[i] = sum(W[j, i] * prod[j] for j in range(len(Q)) )
    R[i] *= x[i]

In [14]:
R = torch.zeros_like(x)
den = W @ x
prod = Q / den
R = W.t() @ prod
for i in range(len(R)):
    R[i] *= x[i]

In [15]:
R = torch.zeros_like(x)
den = W @ x
prod = Q / den
R = x * (W.t() @ prod)

In [16]:
R = x * (W.t() @ (Q / (W @ x)))

In [17]:
Latex(fr'''$$ 
\underbrace{{{sympy.latex(sympy.Matrix(R.numpy().round(2)))}}}_{{R}} \qquad
\underbrace{{{sympy.latex(sympy.Matrix(W.int()))}}}_{{W}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(x.int()))}}}_{{x}} = 
\underbrace{{{sympy.latex(sympy.Matrix(y.int()))}}}_{{y}} \qquad
\underbrace{{{sympy.latex(sympy.Matrix(Q.int()))}}}_{{Q}} $$''')

<IPython.core.display.Latex object>

## LRP batch version

Instead of having $x$ as a single column vector of size 3, we have $X$ which is $N \times 3$

In [18]:
W = torch.tensor([
    [-1, 1, 0],
    [ 0, 1, 2]
], dtype=torch.float)
b = torch.tensor([0, 0], dtype=torch.float)

X = torch.tensor([
    [1, 1, 1],
    [1, 2, 3],
    [0, 3, 2],
    [3, 0, 2],
    [1, 2, 0],
    [1, 0, 0],
    [0, 0, 1],
], dtype=torch.float)

Y = (X @ W.t() + b).clamp(min=0)

In [19]:
Latex(fr'''$$ 
\text{{ReLU}} \Big(
\underbrace{{{sympy.latex(sympy.Matrix(X))}}}_{{X}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(W.t()))}}}_{{W^T}} \Big) = 
\underbrace{{{sympy.latex(sympy.Matrix(Y))}}}_{{Y}}$$''')

<IPython.core.display.Latex object>

$W^{2}$ rule for $\mathbb{R}^D$

In [20]:
Q = torch.ones_like(Y) * (Y != 0).float()
R = Q @ (W.pow(2) / (W.pow(2).sum(dim=1, keepdim=True) + 10e-6))

In [21]:
Latex(fr'''$$ 
\underbrace{{{sympy.latex(sympy.Matrix(R.numpy().round(2)))}}}_{{R}} \qquad
\text{{ReLU}} \Big(
\underbrace{{{sympy.latex(sympy.Matrix(X.int()))}}}_{{X}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(W.t().int()))}}}_{{W^T}} \Big) = 
\underbrace{{{sympy.latex(sympy.Matrix(Y.int()))}}}_{{Y}} \qquad
\underbrace{{{sympy.latex(sympy.Matrix(Q.int()))}}}_{{Q}} $$''')

<IPython.core.display.Latex object>

$z$ rule for $\mathbb{R}_+^D$

In [22]:
Q = torch.ones_like(Y) * (Y != 0).float()
R = X * ((Q / (X @ W.t() + 10e-6)) @ W)

In [23]:
Latex(fr'''$$ 
\underbrace{{{sympy.latex(sympy.Matrix(R.numpy().round(2)))}}}_{{R}} \qquad
\text{{ReLU}} \Big(
\underbrace{{{sympy.latex(sympy.Matrix(X.int()))}}}_{{X}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(W.t().int()))}}}_{{W^T}} \Big) = 
\underbrace{{{sympy.latex(sympy.Matrix(Y.int()))}}}_{{Y}} \qquad
\underbrace{{{sympy.latex(sympy.Matrix(Q.int()))}}}_{{Q}} $$''')

<IPython.core.display.Latex object>

$z^+$ rule for $\mathbb{R}_+^D$

In [24]:
Q = torch.ones_like(Y) * (Y != 0).float()
R = X * ((Q / (X @ W.clamp(min=0).t() + 10e-6)) @ W.clamp(min=0))

In [25]:
Latex(fr'''$$ 
\underbrace{{{sympy.latex(sympy.Matrix(R.numpy().round(2)))}}}_{{R}} \qquad
\text{{ReLU}} \Big(
\underbrace{{{sympy.latex(sympy.Matrix(X.int()))}}}_{{X}} \cdot
\underbrace{{{sympy.latex(sympy.Matrix(W.t().int()))}}}_{{W^T}} \Big) = 
\underbrace{{{sympy.latex(sympy.Matrix(Y.int()))}}}_{{Y}} \qquad
\underbrace{{{sympy.latex(sympy.Matrix(Q.int()))}}}_{{Q}} $$''')

<IPython.core.display.Latex object>