# Check how the network output changes for the input roots

1. Compute the roots of the train-free DTD 
2. For each input root, we compute the network output 
3. Is the network output set to zero by the root? No!
4. Do the roots have the same gradient as the input? No!

In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable, cast

import dataclasses
import torch 
import numpy as np
import tqdm.auto
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

from lrp_relations import dtd, local_linear, figures
from lrp_relations.utils import to_np


In [None]:
@dataclasses.dataclass
class NotebookArgs:
    root_finder: str = "linear_dtd"
    explained_output: slice = slice(0, 1)
    rule = dtd.rules.z_plus


args = NotebookArgs()

torch.manual_seed(1)
mlp = dtd.MLP(3, 10, 10, 2)
mlp.init_weights()

print(f"the network has {sum(p.numel() for p in mlp.parameters())} parameters")


torch.manual_seed(0)
x = mlp.get_input_with_output_greater(
    0.25, args.explained_output, non_negative=True
)

mlp_output = mlp.slice(output=args.explained_output)

assert mlp_output(x).shape == (1, 1)


x[:, args.explained_output].shape


In [None]:
if args.root_finder == "interpolation":
    root_finder = dtd.InterpolationRootFinder(
        mlp,
        use_cache=True,
        use_candidates_cache=True,
        args=local_linear.InterpolationArgs(
            batch_size=50,
            n_refinement_steps=10,
            n_batches=1,
            show_progress=True,
            enforce_non_negative=True,
        ),
    )
elif args.root_finder == "linear_dtd":
    root_finder = dtd.LinearDTDRootFinder(
        mlp,
        args.explained_output.start,
        args.rule,
    )
else:
    raise ValueError(f"unknown root_finder: {args.root_finder}")


In [None]:
data = []
torch.manual_seed(0)
n_errors = 0
n_points = 1000

for rule in [
    dtd.rules.zero,
    dtd.rules.z_plus,
    # dtd.GammaRule(0.0),
    # dtd.GammaRule(0.25),
    # dtd.GammaRule(0.5),
    dtd.GammaRule(1.0),
    # dtd.GammaRule(100),
    dtd.rules.w2,
]:
    pbar = tqdm.auto.tqdm(total=n_points, desc=rule.name)

    root_finder = dtd.LinearDTDRootFinder(
        mlp,
        args.explained_output.start,
        args.rule,
    )
    rel_fn_builder = dtd.TrainFreeFn.get_fn_builder(
        mlp,
        root_finder=root_finder,
        check_consistent=False,
    )

    rel_fns = dtd.get_decompose_relevance_fns(
        mlp, args.explained_output, rel_fn_builder
    )

    while True:
        try:
            x = mlp.get_input_with_output_greater(
                0.1,
                args.explained_output,
                non_negative=True,
                seed=int(torch.randint(0, 2**32, (1,)).item()),
            )

            rel_result = cast(dtd.TrainFreeRel, rel_fns[-1](x))

            outputs = torch.cat([mlp_output(r.root) for r in rel_result.roots])

            grads = torch.cat(
                [
                    mlp_output.compute_input_grad(r.root)
                    for r in rel_result.roots
                ]
            )
            data.append(
                dict(
                    rule=rule.name,
                    gamma=getattr(rule, "gamma", ""),
                    x=to_np(x),
                    roots=to_np(torch.cat([r.root for r in rel_result.roots])),
                    output_x=mlp_output(x).item(),
                    grad_input=to_np(mlp_output.compute_input_grad(x)),
                    output_roots=to_np(outputs),
                    grad_roots=to_np(grads),
                )
            )
        except AssertionError:
            n_errors += 1

        pbar.update(1)
        pbar.set_postfix(error_percentage=n_errors / pbar.n)
        pbar.refresh()
        if pbar.n >= n_points:
            break

df = pd.DataFrame(data)


In [None]:
df.rule.unique()


In [None]:
diffs = np.stack(df.output_x - df.output_roots)

bins = 20
plt.hist(diffs.flatten(), bins=bins, density=True)
plt.hist(df.output_x, bins=bins, density=True, alpha=0.5)
plt.show()


In [None]:
percentage_almost_zero = (diffs < 1e-6).mean()
print(
    "Difference between output and roots is almost "
    f"zero for: {percentage_almost_zero:.2%}"
)


In [None]:
(diffs < 1e-6).mean(axis=0)


In [None]:
df.apply(
    lambda r: np.abs(r.grad_roots - r.grad_input).mean(), axis=1
).plot.hist()


In [None]:
atol = 1e-5

keys = df.groupby(['rule', 'gamma'], dropna=False).first().index
table_data = []
for key in keys:
    
    rule, gamma = key
    
    print(rule, gamma)
    df_sub = df[np.logical_and(df.rule == rule, df.gamma == gamma)]

    grad_diff = df_sub.apply(
        lambda r: (np.abs(r.grad_roots - r.grad_input) > atol).any(), axis=1
    )

    perc_grad_diff = np.stack(grad_diff).mean()  # type: ignore
    print(key)
    print(f"{perc_grad_diff:.2%} of the roots have a gradient difference (>{atol:.1e})")

    print(f"{1 - perc_grad_diff:.2%} of the roots have a gradient difference (<{atol:.1e})")


    diffs = np.stack(df_sub.output_roots - df_sub.output_x)

    plt.hist(diffs.flatten(), bins=33, density=True)
    plt.title(f"{rule} {gamma}")
    plt.show()
    
    perc_zero_diff = (np.abs(diffs) < 1e-6).mean()
    print(f"{rule} {gamma}: {perc_zero_diff:.2%} of the roots have a difference of zero")



    if rule == 'gamma':
        rule_obj = dtd.GammaRule(gamma)
    else:
        rule_obj = dtd.Rule(rule)
    rule_latex = dtd.get_latex_rule_name(rule_obj)
    table_data.append(
        {
            'Rule': rule_latex,
            'In Local Linear Region': f'{perc_grad_diff:.2%}',
            'Zero Difference in Output': f'{perc_zero_diff:.2%}',
        }
    )

print(pd.DataFrame(table_data).set_index('Rule').T.to_latex(escape=False).replace('%', '\%'))

In [None]:
grad_diff # .groupby(['rule', 'gamma']).apply(print)
df.keys()


In [None]:
grad_diff = df.apply(
    lambda r: (np.abs(r.grad_roots - r.grad_input) <= 1e-8).all(), axis=1
)
np.stack(grad_diff).mean()  # type: ignore

In [None]:
diff_max = df.apply(
    lambda r: (np.abs(r.grad_roots - r.grad_input)).max(), axis=1
)

diff_max[(diff_max > 1e-8)].mean()

In [None]:
# grad_diff[(grad_diff < 1e-8)].mean()

In [None]:
diff_mean = df.apply(
    lambda r: (np.abs(r.grad_roots - r.grad_input)).mean(), axis=1
)
diff_mean[diff_max > 1e-8].mean()

In [None]:
roots_output = []
for i in range(len(df)):
    roots_output.append(to_np(mlp(torch.from_numpy(df.x.iloc[0]))))

df['root_outputs'] = roots_output

In [None]:
keys = df.groupby(['rule', 'gamma'], dropna=False).first().index

for key in keys:
    rule, gamma = key
    print(rule, gamma)
    df_sub = df[np.logical_and(df.rule == rule, df.gamma == gamma)]

    diffs = np.stack(df_sub.output_roots - df_sub.output_x)

    plt.hist(diffs.flatten(), bins=33, density=True)
    plt.title(f"{rule} {gamma}")
    plt.show()
    
    perc_zero_diff = (np.abs(diffs) < 1e-6).mean()
    print(f"{rule} {gamma}: {perc_zero_diff:.2%} of the roots have a difference of zero")


In [None]:

c_factors = pd.DataFrame(
    [dict(t=t, c=to_np(mlp(t*x) / mlp(x))) for t in np.linspace(0, 10, 11)]
)

In [None]:
with figures.latexify():
    plt.figure(figsize=figures.get_figure_size(0.3))
    for i in range(20):
        x = mlp.get_input_with_output_greater(
            0.1,
            args.explained_output,
            non_negative=True,
            seed=i,
        )

        c_factors = pd.DataFrame(
            [dict(t=t, c=to_np(mlp(t*x) / mlp(x))) for t in np.linspace(0, 1, 300)]
        )
        mlp(x)

        plt.plot(
            c_factors.t,
            np.concatenate(c_factors.c)[:, 0],
            c='black',
            linewidth=0.2,
        )
        np.stack(c_factors.c).shape

    plt.plot(
        c_factors.t,
        c_factors.t,
        linestyle='--',
    )
# plt.vlines(
#     1,
#     0,
#     plt.ylim()[1],
#     linestyle='--',
#     color='gray',
#     linewidth=0.5,
# )
plt.xlabel("Scaling: c")
plt.ylabel("Network Output")
plt.gcf().set_dpi(200)

In [None]:
a1 = mlp(x, last=mlp.layer1)

for c in torch.linspace(0.05, 2, 10):
    print((c * a1).shape)
    rel = rel_fns[-2](c * a1)
    rel.relevance
    print(
        rel.relevance[:, 6]
    )

In [None]:
[(name, param) for name, param in list(mlp.named_parameters())
 if 'bias' in name]

In [None]:
mlp_pos_bias = dtd.MLP(3, 10, 10, 2)


x = mlp_pos_bias.get_input_with_output_greater(
    0.1, 
    args.explained_output,
    n_tries=10000,
)
[(name, param) for name, param in list(mlp_pos_bias.named_parameters())
 if 'bias' in name]

In [None]:

root_finder_bias = dtd.LinearDTDRootFinder(
    mlp_pos_bias,
    args.explained_output.start,
    args.rule,
)
rel_fn_builder_bias = dtd.TrainFreeFn.get_fn_builder(
    mlp_pos_bias,
    root_finder=root_finder_bias,
    check_consistent=False,
)

rel_fns_bias = dtd.get_decompose_relevance_fns(
    mlp_pos_bias, args.explained_output, rel_fn_builder_bias
)


In [None]:
x, mlp_pos_bias(x)

In [None]:
print(x.shape)
a1 = mlp_pos_bias(x, last=mlp_pos_bias.layer1)

for c in torch.linspace(-20, 20, 10):
    print((c * a1).shape)
    rel = rel_fns_bias[-2](c * a1)
    rel.relevance
    print(
        rel.relevance[:, 0]
    )