In [None]:
%env CUDA_VISIBLE_DEVICES=""

import torch 
from torch import nn
from lrp_relations import dtd

In [None]:
torch.manual_seed(0)

net = dtd.NLayerMLP(
    n_layers=4,
    input_size=5,
    hidden_size=10,
    output_size=2,
)

def weight_scale(m: nn.Module) -> nn.Module:
    for p in m.parameters():
        # to compensate for negative biases, we scale the weight
        p.data[p.data > 0] = 1.2 * p.data[p.data > 0]
    if isinstance(m, dtd.LinearReLU):
        m.linear.bias.data = - m.linear.bias.data.abs() 
    return m

net.apply(weight_scale)

In [None]:
torch.manual_seed(1)
x = torch.rand(1, 5)

explained_class = 0
assert net(x)[:, explained_class] > 0
precise_dtd = dtd.PreciseDTD(
    net, x, explained_class=explained_class, rule="z+", root_max_relevance=0.01
)
rel_input = precise_dtd.explain()

In [None]:
print("Logit:", precise_dtd.get_output(net.layers[-1]))
for layer, res in precise_dtd.results.items():
    rel = precise_dtd.compute_relevance_of_input(
        layer, res.root
    )
    print("-" * 80)
    print(res.layer_idx)
    print('root', res.root)
    print('rel. vec', rel)
    print('rel. sum', rel.sum())