# Querying useless neurons

In the previous part, we constructed a SQL query that could evaluate a neural
network. We can use `eval` to detect *useless* neurons for a given input, i.e.
neurons that to not contribute anything to the end result. This can be useful to
reduce the complexity of a neural network.

The idea behind it is simple: evaluate the model on a given input. Now, remove a
neuron and evaluate the model again. If the two outputs are sufficiently close
to each other, the removed neuron can be considered useless. Keep in mind that
this is always for a specific input.

We can define this a bit more formally as follows: take $eval'$, the version of
$eval$ with some node $z$ left out. If then $|eval - eval'| < \epsilon$ holds,
$z$ is useless.

## Creating the model

We create a simple model, restricting it to 1 output value to more easily
compare the `eval` results.

In [18]:
import torch
import numpy as np
import utils.duckdb as db
import pandas as pd
import utils.nn as nn
import copy

torch.manual_seed(223)

def f(x, y):
    return [x + 2*y]

num_samples = 100
x_train = torch.randn(num_samples, 2) * 100
y_train = [f(x,y) for [x,y] in x_train]

input_size = 2
output_size = 1
hidden_size = 3
num_hidden_layers = 1

model = nn.ReLUFNN(
    input_size=input_size,
    output_size=output_size,
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers
)
nn.train(model, x_train, y_train, save_path="models/useless.pt")

db.load_pytorch_model_into_db(model)

## PyTorch implementation

Let's first try disabling each node in PyTorch and running the model, so we have
something to compare to.

In [19]:
def get_layer_and_node_index(node_id_to_disable):
    node_id_zero_indexed = node_id_to_disable - 1

    if node_id_zero_indexed < input_size:
        return (0, node_id_zero_indexed)

    return (
        ((node_id_zero_indexed - input_size) // hidden_size) + 1,
        ((node_id_zero_indexed - input_size) % hidden_size)
    )

# Preserve the original state dict so we can reset it.
original_state_dict = copy.deepcopy(model.state_dict())

def disable_node(node_id_to_disable):
    # Reset the model to cancel out any previously disabled nodes.
    model.load_state_dict(original_state_dict)

    state_dict = model.state_dict()
    (layer_idx, node_idx) = get_layer_and_node_index(node_id_to_disable)

    # The state dict alternates between weights and biases, so the index of our
    # bias tensor is the following:
    layer_state_dict_idx = layer_idx * 2 - 1

    # Disable the bias.
    biases = list(state_dict.values())[layer_state_dict_idx]
    biases[node_idx] = 0

    # Disable the weights.
    weights = list(state_dict.values())[layer_state_dict_idx - 1]
    weights[node_idx].zero_()

In [20]:
nn_output = model(torch.tensor([5, 20], dtype=torch.float32))
print(f"Regular output: {nn_output[0]}")

for node_id_to_disable in range(input_size + 1, input_size + 1 + hidden_size * num_hidden_layers):
    disable_node(node_id_to_disable)
    output = model(torch.tensor([5, 20], dtype=torch.float32))

    print(f"With neuron {node_id_to_disable} disabled: {output[0]}")

Regular output: 47.10837936401367
With neuron 3 disabled: 1.9035104513168335
With neuron 4 disabled: 47.10837936401367
With neuron 5 disabled: 47.10837936401367


From this output we can observe that both node 4 and 5 are useless for this
input: disabling them still yields the same output.

## Creating the SQL query

We build upon our recursive `eval` query to produce a query that calculates the
difference $|eval - eval'|$. After that it is trivial to filter out differences
that are smaller than some $\epsilon$ using a `WHERE`-clause.

The [full query](./queries/useless_neurons_from_input.sql) goes as follows:

```sql
WITH RECURSIVE inputs AS MATERIALIZED (
    -- Fetch inputs together
    SELECT i.input_set_id, i.input_node_idx, i.input_value, n.id
    FROM input i
    JOIN (
        SELECT
            id,
            bias,
            ROW_NUMBER() OVER (ORDER BY id) AS input_node_idx
        FROM node n
        WHERE NOT EXISTS
        (SELECT 1 FROM edge WHERE dst = n.id)
    ) n ON n.input_node_idx = i.input_node_idx
),
output_nodes AS MATERIALIZED (
    SELECT id, bias
    FROM node n
    WHERE NOT EXISTS
    (SELECT 1 FROM edge WHERE src=n.id)
),
hidden_nodes AS MATERIALIZED (
    SELECT id, bias
    FROM node n
    WHERE EXISTS (SELECT 1 FROM edge WHERE src = n.id)
    AND EXISTS (SELECT 1 FROM edge WHERE dst = n.id)
),
tx AS (
    -- Base case (t1)
    SELECT
        i.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * i.input_value)
        ) AS value,
        e.dst AS id
    FROM inputs i
    JOIN edge e ON i.id = e.src
    JOIN node n ON e.dst = n.id
    GROUP BY e.dst, n.bias, i.input_set_id

    UNION ALL

    SELECT
        tx.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * tx.value)
        ) AS value,
        e.dst AS id
    FROM tx
    JOIN edge e ON tx.id = e.src
    JOIN node n ON e.dst = n.id
    GROUP BY e.dst, n.bias, tx.input_set_id
),
t_out AS MATERIALIZED (
    SELECT
        tx.input_set_id AS input_set_id,
        o.bias + SUM(e.weight * tx.value) AS value,
        e.dst AS id
    FROM output_nodes o
    JOIN edge e ON e.dst = o.id
    JOIN tx ON tx.id = e.src
    GROUP BY e.dst, o.bias, tx.input_set_id
),
tx_prime AS (
    SELECT
        h_to_remove.id AS h_to_remove,
        i.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * i.input_value)
        ) AS value,
        e.dst AS id
    FROM inputs i
    CROSS JOIN hidden_nodes h_to_remove
    JOIN edge e ON i.id = e.src AND e.dst <> h_to_remove.id
    JOIN node n ON e.dst = n.id
    GROUP BY h_to_remove.id, i.input_set_id, e.dst, n.bias

    UNION ALL

    SELECT
        tx_prime.h_to_remove,
        tx_prime.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * tx_prime.value)
        ) AS value,
        e.dst AS id
    FROM tx_prime
    JOIN edge e ON tx_prime.id = e.src
        AND e.dst <> tx_prime.h_to_remove
        AND e.src <> tx_prime.h_to_remove
    JOIN node n ON e.dst = n.id
    GROUP BY tx_prime.h_to_remove, tx_prime.input_set_id, e.dst, n.bias
),
t_out_prime AS (
    SELECT
        tx_prime.h_to_remove,
        tx_prime.input_set_id AS input_set_id,
        o.bias + SUM(e.weight * tx_prime.value) AS value,
        e.dst AS id
    FROM output_nodes o
    JOIN edge e ON e.dst = o.id
    JOIN tx_prime ON tx_prime.id = e.src
    GROUP BY tx_prime.h_to_remove, e.dst, o.bias, tx_prime.input_set_id
),
useless_neurons_overview AS (
    SELECT
        t_out_prime.h_to_remove,
        t_out_prime.input_set_id,
        t_out_prime.id AS output_id,
        t_out.value AS eval,
        t_out_prime.value as eval_prime,
        ABS(t_out.value - t_out_prime.value) AS delta
    FROM t_out_prime
    JOIN t_out ON t_out.id = t_out_prime.id
)
SELECT * FROM useless_neurons_overview ORDER BY h_to_remove;
```

This query is very similar to the one final recursive query from [the earlier
chapter](./3.5%20generic%20eval%20-%20recursive.ipynb). We'll highlight the new
parts.

```sql
hidden_nodes AS MATERIALIZED (
    SELECT id, bias
    FROM node n
    WHERE EXISTS (SELECT 1 FROM edge WHERE src = n.id)
    AND EXISTS (SELECT 1 FROM edge WHERE dst = n.id)
),
```

We add a basic CTE to select the hidden nodes, because we need to calculate
$eval'$ for each one of them later on.

In the main addition to the query, we essentially run `eval` for each hidden
neuron, with that neuron removed. We call this `tx_out_prime`. We achieve this
as follows:

```sql
tx_prime AS (
    -- ...
    CROSS JOIN hidden_nodes h_to_remove
    JOIN edge e ON i.id = e.src AND e.dst <> h_to_remove.id
    -- ...

    UNION ALL

    -- ..
    FROM tx_prime
    JOIN edge e ON tx_prime.id = e.src
        AND e.dst <> tx_prime.h_to_remove
        AND e.src <> tx_prime.h_to_remove
    -- ...
),
```

- The `CROSS JOIN` essentially "loops" over the hidden nodes we want to remove.
- For each node to remove, we add a join condition to remove the node from the
  join.

This resulting overview simply displays the difference between the two for each hidden
neuron, but the query can easily be altered to only show those neurons where the
difference is smaller than some $\epsilon$.

## Comparison

Earlier, we found that neuron 4 and 5 were useless for our particular input.
Let's run the query now:

In [21]:
db.con.execute("TRUNCATE input")
db.con.execute("INSERT INTO input VALUES (0, 1, 5), (0, 2, 20)")

with open('queries/useless_neurons_from_input.sql') as f:
    query = f.read()

db.con.execute(query).df()

Unnamed: 0,h_to_remove,input_set_id,output_id,eval,eval_prime,delta
0,3,0,6,47.108378,1.90351,45.204868
1,4,0,6,47.108378,47.108378,0.0
2,5,0,6,47.108378,47.108378,0.0


Indeed, the SQL query finds the same result. We can conclude that neurons 3 and
4 are useless and can be removed.

## Larger example

The previous network only had 1 hidden layer with 3 nodes. Let's create a bigger
one and see how our query performs.

In [22]:
model = nn.ReLUFNN(
    input_size=2,
    output_size=1,
    hidden_size=10,
    num_hidden_layers=10
)
nn.train(model, x_train, y_train, save_path="models/useless_bigger.pt")

db.load_pytorch_model_into_db(model)

In [23]:
db.con.execute("TRUNCATE input")
db.con.execute("INSERT INTO input VALUES (0, 1, 5), (0, 2, 20)")

with open('queries/useless_neurons_from_input.sql') as f:
    query = f.read()

df = db.con.execute(query).df()
df

Unnamed: 0,h_to_remove,input_set_id,output_id,eval,eval_prime,delta
0,3,0,103,44.925945,39.348491,5.577453
1,4,0,103,44.925945,30.891371,14.034574
2,5,0,103,44.925945,41.924393,3.001552
3,6,0,103,44.925945,43.958574,0.967370
4,7,0,103,44.925945,35.874110,9.051835
...,...,...,...,...,...,...
95,98,0,103,44.925945,44.925945,0.000000
96,99,0,103,44.925945,44.925945,0.000000
97,100,0,103,44.925945,44.925945,0.000000
98,101,0,103,44.925945,5.448792,39.477153


We used the same query again that provides us with an overview. We use
`pandas` to filter this output because that's easier in this notebook, but a
simple `WHERE` condition in the SQL query would have yielded the same result.

Let's first look at the useful neurons.

In [24]:
epsilon = 0.01
df[df['delta'] >= epsilon]

Unnamed: 0,h_to_remove,input_set_id,output_id,eval,eval_prime,delta
0,3,0,103,44.925945,39.348491,5.577453
1,4,0,103,44.925945,30.891371,14.034574
2,5,0,103,44.925945,41.924393,3.001552
3,6,0,103,44.925945,43.958574,0.96737
4,7,0,103,44.925945,35.87411,9.051835
6,9,0,103,44.925945,32.669875,12.25607
12,15,0,103,44.925945,44.741128,0.184817
13,16,0,103,44.925945,32.749066,12.176879
14,17,0,103,44.925945,30.422121,14.503824
15,18,0,103,44.925945,37.24885,7.677094


As we can see, roughly half of the neurons actually provide meaningful
contributions to the final output - for this specific input.

## Conclusion

We have shown how we can adapt the `eval` query to show us which neurons do not
contribute anything to the output of a neural network, for a given input. This
can be useful to reduce the complexity of a network.

The next section explores how we can use a SQL query to calculate the integral
of a function represented by a neural network.

> TODO: do we want to compare the SQL result of the bigger example to the PyTorch
> version as well?