# Execute DTD-full-backward

1. Get samples from local linear segment (store as $X_L$).
2. For last layer, find which sample of $X_L$ would be a good root.
3. Recursively, derive and find other roots.

In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable

import dataclasses
import torch 
import numpy as np
import tqdm.auto
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd

from lrp_relations import dtd, local_linear


from lrp_relations.utils import to_np


In [None]:
@dataclasses.dataclass
class NotebookArgs:
    root_finder: str = 'interpolation'

args = NotebookArgs()

torch.manual_seed(1)
mlp = dtd.MLP(3, 10, 10, 2)

mlp.init_weights()

print(f"the network has {sum(p.numel() for p in mlp.parameters())} parameters")


In [None]:
rule = "z+"
explained_output_neuron = 0
explained_output: dtd.NEURON = slice(
    explained_output_neuron, explained_output_neuron + 1
)

x = mlp.get_input_with_output_greater(0.5, explained_output, non_negative=True)

mlp_output = mlp.slice(output=explained_output)

root_finder = dtd.InterpolationRootFinder(
    mlp_output,
    use_cache=False,
    args=local_linear.InterpolationArgs(
        batch_size=50,
        show_progress=True,
        enforce_non_negative=True,
    ),
)

network_output_fn = dtd.NetworkOutputRelevanceFn(
    mlp_output, mlp.first_layer, explained_output
)

roots = root_finder.get_root_points_for_layer(
    mlp.first_layer, x, relevance_fn=network_output_fn
)


In [None]:
def get_grad(x: torch.Tensor) -> torch.Tensor:
    x.requires_grad_(True)
    out = mlp(x)
    (grad_x,) = torch.autograd.grad(out[:, explained_output], x)
    return grad_x


root = roots[0]

root_grad = get_grad(root.root)[0]
input_grad = get_grad(root.input)[0]
# root.root - root.input
assert torch.allclose(root_grad, input_grad)

mlp(root.root), mlp(root.input)


mlp_root = mlp(root.root)[:, explained_output]
mlp_input = mlp(root.input)[:, explained_output]

taylor_approx = mlp_root + root_grad @ (root.input - root.root)[0]

assert torch.allclose(taylor_approx, mlp_input, atol=1e-5)


In [None]:
if args.root_finder == 'metropolis_hasting':
    root_finder = dtd.MetropolisHastingRootFinder(
        mlp,
        args=local_linear.MetropolisHastingArgs(
            n_steps=50,
            n_warmup=0,
            # n_warmup=1000,
            enforce_positive=True,
            show_progress=True,
        ),
    )
elif args.root_finder == 'interpolation':
    root_finder = dtd.InterpolationRootFinder(
        mlp,
        use_cache=True,
        use_candidates_cache=True,
        args=local_linear.InterpolationArgs(
            batch_size=50,
            n_refinement_steps=10,
            n_batches=1,
            show_progress=True,
            enforce_non_negative=True,
        ),
    )
else:
    raise ValueError(f"unknown root_finder: {args.root_finder}")

rel_fn_builder = dtd.FullBackwardFn.get_fn_builder(
    mlp,
    root_finder=root_finder,
    stabilize_grad=None,
)


In [None]:
%load_ext line_profiler

rel_fns = dtd.get_decompose_relevance_fns(mlp, explained_output, rel_fn_builder)

print(mlp(x)[:, explained_output])

rel_result =  None

def benchmark():
    global rel_result
    rel_result = rel_fns[-1](x)


benchmark()

# %lprun -f dtd.LocalSegmentRoots.get_root_points_for_layer \
#     -f dtd.LocalSegmentRoots.get_cache_key \
#     -f local_linear.sample \
#         benchmark()

In [None]:
for idx, points in root_finder.cache.items():
    print(f"{idx} {len(points)}")

In [None]:
for idx, points in root_finder.candidates_cache.items():
    print(f"{idx} {len(points)}")

    print(points[:10])
    print(points.min(0).values.tolist())
    print(points.max(0).values.tolist())

In [None]:

assert rel_result is not None
rels = rel_result.collect_relevances()

data = []


def collect_info(rel: dtd.Relevance, callgraph: list[str]):
    rel_input = rel.computed_with_fn.get_input_layer()
    if isinstance(rel_input, dtd.LinearReLU):
        layer_name = mlp.get_layer_name(rel_input)
    else:
        layer_name = "output"

    callgraph_w_layer = callgraph + [layer_name]

    if isinstance(rel, dtd.FullBackwardRel):
        for root, r, rel_unresolved, rel_decomposed in zip(
            rel.roots,
            rel.relevance_upper_layers,
            rel.unresolved_relevance,
            rel.roots_relevance,
        ):
            j = root.explained_neuron
            callgraph_w_root = callgraph_w_layer + [f"root_{j}"]

            assert root.relevance is not None

            root_logit = mlp(root.root, first=root.layer)[:, explained_output]
            input_logit = mlp(root.input, first=root.layer)[:, explained_output]
            data.append(
                {
                    "layer": layer_name,
                    "unresolved_relevance": rel_unresolved.detach().numpy(),
                    "relevance": rel_decomposed.detach().numpy(),
                    "callgraph": callgraph_w_root,
                    "root": root.root.detach().numpy(),
                    "input": root.input.detach().numpy(),
                    "root_logit": root_logit.detach().item(),
                    "input_logit": input_logit.detach().item(),
                    "explained_neuron": root.explained_neuron,
                    "root_relevance": root.relevance.detach().numpy(),
                }
            )
            if isinstance(r, dtd.FullBackwardRel):
                collect_info(r, callgraph_w_root)
            # collect_info(r, callgraph_w_root)
        # rel.relevance
        # rel.roots_relevance


collect_info(rel_result, [])

df = pd.DataFrame(data)


In [None]:
df.groupby("layer").unresolved_relevance.mean()
df.groupby("layer").relevance.apply(lambda x: np.stack(x).sum(axis=-1)).layer3

df.groupby("layer").relevance.apply(lambda x: np.stack(x)).layer1.mean(axis=0)
df.groupby("layer").apply(lambda x: x.input - x.root).abs()

# df.groupby('layer').unresolved_relevance.apply(lambda x: np.stack(x)).layer1.mean(axis=0)

# df.groupby('layer').unresolved_relevance

df.groupby("layer").relevance.apply(lambda x: np.stack(x).sum(axis=-1))


In [None]:
df["sum_root_relevance"] = df.root_relevance.apply(lambda x: x.sum())
df["sum_relevance"] = df.relevance.apply(lambda x: x.sum())
df["sum_unresolved_relevance"] = df.unresolved_relevance.apply(
    lambda x: x.sum()
)

keys = [
    "layer",
    "callgraph",
    "sum_root_relevance",
    "sum_relevance",
    "sum_unresolved_relevance",
]
df[keys][df.layer == "layer2"].sort_values(
    "sum_unresolved_relevance"  # type: ignore
)


In [None]:
for name in ['input', 'root']:
    data_points =df[df.layer == "layer3"][name]
    # print(name, data_points)
    print(name, data_points.apply(lambda x: (x >= 0).all()).all())

In [None]:
df.groupby("layer").apply(lambda x: x.input_logit - x.root_logit).apply(
    lambda x: np.abs(x).mean()
)


In [None]:
# (df.root - df.input)

df.input_logit
(df.input_logit - df.root_logit).plot.hist(bins=10)  



In [None]:
df[df.layer == 'layer1'].root_logit

In [None]:
unresolved_rel = df.groupby("layer").unresolved_relevance.apply(
    lambda x: np.stack(x).sum(axis=-1)
)
decomposed_rel = df.groupby("layer").relevance.apply(
    lambda x: np.stack(x).sum(axis=-1)
)
root_rel = df.groupby("layer").root_relevance.apply(
    lambda x: np.stack(x).sum(axis=-1)
)
# unresolved_rel + decomposed_rel, root_rel

decomposed_rel.layer3





In [None]:
unresolved_rel.layer3

In [None]:
import networkx as nx

cg = nx.Graph()


def visit_rel(rel: dtd.Relevance, prefix: str):
    rel_input = rel.computed_with_fn.get_input_layer()
    if isinstance(rel_input, dtd.LinearReLU):
        layer_name = 'L' + str(mlp.get_layer_index(rel_input))
    else:
        layer_name = "o"

    node_name = prefix + layer_name
    cg.add_node(node_name)
    cg.add_edge(prefix, node_name)

    if isinstance(rel, dtd.FullBackwardRel):
        for root, rel_info in zip(rel.roots, rel.relevance_upper_layers):
            j = root.explained_neuron
            root_name = node_name + f"_r{j}@"

            cg.add_node(root_name)
            cg.add_edge(node_name, root_name)
            visit_rel(rel_info, root_name)
        # rel.relevance
        # rel.roots_relevance


assert rel_result is not None
visit_rel(rel_result, "s")
cg.remove_node("s")


In [None]:
plt.figure(figsize=(10, 10))

# pos = nx.spring_layout(cg, scale=20)

pos = nx.nx_agraph.graphviz_layout(cg, prog="neato")
nx.draw(
    cg, pos, with_labels=True, node_size=100, font_size=3, node_color="white"
)


plt.savefig("/tmp/callgraph.svg", dpi=300)
