# Execute DTD-full-backward

1. Get samples from local linear segment (store as $X_L$).
2. For last layer, find which sample of $X_L$ would be a good root.
3. Recursively, derive and find other roots.

In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable

import torch 
import numpy as np
import tqdm.auto
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd

from lrp_relations import dtd
from lrp_relations.utils import to_np


In [None]:
torch.manual_seed(0)
mlp = dtd.MLP(5, 10, 10, 2)

def weight_scale(m: nn.Module) -> nn.Module:
    for p in m.parameters():
        # to compensate for negative biases, we scale the weight
        p.data[p.data > 0] = 1.4 * p.data[p.data > 0]
    if isinstance(m, dtd.LinearReLU):
        m.linear.bias.data = - m.linear.bias.data.abs() 
    return m

mlp.apply(weight_scale)

print(f"the network has {sum(p.numel() for p in mlp.parameters())} parameters")

In [None]:
x = torch.randn(11, 10)

logits = mlp(x)
logits

In [None]:
rule = "z+"
explained_output = 0

torch.manual_seed(0)
for _ in tqdm.auto.trange(100, disable=True):
    x = torch.randn(1, mlp.input_size)
    if mlp(x)[:, explained_output] <= 0:
        continue
    break


root_finder = dtd.LocalSegmentRoots(mlp, n_steps=200, n_warmup=10_000, show_progress=True,)

network_output_fn = dtd.NetworkOutputRelevanceFn(
    mlp, mlp.first_layer, explained_output
)

roots = root_finder.get_root_points_for_layer(
    mlp.first_layer, x, relevance_fn=network_output_fn
)

In [None]:
def get_grad(x: torch.Tensor) -> torch.Tensor:
    x.requires_grad_(True)
    out = mlp(x)
    (grad_x,) = torch.autograd.grad(out[:, explained_output], x)
    return grad_x


root = roots[0]

root_grad = get_grad(root.root)[0]
input_grad = get_grad(root.input)[0]
# root.root - root.input
assert torch.allclose(root_grad, input_grad)

mlp(root.root), mlp(root.input)


mlp_root = mlp(root.root)[:, explained_output]
mlp_input = mlp(root.input)[:, explained_output]

taylor_approx = mlp_root + root_grad @ (root.input - root.root)[0]

assert torch.allclose(taylor_approx, mlp_input, atol=1e-5)


In [None]:
import functools

rel_fn_builder = functools.partial(dtd.DecomposeRelevanceFn, mlp, 
    root_finder=dtd.LocalSegmentRoots(
        mlp, n_steps=2_000, n_chains=50, n_warmup=20_000, show_progress=True
    ),
    stabilize_grad=None,
)

rel_fn_builder

In [None]:
rel_fns = dtd.get_decompose_relevance_fns(
    mlp,
    explained_output=explained_output,
    root_finder=dtd.LocalSegmentRoots(
        mlp, n_steps=2_000, n_chains=50, n_warmup=20_000, show_progress=True
    ),
    decomposition="full",
)

rel_fns[-2](x)

In [None]:
# run decomposition