# Reimplement the DTD 

Goal:

- Compute higher order relevances
- Follow the theory 100%


In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable

import torch 
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd

from lrp_relations import dtd

Equation 4, Appendix DTD
$
x_i - \tilde x_i = \frac
    {\sum_k (x_k w_{kj} + b_k)}
    {\sum_l (v_l w_{lj}}
    v_i
$

where $ v_i = x_i 1_{w_{ij} \ge 0} $ 


In [None]:
torch.manual_seed(2)
net = dtd.TwoLayerMLP(input_size=3, hidden_size=5, output_size=1)


data = net.layer2.linear.weight.data.clone()
data[:, :2] = 5 * data.abs()[:, :2]
net.layer2.linear.weight.data = data
net.layer2.linear.weight
net.layer2.linear.bias.data.abs_()
net.layer1.linear.bias.data.abs_()

In [None]:
torch.manual_seed(3)
x = (0.25 * torch.randn(1, 3, requires_grad=True) + 3).clamp(min=0)


rules = ["0", "x", "z+", "w2", "gamma"]
mode = "sum"
info = []
for rule in rules:
    print("-" * 80)
    print("Rule:", rule)

    def relevance_fn(net: nn.Module, x: torch.Tensor) -> torch.Tensor:
        rel = dtd.get_relevance_hidden(net, x, rule=rule, gamma=1000)

        if mode == "sum":
            return rel.sum(dim=1, keepdim=True)
        else:
            return rel

    with dtd.record_all_outputs(net) as x_outs:
        logit_x = net(x)

    rel_hidden = relevance_fn(net, x)
    (grad_rel_hidden,) = torch.autograd.grad(
        [rel_hidden[:, 0]],
        [x],
    )
    hidden_root = dtd.root_point(
        x_outs[net.layer1][0], net.layer2, 0, rule=rule, gamma=1000
    )
    x_root = dtd.find_input_root_point(
        net, x, 0, relevance_fn, n_samples=20_000, plot=True
    )

    with dtd.record_all_outputs(net) as x_root_outs:
        logit_x_root = net(x_root)

    print(x_root.shape)
    (
        rel_hidden_for_x_root,
        hidden_root_for_x_root,
    ) = dtd.get_relevance_hidden_and_root(
        net, x_root.unsqueeze(0), rule=rule, gamma=1000
    )
    if mode == "sum":
        rel_hidden_for_x_root = rel_hidden_for_x_root.sum(dim=1, keepdim=True)

    (grad_rel_hidden_for_x_root,) = torch.autograd.grad(
        [rel_hidden_for_x_root[:, 0]],
        [x_root],
    )

    info.append(
        dict(
            rule=rule,
            x=x.tolist(),
            hidden=x_outs[net.layer1][0].tolist(),
            x_root=x_root.tolist(),
            hidden_for_x_root=x_root_outs[net.layer1][0].tolist(),
            rel_hidden=rel_hidden.tolist(),
            hidden_root=hidden_root.tolist(),
            logit_x=logit_x.tolist(),
            logit_x_root=logit_x_root.tolist(),
            hidden_root_for_x_root=hidden_root_for_x_root.tolist(),
            rel_hidden_for_x_root=rel_hidden_for_x_root.tolist(),
            grad_rel_hidden=grad_rel_hidden.tolist(),
            grad_rel_hidden_for_x_root=grad_rel_hidden_for_x_root.tolist(),
        )
    )


df_roots = pd.DataFrame(info)


In [None]:
df_roots[
    [
        "rule",
        "x_root",
        "hidden",
        "hidden_for_x_root",
        "rel_hidden",
        "rel_hidden_for_x_root",
        "logit_x",
        "logit_x_root",
        "grad_rel_hidden",
        "grad_rel_hidden_for_x_root",
    ]
]

In [None]:
for gamma in [10, 1000, 1_000_000]:
    root_gamma = dtd.root_point(
        x_outs[net.layer1][0], net.layer2, j=0, rule="gamma", gamma=gamma
    )
    print(gamma, root_gamma, net.layer2(root_gamma))

# rel_gamma = get_relevance_hidden(net, x, rule="gamma", gamma=10)

In [None]:
roots = []
for hidden_idx in range(net.hidden_size):

    x_root_search = find_root_point(net, x, hidden_idx, 
        lambda m, t: get_relevance_hidden(m, t, rule="gamma", gamma=1000),
        )
    with dtd.record_all_outputs(net) as root_outs:
        logits = net(x_root_search.unsqueeze(0))
        print(root_outs[net.layer1][0][:, hidden_idx], logits[0])

    roots.append(x_root_search)

In [None]:
root_mean = torch.mean(torch.stack(roots), dim=0)

with dtd.record_all_outputs(net) as root_outs:
    logits = net(root_mean.unsqueeze(0))
    print(root_outs[net.layer1][0], logits[0])
    print(x_outs[net.layer1][0])


In [None]:
torch.manual_seed(2)

net = dtd.NLayerMLP(
    n_layers=3,
    input_size=2,
    hidden_size=30,
    output_size=2,
)


def weight_scale(m: nn.Module) -> nn.Module:
    for p in m.parameters():
        p.data[p.data > 0] = 1.2 * p.data[p.data > 0]
    if isinstance(m, dtd.LinearReLU):
        m.linear.bias.data = - m.linear.bias.data.abs() 
    return m


net.apply(weight_scale)


In [None]:
n_grid_points = 500

torch.manual_seed(0)
point = torch.randn(1, net.input_size)

grid_line = torch.linspace(-1, 1, n_grid_points)

grid = torch.meshgrid(grid_line, grid_line, indexing="xy")

In [None]:
x_a = grid[0].flatten()
x_b = grid[1].flatten()

inputs = point.repeat(x_a.size(0), 1)
inputs[:, 0] = x_a
inputs[:, 1] = x_b

inputs.shape

In [None]:
inputs.requires_grad_(True)
logits = net(inputs)


grads_logit_0, = torch.autograd.grad(
    logits[:, 0],
    inputs,
    grad_outputs=torch.ones_like(logits[:, 0]),
    retain_graph=True,
)
grads_logit_1, = torch.autograd.grad(
    logits[:, 1],
    inputs,
    grad_outputs=torch.ones_like(logits[:, 0]),
)

for i in range(2):
    plt.scatter(
        x_a.numpy(),
        x_b.numpy(),
        # c=net(inputs).argmax(dim=1).numpy(),
        # c=net(inputs)[:, 0].detach().numpy(),
        c=logits[:, i].detach().numpy(),
    )
    plt.colorbar()
    plt.title(f"Logit {i}")
    plt.show()


In [None]:
import seaborn.apionly as sns
import numpy as np
colors = np.array(sns.color_palette("colorblind", 100))
print(colors)

In [None]:
for grad in [grads_logit_0, grads_logit_1]:

    grad_colors = dtd.almost_unique(grad, atol=1e-4)
    print(grad_colors.unique().shape)
    plt.scatter(
        x_a.numpy(),
        x_b.numpy(),
        c=colors[grad_colors.numpy() % len(colors)],
        marker='.',
        s=1,
    )
    plt.show()