# Recursive queries

Up until now we stayed true to the paper and added a CTE for each layer in the
neural network. In this chapter, we'll leverage SQL's recursive capabilities to
create a query that works for every FNN, without needing to know the number of
hidden layers beforehand. In fact, we know next to nothing about the networks we
query. They can have:

- An unknown number of input neurons
- An unknown number of output neurons
- An unknown number of hidden layers
- An unknown number of neurons in each hidden layer

## The network

To show we don't need to know anything about the neural network, we'll create
one with randomized properties.

In [1]:
import torch
import numpy as np
import utils.duckdb as db
import pandas as pd
import utils.nn as nn
import random

torch.manual_seed(223)
random.seed(223)

num_samples = 100

input_size = random.randint(1, 10)
output_size = random.randint(1, 10)

x_train = torch.randn(num_samples, input_size) * 100
y_train = torch.randn(num_samples, output_size) * 10

model = nn.ReLUFNN(
    input_size=input_size,
    output_size=output_size,
    hidden_size=random.randint(3, 10),
    num_hidden_layers=random.randint(2, 20)
)
nn.train(model, x_train, y_train, save_path="models/eval_recursive.pt")

db.load_pytorch_model_into_db(model)

If you look closely at the code above, you'll see that we traded in SQLite for
DuckDB. The corresponding module can be found [here](./utils/duckdb.py).

Previously, we used SQLite because DuckDB's performance suffers when
constructing large queries, as we've done so far. We now opt for DuckDB because
the recursive variant actually performs comparably to SQLite, and more primarily
because SQLite does not support aggregations in recursive queries, which our
query relies on.

Some more discussion on this topic can be found in a [separate
chapter](./A.1%20Aside%20-%20DuckDB%20bugreport).

## The query

The full query is given as follows:

```sql
WITH RECURSIVE input_values AS (
    SELECT ? AS input_set_id, ? AS input_node_idx, ? AS input_value
    UNION
    SELECT ? AS input_set_id, ? AS input_node_idx, ? AS input_value
),
input_nodes AS (
    SELECT
        id,
        bias,
        ROW_NUMBER() OVER (ORDER BY id) AS input_node_idx
    FROM node
    WHERE id NOT IN
    (SELECT dst FROM edge)
),
output_nodes AS (
    SELECT id
    FROM node
    WHERE id NOT IN
    (SELECT src FROM edge)
),
tx AS (
    -- Base case (t1)
    SELECT
        v.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * v.input_value)
        ) AS value,
        e.dst AS id
    FROM edge e
    JOIN input_nodes i ON i.id = e.src
    JOIN node n ON e.dst = n.id
    JOIN input_values v ON i.input_node_idx = v.input_node_idx
    GROUP BY e.dst, n.bias, v.input_set_id

    UNION ALL

    -- Recursive case
    SELECT
        tx.input_set_id AS input_set_id,
        GREATEST(
            0,
            n.bias + SUM(e.weight * tx.value)
        ) AS value,
        e.dst AS id
    FROM edge e
    JOIN tx ON tx.id = e.src
    JOIN node n ON e.dst = n.id
    GROUP BY e.dst, n.bias, tx.input_set_id
),
-- As the last step, repeat the calculation for the output nodes, but omit the
-- ReLU this time (per definition)
t_out AS (
    SELECT
        tx.input_set_id AS input_set_id,
        n.bias + SUM(e.weight * tx.value) AS value,
        e.dst AS id
    FROM edge e
    JOIN output_nodes o ON e.dst = o.id
    JOIN node n ON o.id = n.id
    JOIN tx ON tx.id = e.src
    GROUP BY e.dst, n.bias, tx.input_set_id
)
SELECT * FROM t_out ORDER BY id;
```

This query is very similar to what we ended up with in the previous chapter. The
key differences are:

- We use `WITH RECURSIVE` to indicate some CTEs use recursion.
- We add a query that fetches the set of output nodes. These are all nodes that
  do not have an outgoing edge.
- The recursive CTE in question is `tx`, which is split in a base case and a
  recursion step.
- The base case corresponds to $t_1$ of the FO(SUM) term.
- The recursion step corresponds to the term $t_l$.
- When we calculate $t_{out}$, we join on the set of output nodes, since these are
  the only values we are ultimately interested in.

Translating this query to code is straightforward. We copy it as-is, only adding
some logic to specify the input values.

In [2]:
def eval_nn(input_value):
    input_clauses = []
    for input_set, input in enumerate(input_value):
        for i,_ in enumerate(input):
            input_clauses.append(f"""
                SELECT
                    {input_set} AS input_set_id,
                    {i + 1} AS input_node_idx,
                    ? AS input_value
            """)

    query = f"""
        WITH RECURSIVE input_values AS (
            {" UNION ".join(input_clauses)}
        ),
        input_nodes AS (
            SELECT
                id,
                bias,
                ROW_NUMBER() OVER (ORDER BY id) AS input_node_idx
            FROM node
            WHERE id NOT IN
            (SELECT dst FROM edge)
        ),
        output_nodes AS (
            SELECT id
            FROM node
            WHERE id NOT IN
            (SELECT src FROM edge)
        ),
        tx AS (
            -- Base case (t1)
            SELECT
                v.input_set_id AS input_set_id,
                GREATEST(
                    0,
                    n.bias + SUM(e.weight * v.input_value)
                ) AS value,
                e.dst AS id
            FROM edge e
            JOIN input_nodes i ON i.id = e.src
            JOIN node n ON e.dst = n.id
            JOIN input_values v ON i.input_node_idx = v.input_node_idx
            GROUP BY e.dst, n.bias, v.input_set_id

            UNION ALL

            -- Recursive case
            SELECT
                tx.input_set_id AS input_set_id,
                GREATEST(
                    0,
                    n.bias + SUM(e.weight * tx.value)
                ) AS value,
                e.dst AS id
            FROM edge e
            JOIN tx ON tx.id = e.src
            JOIN node n ON e.dst = n.id
            GROUP BY e.dst, n.bias, tx.input_set_id
        ),
        -- As the last step, repeat the calculation for the output nodes, but omit the
        -- ReLU this time (per definition)
        t_out AS (
            SELECT
                tx.input_set_id AS input_set_id,
                n.bias + SUM(e.weight * tx.value) AS value,
                e.dst AS output_node_id
            FROM edge e
            JOIN output_nodes o ON e.dst = o.id
            JOIN node n ON o.id = n.id
            JOIN tx ON tx.id = e.src
            GROUP BY e.dst, n.bias, tx.input_set_id
        )
        SELECT * FROM t_out ORDER BY input_set_id, output_node_id;
    """

    args = []
    for input_set in input_value:
        for value in input_set:
            args.append(value)

    results = [[] for _ in range(0, len(input_value))]
    for row in db.con.execute(query, args).fetchall():
        (input_set_id, output, output_node_id) = row
        results[input_set_id].append(output)

    return np.array(results)

## The results

When we run this query, we can see that we achieve the same result as regular
model evaluation.

In [7]:
test_input = torch.randn(1, input_size) * 100
nn_output = model(test_input).detach().numpy()
sql_output = eval_nn(test_input.tolist())

print(f"The neural network predicted {nn_output}")
print(f"The SQL query calculated {sql_output}")

The neural network predicted [[-0.73033047  4.7305093   0.5392649  -0.9155848  -1.2634848   3.1685772 ]]
The SQL query calculated [[-0.73033228  4.73051359  0.53926785 -0.91558373 -1.26348692  3.16857577]]


## Conclusion

We now have a way to evaluate a neural network in SQL without any prerequisite
knowledge of the neural network. Note that this part was mostly informative. The
FO(SUM) logic does not allow recursion, SQL is more powerful in that aspect.

In the following chapter we'll try to replicate this result using a graph
database.