# Basic model queries

Just being able to query the structure of a model already gives us useful
information. We'll look at two examples: describing the complexity of a
network and finding nodes that can likely be pruned. We'll then run them
against multiple models.

In [1]:
import duckdb as db

# Note: we assume the app models are available. If not, run the
# `../preparation.ipynb` notebook
con_single = db.connect()
con_single.execute(f"IMPORT DATABASE '../mnist-showcase-app/dbs/cnn_single.db'")

con_multi = db.connect()
con_multi.execute(f"IMPORT DATABASE '../mnist-showcase-app/dbs/cnn_multimodel_size.db'")

query_layers_single = """WITH RECURSIVE input_nodes AS (
    SELECT id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
nodes_with_layer AS (
    SELECT
        id,
        0 as layer
    FROM input_nodes

    UNION ALL

    SELECT
        n.id,
        nodes_with_layer.layer + 1 as layer
    FROM edge e
    JOIN nodes_with_layer ON nodes_with_layer.id = e.src
    JOIN node n ON e.dst = n.id
    GROUP BY e.dst, n.id, layer
)
SELECT layer, COUNT(id) AS number_of_nodes
FROM nodes_with_layer
GROUP BY layer
ORDER BY layer;
"""

query_layers_multi = """WITH RECURSIVE input_nodes AS (
    SELECT
        model_id,
        id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
nodes_with_layer AS (
    SELECT
        model_id,
        id,
        0 as layer
    FROM input_nodes

    UNION ALL

    SELECT
        n.model_id,
        n.id,
        nodes_with_layer.layer + 1 as layer
    FROM edge e
    JOIN nodes_with_layer
      ON nodes_with_layer.id = e.src
      AND nodes_with_layer.model_id = e.model_id
    JOIN node n ON e.dst = n.id
    GROUP BY n.model_id, e.dst, n.id, layer
)
SELECT
  m.id,
  m.name,
  n.layer,
  COUNT(n.id) AS number_of_nodes
FROM model m
JOIN nodes_with_layer n ON n.model_id = m.id
GROUP BY m.id, n.layer, m.name
ORDER BY m.id, n.layer;
"""

query_parameters_single = """WITH input_nodes AS (
    SELECT id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
num_biases AS (
    SELECT COUNT(bias) AS num_biases
    FROM node
    WHERE id NOT IN (SELECT id FROM input_nodes)
),
num_weights AS (
    SELECT COUNT(weight) AS num_weights FROM edge
)
SELECT
    (SELECT num_biases FROM num_biases)
    + (SELECT num_weights FROM num_weights)
AS learnable_parameters"""

query_parameters_multi = """WITH input_nodes AS (
    SELECT
      model_id,
      id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
num_biases AS (
    SELECT model_id, COUNT(bias) AS num_biases
    FROM node
    WHERE id NOT IN (SELECT id FROM input_nodes)
    GROUP BY model_id
),
num_weights AS (
    SELECT model_id, COUNT(weight) AS num_weights
    FROM edge
    GROUP BY model_id
)

SELECT
  m.id,
  m.name,
  nb.num_biases + nw.num_weights
FROM model m
JOIN num_biases nb ON m.id = nb.model_id
JOIN num_weights nw ON m.id = nw.model_id
"""

query_pruning_single = """
SELECT src
FROM edge
GROUP BY src
HAVING MAX(ABS(weight)) <= 0.01
"""

query_pruning_multi = """
WITH num_input_nodes AS (
  SELECT model_id, COUNT(id) AS num_input_nodes
  FROM node
  WHERE id NOT IN (SELECT dst FROM edge)
  GROUP BY model_id
),
num_output_nodes AS (
  SELECT model_id, COUNT(id) AS num_output_nodes
  FROM node
  WHERE id NOT IN (SELECT src FROM edge)
  GROUP BY model_id
),
num_total_nodes AS (
  SELECT model_id, COUNT(id) AS num_nodes
  FROM node
  GROUP BY model_id
),
num_hidden_nodes AS (
  SELECT
    t.model_id,
    t.num_nodes - i.num_input_nodes - o.num_output_nodes AS num_hidden_nodes
  FROM num_total_nodes t
  JOIN num_input_nodes i ON i.model_id = t.model_id
  JOIN num_output_nodes o ON o.model_id = t.model_id
  GROUP BY t.model_id, t.num_nodes, i.num_input_nodes, o.num_output_nodes
),
prunable_nodes AS (
  SELECT
    model_id,
    src AS id_to_prune
  FROM edge
  GROUP BY model_id, src
  HAVING MAX(ABS(weight)) <= ?
),
num_prunable_nodes AS (
  SELECT
    m.id AS model_id,
    m.name,
    COUNT(p.id_to_prune) AS num_prunable_nodes
  FROM model m
  JOIN prunable_nodes p ON p.model_id = m.id
  GROUP BY m.id, m.name
  ORDER BY m.id
)
SELECT
  m.id,
  m.name,
  h.num_hidden_nodes,
  p.num_prunable_nodes,
  ROUND(p.num_prunable_nodes * 100 / h.num_hidden_nodes, 2) AS percentage_prunable
FROM model m
JOIN num_hidden_nodes h ON h.model_id = m.id
JOIN num_prunable_nodes p ON p.model_id = m.id
ORDER BY m.id
"""

## Model complexity

A first basic set of queries we can perform is to describe the model
complexity. Some questions are:

- How many learnable parameters does the model have?
- How many layers does the model have, with how many nodes each?

The queries are as follows; for the learnable parameters:

```sql
WITH input_nodes AS (
    SELECT id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
num_biases AS (
    SELECT COUNT(bias) AS num_biases
    FROM node
    WHERE id NOT IN (SELECT id FROM input_nodes)
),
num_weights AS (
    SELECT COUNT(weight) AS num_weights FROM edge
)
SELECT
    (SELECT num_biases FROM num_biases)
    + (SELECT num_weights FROM num_weights)
AS learnable_parameters
```

And for the layer info:

```sql
WITH RECURSIVE input_nodes AS (
    SELECT id
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
nodes_with_layer AS (
    SELECT
        id,
        0 as layer
    FROM input_nodes

    UNION ALL

    SELECT
        n.id,
        nodes_with_layer.layer + 1 as layer
    FROM edge e
    JOIN nodes_with_layer ON nodes_with_layer.id = e.src
    JOIN node n ON e.dst = n.id
    GROUP BY e.dst, n.id, layer
)
SELECT layer, COUNT(id) AS number_of_nodes
FROM nodes_with_layer
GROUP BY layer
ORDER BY layer;
```

Running these queries agains the MNIST CNN results in the following:

In [2]:
con_single.execute(query_parameters_single).df()

Unnamed: 0,learnable_parameters
0,4060810


In [3]:
con_single.execute(query_layers_single).df()

Unnamed: 0,layer,number_of_nodes
0,0,784
1,1,21632
2,2,9216
3,3,128
4,4,10


## Pruning

Another interesting query revolves around model pruning: determining which
nodes contribute little to the network and can be removed, without loss of
accuracy.

One way to do this is to find nodes with low outgoing weights. We can do
this with the following query:

```sql
SELECT src
FROM edge
GROUP BY src
HAVING MAX(ABS(weight)) <= 0.01
```

In [4]:
pruning_result = con_single.execute(query_pruning_single).df()
pruning_result


Unnamed: 0,src
0,23574
1,23638


These are the outgoing weights for these nodes (limited to 100 weights):

In [5]:
src_ids = map(str, pruning_result["src"].tolist())

query = f"""
SELECT src, weight
FROM edge
WHERE src IN ({','.join(src_ids)})
"""

con_single.execute(query).df()

Unnamed: 0,src,weight
0,23574,0.005189
1,23574,0.006463
2,23574,-0.007742
3,23574,-0.001924
4,23574,0.009078
...,...,...
251,23638,-0.009493
252,23638,0.007660
253,23638,0.009144
254,23638,-0.002942


For reference, this is the average weight:

In [6]:
avg_query = """
SELECT AVG(ABS(weight)) AS avg_weight FROM edge
"""

con_single.execute(avg_query).df()

Unnamed: 0,avg_weight
0,0.055364


## Multiple models

We can also use these queries to compare multiple models. In this example,
we'll use a set of MNIST classifiers with the same structure, but with
a different number of hidden units per layer.

We can compare the number of hidden units per layer:

In [7]:
con_multi.execute(query_layers_multi).df()

Unnamed: 0,id,name,layer,number_of_nodes
0,1,Regular,0,784
1,1,Regular,1,21632
2,1,Regular,2,9216
3,1,Regular,3,128
4,1,Regular,4,10
5,2,2x smaller,0,784
6,2,2x smaller,1,10816
7,2,2x smaller,2,4608
8,2,2x smaller,3,64
9,2,2x smaller,4,10


The number of learnable parameters:

In [8]:
con_multi.execute(query_parameters_multi).df()

Unnamed: 0,id,name,(nb.num_biases + nw.num_weights)
0,1,Regular,4060810
1,2,2x smaller,1071946
2,3,4x smaller,296362
3,4,8x smaller,88282
4,5,16x smaller,29170


And the number of prunable nodes:

In [9]:
for max_value in [0.01, 0.02, 0.03, 0.04, 0.05]:
    print(f"Prunable nodes for max weight value of {max_value}")
    display(con_multi.execute(query_pruning_multi, [max_value]).df())

Prunable nodes for max weight value of 0.01


Unnamed: 0,id,name,num_hidden_nodes,num_prunable_nodes,percentage_prunable
0,1,Regular,30976,2,0.01


Prunable nodes for max weight value of 0.02


Unnamed: 0,id,name,num_hidden_nodes,num_prunable_nodes,percentage_prunable
0,1,Regular,30976,2459,7.94
1,2,2x smaller,15488,246,1.59
2,3,4x smaller,7744,22,0.28


Prunable nodes for max weight value of 0.03


Unnamed: 0,id,name,num_hidden_nodes,num_prunable_nodes,percentage_prunable
0,1,Regular,30976,3444,11.12
1,2,2x smaller,15488,692,4.47
2,3,4x smaller,7744,265,3.42
3,4,8x smaller,3872,9,0.23
4,5,16x smaller,1936,2,0.1


Prunable nodes for max weight value of 0.04


Unnamed: 0,id,name,num_hidden_nodes,num_prunable_nodes,percentage_prunable
0,1,Regular,30976,4378,14.13
1,2,2x smaller,15488,945,6.1
2,3,4x smaller,7744,385,4.97
3,4,8x smaller,3872,20,0.52
4,5,16x smaller,1936,8,0.41


Prunable nodes for max weight value of 0.05


Unnamed: 0,id,name,num_hidden_nodes,num_prunable_nodes,percentage_prunable
0,1,Regular,30976,5145,16.61
1,2,2x smaller,15488,1200,7.75
2,3,4x smaller,7744,466,6.02
3,4,8x smaller,3872,28,0.72
4,5,16x smaller,1936,19,0.98
