# Reimplement the DTD 

Goal:

- Compute higher order relevances
- Follow the theory 100%


In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Callable

import torch 
from torch import nn
import matplotlib.pyplot as plt

from lrp_relations import dtd

Equation 4, Appendix DTD
$
x_i - \tilde x_i = \frac
    {\sum_k (x_k w_{kj} + b_k)}
    {\sum_l (v_l w_{lj}}
    v_i
$

where $ v_i = x_i 1_{w_{ij} \ge 0} $ 


In [None]:
torch.manual_seed(2)
net = dtd.TwoLayerMLP(input_size=3, hidden_size=5, output_size=1)

net.layer2.linear.weight.data = 10 * net.layer2.linear.weight.data.abs()

In [None]:
def get_relevance_hidden(
    net: nn.Module,
    x: torch.Tensor,
) -> torch.Tensor:
    outputs = {}
    handles = []

    def add_hook(module: nn.Module) -> None:
        handles.append(module.register_forward_hook(save_outputs))

    def remove_hooks() -> None:
        for handle in handles:
            handle.remove()

    def save_outputs(
        module: nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
    ) -> None:
        outputs[module] = output

    net.apply(add_hook)


    logit = net(x)
    remove_hooks()

    # outputs
    hidden = outputs[net.layer1]
    hidden_root = dtd.root_point(hidden, net.layer2, 0, rule="z+")

    output_root = net.layer2.linear(hidden_root)
    (rel_grad,) = torch.autograd.grad(
        output_root,
        hidden_root,
        grad_outputs=torch.ones_like(output_root),
        retain_graph=True,
    )
    rel_hidden = rel_grad * (hidden - hidden_root)
    return rel_hidden



# x_rel_grad = torch.autograd.grad(
#     rel_hidden,
#     x,
#     grad_outputs=torch.ones_like(rel_hidden),
#     retain_graph=True,
# )

def find_root_point(
    net: nn.Module, 
    x: torch.Tensor,
    j: int,
    relevance_fn: Callable[[nn.Module, torch.Tensor], torch.Tensor],
):
    assert x.size(0) == 1
    x_search = x + torch.randn(1000, x.size(1))

    rel_hidden_search = relevance_fn(net, x_search)


    print((rel_hidden_search[:, j] == 0.).float().sum())
    idx = rel_hidden_search[:, j].abs().argmin()

    steps = torch.linspace(0, 1, 100).view(-1, 1)

    x_line_search = x_search[idx] + steps * (x[0] - x_search[idx]).unsqueeze(0)

    rel_hidden_line_search = relevance_fn(net, x_line_search)
    root_idx = (rel_hidden_line_search[:, j] <= 1e-5).nonzero().max()
    print(root_idx)
    return x_line_search[root_idx]

In [None]:
torch.manual_seed(3)
x = (0.25 * torch.randn(1, 3, requires_grad=True) + 3).clamp(min=0)

rel_hidden = get_relevance_hidden(net, x)
x_root = dtd.root_point(x, net.layer1, 0)

x_root_search = find_root_point(net, x, 0, get_relevance_hidden)

In [None]:
x_root_search

In [None]:
x, x_root