# Check how the network output changes for the input roots

1. Compute the roots of the train-free DTD 
2. For each input root, we compute the network output 
3. Is the network output set to zero by the root? No!
4. Do the roots have the same gradient as the input? No!

In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable, cast

import dataclasses
import torch 
import numpy as np
import tqdm.auto
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from lrp_relations import dtd, local_linear
from lrp_relations.utils import to_np


In [None]:
@dataclasses.dataclass
class NotebookArgs:
    root_finder: str = "linear_dtd"
    explained_output: slice = slice(0, 1)
    rule = dtd.rules.z_plus


args = NotebookArgs()

torch.manual_seed(2)
mlp = dtd.MLP(4, 10, 10, 2)
mlp.init_weights()

print(f"the network has {sum(p.numel() for p in mlp.parameters())} parameters")


torch.manual_seed(1)
x = mlp.get_input_with_output_greater(
    0.25, args.explained_output, non_negative=True
)

mlp_output = mlp.slice(output=args.explained_output)

assert mlp_output(x).shape == (1, 1)


x[:, args.explained_output].shape


In [None]:
if args.root_finder == "interpolation":
    root_finder = dtd.InterpolationRootFinder(
        mlp,
        use_cache=True,
        use_candidates_cache=True,
        args=local_linear.InterpolationArgs(
            batch_size=50,
            n_refinement_steps=10,
            n_batches=1,
            show_progress=True,
            enforce_non_negative=True,
        ),
    )
elif args.root_finder == "linear_dtd":
    root_finder = dtd.LinearDTDRootFinder(
        mlp,
        args.explained_output.start,
        args.rule,
    )
else:
    raise ValueError(f"unknown root_finder: {args.root_finder}")

rel_fn_builder = dtd.TrainFreeFn.get_fn_builder(
    mlp,
    root_finder=root_finder,
    check_consistent=False,
)

rel_fns = dtd.get_decompose_relevance_fns(
    mlp, args.explained_output, rel_fn_builder
)


In [None]:
data = []
torch.manual_seed(0)
n_errors = 0
n_points = 1000
pbar = tqdm.auto.tqdm(total=n_points)
while True:
    try:
        x = mlp.get_input_with_output_greater(
            0.1,
            args.explained_output,
            non_negative=True,
            seed=int(torch.randint(0, 2**32, (1,)).item()),
        )

        rel_result = cast(dtd.TrainFreeRel, rel_fns[-1](x))

        outputs = torch.cat([mlp_output(r.root) for r in rel_result.roots])

        grads = torch.cat(
            [mlp_output.compute_input_grad(r.root) for r in rel_result.roots]
        )
        data.append(
            dict(
                x=to_np(x),
                output_x=mlp_output(x).item(),
                grad_input=to_np(mlp_output.compute_input_grad(x)),
                output_roots=to_np(outputs),
                grad_roots=to_np(grads),
            )
        )
    except AssertionError:
        n_errors += 1

    pbar.update(1)
    pbar.set_postfix(error_percentage=n_errors / pbar.n)
    pbar.refresh()
    if pbar.n >= n_points:
        break


In [None]:
df = pd.DataFrame(data)

display(df)


In [None]:
np.stack(df.output_roots - df.output_x).shape


In [None]:
diffs = np.stack(df.output_x - df.output_roots)

bins = 20
plt.hist(diffs.flatten(), bins=bins, density=True)
plt.hist(df.output_x, bins=bins, density=True, alpha=0.5)
plt.show()


In [None]:
percentage_almost_zero = (diffs < 1e-6).mean()
print(
    "Difference between output and roots is almost "
    f"zero for: {percentage_almost_zero:.2%}"
)


In [None]:
(diffs < 1e-6).mean(axis=0)


In [None]:
df.apply(
    lambda r: np.abs(r.grad_roots - r.grad_input).mean(), axis=1
).plot.hist()


In [None]:
atol = 1e-6
grad_diff = df.apply(
    lambda r: (np.abs(r.grad_roots - r.grad_input) > atol).any(), axis=1
)

perc_grad_diff = np.stack(grad_diff).mean()  # type: ignore
print(f"{perc_grad_diff:.2%} of the roots have a gradient difference (>{atol:.1e})")