In [None]:
import numpy as np

import os
import dataclasses

In [None]:
@dataclasses.dataclass(frozen=True)
class LinearReLU:
    w: np.ndarray
    b: np.ndarray

    def __call__(self, x):
        return np.maximum(0, x @ self.w + self.b)

In [None]:
W = np.random.uniform(-1, 1, size=(1000, 2))
all_zero = []
for w in W:
    relu = LinearReLU(w=w, b=0)

    x = np.random.uniform(0, 1, size=(1000, 2))
    y = relu(x)
    all_zero.append(y.mean())

all_zero = np.array(all_zero)

In [None]:
import matplotlib.pyplot as plt
from lrp_relations import figures

with figures.latexify():
    plt.figure(figsize=figures.get_figure_size(0.3))
    plt.scatter(
        W[:, 0],
        W[:, 1],
        c=all_zero,
        s=5,
        cmap="coolwarm",
    )
    plt.colorbar()
    plt.gcf().set_dpi(300)


In [None]:
from lrp_relations import dtd
import torch

with_neg_biases = False

In [None]:

with_neg_biases = False

if with_neg_biases:
    torch.manual_seed(3)
    mlp = dtd.MLP(3, 2, 10, 1)
    mlp.init_weights()
else:
    torch.manual_seed(1)
    mlp = dtd.MLP(3, 2, 10, 1)

x = 20.0 * (2 * torch.rand(100000, 2, requires_grad=True) - 1)

y = mlp(x)

with figures.latexify():
    plt.figure(figsize=figures.get_figure_size(0.3))
    grad, = torch.autograd.grad(y.sum(), x)

    plt.scatter(
        x[:, 0].detach(),
        x[:, 1].detach(),
        c=y.detach().numpy(),
        s=5,
        cmap="coolwarm",
    )
    plt.colorbar()
    plt.gcf().set_dpi(300)


In [None]:
from lrp_relations.utils import to_np
from lrp_relations import utils 
from sklearn.neighbors import NearestNeighbors
grad_np = to_np(grad)


nn = NearestNeighbors(n_neighbors=1, radius=1e-4, leaf_size=10)

idx = 0
nn_grads = []
nn.fit(grad_np[:1])

for grad_i in grad_np:
    dist, neigh = nn.kneighbors(grad_i.reshape(1, -1))
    if dist[0, 0] >= 1e-4:
        nn_grads.append(grad_i)
        nn = NearestNeighbors(n_neighbors=1, radius=1e-4, leaf_size=10)
        nn.fit(np.stack(nn_grads))

idx

In [None]:
idx = np.arange(len(nn_grads))
np.random.shuffle(idx)
with figures.latexify():
    plt.figure(figsize=figures.get_figure_size(0.3, ratio=1))

    plt.scatter(
        x[:, 0].detach(),
        x[:, 1].detach(),
        c=idx[nn.kneighbors(grad_np, return_distance=False)],
        # c=np.linalg.norm(grad_np, axis=1),
        s=0.5,
        cmap="rainbow",
    )
    # plt.colorbar()
    plt.gcf().set_dpi(300)


In [None]:
idx = np.arange(len(nn_grads))
np.random.seed(0)
np.random.shuffle(idx)
with figures.latexify():
    plt.figure(figsize=figures.get_figure_size(0.24, ratio=1))

    plt.scatter(
        x[:, 0].detach(),
        x[:, 1].detach(),
        c=idx[nn.kneighbors(grad_np, return_distance=False)],
        # c=np.linalg.norm(grad_np, axis=1),
        s=0.5,
        cmap="rainbow",
    )
    subset = np.random.choice(len(x), size=400, replace=False)
    plt.quiver(
        x[subset, 0].detach(),
        x[subset, 1].detach(),
        grad_np[subset, 0],
        grad_np[subset, 1],
        scale_units="xy",
        scale=0.25 if with_neg_biases else 0.05,
        color="k",
        width=0.002,
        # headwidth=1,
        # headlength=1,
        # headaxislength=0,
    )
    # plt.colorbar()
    plt.gcf().set_dpi(300)

    os.makedirs("figures", exist_ok=True)
    if with_neg_biases: 
        figname = "figures/mlp_2d_neg_biases.png"
    else:
        figname = "figures/mlp_2d_random_biases.png"

    plt.savefig(figname, bbox_inches="tight", pad_inches=0.01, dpi=900)
    print(f"Saved {figname}")
    print(f"cp {os.path.abspath(figname)} ./figures")

In [None]:
all_biases = [
    (n, p)
    for n, p in mlp.named_parameters()
    if "bias" in n
]

all_biases