# Graph databases

So far we've been using a relational database with SQL to construct our `eval`
queries. Since we're storing a directed acyclic graph, it might make more sense
to use a graph database. We'll take a look at [Neo4j](https://neo4j.com/) with
the [Cypher](https://neo4j.com/docs/cypher-manual/current/introduction/) query
language to see how they compare to our relational query.

## Database setup

To evaluate this notebook, we'll need a Neo4j instance. Use the following docker
command to start one:

```bash
docker run \
  --rm \
  -p 7474:7474 -p 7687:7687 \
  -e NEO4J_dbms_security_auth__enabled=false \
  neo4j:5.22.0
```

The graphical interface is now available at
[http://localhost:7474](http://localhost:7474). 

This command does not persist
any state, which is what we want for our experimentation. Authentication is
disabled for ease of use as well.

## Creating the model

As always, we first need a neural network to store and query. We'll reuse our
model with multiple input and output nodes.

In [7]:
import torch
import numpy as np
import utils.sqlite as db
import pandas as pd
import utils.nn as nn
from neo4j import GraphDatabase

torch.manual_seed(223)

def f(x, y):
    return [2*x, 4*y]

num_samples = 100
x_train = torch.randn(num_samples, 2) * 100
y_train = [f(x,y) for [x,y] in x_train]

model = nn.ReLUFNN(input_size=2, output_size=2, hidden_size=2, num_hidden_layers=1)
nn.train(model, x_train, y_train, save_path="models/eval_graph.pt")

## Creating the database

Now let's try to insert that into Neo4j. This is basically a slight adaptation
from the [SQLite](./utils/sqlite.py) version.

In [8]:
driver = GraphDatabase.driver("neo4j://localhost")
driver.verify_connectivity()

# Delete any preexisting data.
driver.execute_query("MATCH (n) DETACH DELETE n;")

state_dict = model.state_dict()
input_weights = list(state_dict.items())[0][1].tolist()
num_input_nodes = len(input_weights[0])

node_ids = [[]]

def create_node(tx, label, bias):
    query = """
        CREATE (n:Node {bias: $bias, label: $label})
        RETURN elementId(n) AS nodeId
    """

    result = tx.run(query, bias=bias, label=label)
    return result.single()

def create_edge(tx, from_id, to_id, weight):
    query = """
        MATCH (a) WHERE elementId(a) = $from_id
        MATCH (b) WHERE elementId(b) = $to_id
        CREATE (a)-[r:EDGE {weight: $weight}]->(b)
    """
    result = tx.run(query, from_id=from_id, to_id=to_id, weight=weight)
    return result.single()

with driver.session() as session:
    # Insert input nodes.
    for i in range(0, num_input_nodes):
        result = session.execute_write(create_node, f"input.{i}", 0)
        node_ids[0].append(result['nodeId'])

    layer = 0
    # In the first pass, insert all nodes with their biases
    for name, values in state_dict.items():
        if not "bias" in name:
            continue

        node_ids.append([])
        layer += 1
        for i, bias in enumerate(values.tolist()):
            result = session.execute_write(create_node, f"{name}.{i}", bias)
            node_ids[layer].append(result['nodeId'])

    # In the second pass, insert all edges and their weights. This assumes a fully
    # connected network.
    layer = 0
    for name, values in state_dict.items():
        if not "weight" in name:
            continue

        # Each weight tensor has a list for each node in the next layer. The
        # elements of this list correspond to the nodes of the current layer.
        weight_tensor = values.tolist()
        for from_index, from_node in enumerate(node_ids[layer]):
            for to_index, to_node in enumerate(node_ids[layer + 1]):
                weight = weight_tensor[to_index][from_index]
                session.execute_write(create_edge, from_node, to_node, weight)

        layer += 1

If you head over to [localhost:7474](http://localhost:7474), you can run the
following query to fetch the network:

```cypher
MATCH (n:Node) RETURN n
```

The result will look something like this:

![Neo4j database as an image](./assets/neo4j_graph.png)

## Writing the Cypher query

To match the FO(SUM) language, our first SQL queries did not use any recursion.
Instead, we assumed the number of hidden layers was known beforehand and we
explicitly added each layer to the query. We'll do the same for Cypher. This is
probably even necessary, since Cypher seems to lack recursion (without
extensions).

The full Cypher query to evaluate the neural network is given as follows:

```cypher
WITH [5, 20] AS inputValues

MATCH (inputNode:Node)
WHERE NOT (:Node)-[:EDGE]->(inputNode)

WITH inputValues, collect(inputNode) AS inputNodes
UNWIND range(0, size(inputValues) - 1) AS idx
WITH inputValues[idx] AS inputValue, inputNodes[idx] AS inputNode

MATCH (inputNode:Node)-[e:EDGE]->(t1:Node)
WITH
    t1,
    CASE 
        WHEN SUM(e.weight * inputValue) + t1.bias > 0
        THEN SUM(e.weight * inputValue) + t1.bias
        ELSE 0
    END AS value

MATCH (t1:Node)-[e:EDGE]->(t2:Node)
WITH
    t2,
    SUM(e.weight * value) + t2.bias AS value

RETURN t2, value
ORDER BY elementId(t2)
```

Let's break that down.

```cypher
WITH [5, 20] AS inputValues
```

This clause simply specifies the input values as a list.

```cypher
MATCH (inputNode:Node)
WHERE NOT (:Node)-[:EDGE]->(inputNode)
```

Here we define the set of input nodes as the nodes that do not have an incoming
edge. This shows how Cypher queries are visually expressive, especially compared
to their SQL counterparts.

```cypher
WITH inputValues, collect(inputNode) AS inputNodes
UNWIND range(0, size(inputValues) - 1) AS idx
WITH inputValues[idx] AS inputValue, inputNodes[idx] AS inputNode
```

This part associates each input node with an index (cfr. SQL's `ROW_NUMBER`).
With this index we can connect each input node to their respective input value
that was defined as the first thing in the query.

```cypher
MATCH (inputNode:Node)-[e:EDGE]->(t1:Node)
WITH
    t1,
    CASE 
        WHEN SUM(e.weight * inputValue) + t1.bias > 0
        THEN SUM(e.weight * inputValue) + t1.bias
        ELSE 0
    END AS value
```

Now we can calculate the term $t_1$. The `CASE`-expression represents the ReLU
function.

If the neural network has more layers, we have to repeatedly add similar
expressions to the query.

```cypher
MATCH (t1:Node)-[e:EDGE]->(t2:Node)
WITH
    t2,
    SUM(e.weight * value) + t2.bias AS value
```

Finally, by using the same expression with the ReLU omitted, we calculate the
output values.

## Evalutation

Let's implement this in Python and compare the results.

In [11]:
def eval_cypher(inputs):
    query = f"""
        WITH [{', '.join([str(input) for input in inputs])}] AS inputValues

        MATCH (inputNode:Node)
        WHERE NOT (:Node)-[:EDGE]->(inputNode)

        WITH inputValues, collect(inputNode) AS inputNodes
        UNWIND range(0, size(inputValues) - 1) AS idx
        WITH inputValues[idx] AS inputValue, inputNodes[idx] AS inputNode

        MATCH (inputNode:Node)-[e:EDGE]->(t1:Node)
        WITH
            t1,
            CASE
                WHEN SUM(e.weight * inputValue) + t1.bias > 0
                THEN SUM(e.weight * inputValue) + t1.bias
                ELSE 0
            END AS value

        // If there are more layers, we need to manually add them here.

        MATCH (t1:Node)-[e:EDGE]->(t2:Node)
        WITH
            t2,
            SUM(e.weight * value) + t2.bias AS value

        RETURN t2, value
        ORDER BY elementId(t2)
    """

    results = driver.execute_query(query)[0]

    return [result['value'] for result in results]

In [12]:
nn_output = model(torch.tensor([5,20], dtype=torch.float32)).detach().numpy()
cypher_output = eval_cypher([5, 20])

print(f"The neural network predicted {nn_output}")
print(f"The Cypher query calculated {cypher_output}")

The neural network predicted [-24.508732  54.89358 ]
The Cypher query calculated [-24.508734581785873, 54.89358546175163]


We can see that the Cypher query returns the same value as the model.

TODO: try out on larger model?

## Conclusion

As we've shown, it is possible to construct a similar `eval`-query in Cypher.
Where SQL needs a lot of joins and aggregations, Cypher queries are a lot more
elegant for graph structures.