# Local linear notebook


* [local-linear-visualize]: visualize the linear segments over a 2D space
* [local-linear-size]: compute the typical size of local linear segments
* [local-linear-satisfy]: do the points of the local linear segments satisfy the condi-
tions of root points?

In [None]:
%env CUDA_VISIBLE_DEVICES=""
%load_ext autoreload
%autoreload 2

from typing import Union, Callable

import seaborn.apionly as sns
import numpy as np
import torch 
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd
import savethat

from lrp_relations import local_linear, dtd, utils

## [local-linear-visualize]: visualize the linear segments over a 2D space

In [None]:
# run sampling

utils.set_project_dir()
sample_grid = savethat.create_node(
    local_linear.SampleGrid,
    args=local_linear.SampleGridArgs(),
    credentials=utils.get_credentials(),
)
grid = sample_grid.run()


In [None]:
grid.grads_logit_0.shape

In [None]:
grid_np = grid.grid.detach().numpy()
for i in range(2):
    plt.scatter(
        grid_np[:, 0],
        grid_np[:, 1],
        # c=net(inputs).argmax(dim=1).numpy(),
        # c=net(inputs)[:, 0].detach().numpy(),
        c=grid.logits[:, i].detach().numpy(),
    )
    plt.colorbar()
    plt.title(f"Logit {i}")
    plt.show()

In [None]:

colors = np.array(sns.color_palette("colorblind", 100))
# print(colors)

for grad in [grid.grads_logit_0, grid.grads_logit_1]:
    grad_colors = dtd.almost_unique(grad, atol=1e-4)
    # print(grad_colors.unique().shape)
    plt.scatter(
        grid_np[:, 0],
        grid_np[:, 1],
        c=colors[grad_colors.numpy() % len(colors)],
        marker='.',
        s=1,
    )
    plt.show()

# Local Sampling

In [None]:
torch.manual_seed(0)
in_channels = 10
mlp = dtd.NLayerMLP(5, in_channels, 25, 1)


def weight_scale(m: nn.Module) -> None:
    for p in m.parameters():
        p.data[p.data > 0] = 1.2 * p.data[p.data > 0]
    if isinstance(m, dtd.LinearReLU):
        m.linear.bias.data = -m.linear.bias.data.abs()

_ = mlp.apply(weight_scale)

In [None]:

n_chains = 1000

for i in range(1000):
    torch.manual_seed(i)
    x = 2 * torch.rand(1, in_channels).repeat(n_chains, 1) - 1
    if (mlp(x) > 0).all():
        break

assert (mlp(x) > 0).all()

sample_results = local_linear.sample(
    mlp, x, n_steps=4_000, n_warmup=3000, grad_rtol=1e-5, grad_atol=1e-5
)

samples = sample_results.chain[-1]
samples.requires_grad_(True)
logits = mlp(samples)
(grad,) = torch.autograd.grad(logits, samples, torch.ones_like(logits))

print(grad.shape)

assert torch.allclose(grad, sample_results.start_grad, rtol=1e-5, atol=1e-5)


In [None]:
_ = plt.hist(logits.detach().numpy().flatten(), bins=50)

In [None]:
chain_np = sample_results.chain.detach().numpy()
# _ = plt.plot(chain_np[:, :, 0], chain_np[:, :, 1], '.', c='b', alpha=0.01)
_ = plt.plot(chain_np[:, :, 0], chain_np[:, :, 1], alpha=1)

In [None]:
plt.plot(sample_results.accept_ratio.numpy())
ax = plt.twinx()
ax.plot(sample_results.scaling.detach().numpy())
ax.set_yscale('log')
plt.show()

In [None]:
_ = plt.hist(chain_np.flatten(), bins=100)