# Notebook to visualize root of DTD



In [None]:
%env CUDA_VISIBLE_DEVICES=""


from typing import Union, Callable

import torch 
import numpy as np
import tqdm.auto
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd

from lrp_relations import dtd
from lrp_relations.utils import to_np


In [None]:
torch.manual_seed(0)
mlp = dtd.MLP(5, 10, 10, 2)

def weight_scale(m: nn.Module) -> nn.Module:
    for p in m.parameters():
        # to compensate for negative biases, we scale the weight
        p.data[p.data > 0] = 1.4 * p.data[p.data > 0]
    if isinstance(m, dtd.LinearReLU):
        m.linear.bias.data = - m.linear.bias.data.abs() 
    return m

mlp.apply(weight_scale)

In [None]:
x = torch.randn(11, 10)

logits = mlp(x)
logits

In [None]:
mlp.input_size

In [None]:
def show_summary(df: pd.DataFrame) -> None:
    keys = [
        "last_layer_for_input",
        "layer3_for_input",
        "last_layer_for_root",
        "layer3_for_root",
    ]
    print('-' * 80)
    print(df.rule.iloc[0])
    for key in keys:
        print(key, np.abs(np.stack(df[key])).mean())

all_roots = []
data = []
for rule in [ 'z+', 'pinv', ]:
    rr = dtd.RecursiveRoots(mlp, explained_class=0, rule=rule)


    for _ in tqdm.auto.trange(200):
        x = torch.randn(1, mlp.input_size)
        roots = rr.run(x, start_at=mlp.layer1, end_at=mlp.layer1)
        all_roots.extend(roots)
        for root in roots:
            data.append(
                dict(
                    rule=rule,
                    explained_neuron=root.explained_neuron,
                    last_layer_for_input=to_np(root.outputs_of_input[mlp.layers[-1]]),
                    layer3_for_input=to_np(root.outputs_of_input[mlp.layers[3]]),
                    last_layer_for_root=to_np(root.outputs_of_root[mlp.layers[-1]]),
                    layer3_for_root=to_np(root.outputs_of_root[mlp.layers[3]]),
                )
            )

df = pd.DataFrame(data)
df.groupby('rule').apply(show_summary)
# print('-' * 80)
# print(rule)
# print()
# show_summary(df)
# 

In [None]:
data = []
for root in all_roots:
    assert root.input.shape == (1, 10), root.rule.name
    assert root.root.shape == (1, 10), root.rule.name

    root.input.requires_grad_(True)
    root.root.requires_grad_(True)

    out = mlp(root.input)
    grad_input, = torch.autograd.grad(
        out[:, 0],
        root.input,
        grad_outputs=torch.ones_like(out[:, 0]),
    )

    out = mlp(root.root)
    grad_root, = torch.autograd.grad(
        out[:, 0],
        root.root,
        grad_outputs=torch.ones_like(out[:, 0]),
    )

    grad_input, grad_root

    data.append(dict(
        rule=root.rule.name,
        explained_neuron=root.explained_neuron,
        grad_input=to_np(grad_input),
        grad_root=to_np(grad_root),
        grad_diff=(grad_input - grad_root).abs().mean().item(),
    ))

df = pd.DataFrame(data)
df.groupby('rule').grad_diff.describe()

In [None]:

root = [r for r in all_roots if r.rule.name == 'pinv'][0]

# root.outputs_of_input[mlp.layers[-1]]
root.input, root.root

In [None]:
df.groupby('rule').grad_diff.describe()

In [None]:

keys = [
    "last_layer_for_input",
    "layer3_for_input",
    "last_layer_for_root",
    "layer3_for_root",
]
for key in keys:
    print(key, np.abs(np.stack(df[key])).mean())


In [None]:

(roots[0].upper_layers[0].upper_layers[0].activations_for_root[mlp.layer4],
roots[0].upper_layers[0].upper_layers[0].activations_for_input[mlp.layer4])

In [None]:
roots = torch.cat([
    dtd.root_point_linear(x[:1], mlp.layer1, j=j)
    for j in range(mlp.layer1.out_features)
])

list(zip(roots, mlp.layer1(roots)))

In [None]:

layer_root = torch.linalg.pinv(mlp.layer1.linear.weight) @ mlp.layer1.linear.bias 
mlp.layer1(layer_root)

In [None]:
# mlp.layer1(roots.mean(0, keepdim=True))