In [None]:
# Make sure you have installed the custom GDS Client distributed with this notebook
from graphdatascience import GraphDataScience

In [None]:
# From the Aura Console, get the Connection URI to your Neo4j instance and paste here
URI = "neo4j+s://<dbid>-mlruntimedev.databases.neo4j-dev.io"
# And paste the database password here
PASSWORD = ""

In [None]:
# The usual GDS client initialization
gds = GraphDataScience(URI, auth=("neo4j", PASSWORD))
gds.set_database("neo4j")

In [None]:
# We will load the Cora dataset
# The progress bar is sometimes wonky; don't worry about it
try:
    gds.graph.load_cora()
except:
    pass

In [None]:
# The graph import is completed when this command returns a non-empty list
gds.graph.list()

# GNN training!

And now for the exciting stuff!
In the next cell, you will start a GNN training job.
In actuality, it is a PyTorch-Geometric GraphSAGE model being trained.
It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).
Of course, the training does not complete instantly, so you will have to wait for it to finish.

## Observing the training progress

You can observe the training progress by watching the logs.
This is done in the subsequent cell.
The watching doesn't automatically stop, so you will have to stop it manually.
Once you see the message 'Training Done', you can interrupt the cell and continue.

## Graph and training parameters




| Parameter          | Default        | Type           | Description                                                                                                                                                                           |
|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| graph_name         | -              | str            | The name of the graph to train on.                                                                                                                                                    |
| model_name         | -              | str            | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time.                                                                |
| feature_properties | -              | List[str]      | The node properties to use as model features.                                                                                                                                         |
| target_property    | -              | str            | The node property that contains the target class values.                                                                                                                              |
| node_labels        | None           | List[str]      | The node labels to use for training. By default, all labels are used.                                                                                                                 |
| relationship_types | None           | List[str]      | The relationship types to use for training. By default, all types are used.                                                                                                           |
| target_node_label  | None           | str            | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |
| graph_sage_config  | None           | dict           | Configuration for the GraphSAGE training. See below.                                                                                                                                  |


## GraphSAGE parameters

We have exposed several parameters of the PyG GraphSAGE model.

| Parameter       | Default  | Description                                                                                                                                                                                                                                                                                                                                                                                                                                                   |
|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| layer_config    | {}       | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |
| num_neighbors   | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0.                                                                                                                                                                                                                                                                                                                                                    |
| dropout         | 0.5      | Probability of dropping out neurons during training. Must be between 0 and 1.                                                                                                                                                                                                                                                                                                                                                                                 |
| hidden_channels | 256      | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0.                                                                                                                                                                                                                                                                                                                               |
| learning_rate   | 0.003    | The learning rate. Must be >0.                                                                                                                                                                                                                                                                                                                                                                                                                                |

Please try to use any of them with any useful values.


In [None]:
# Let's train!
job_id = gds.gnn.nodeClassification.train(
    "cora", "myModel", ["features"], "subject", ["CITES"], target_node_label="Paper", node_labels=["Paper"]
)

In [None]:
# And let's follow the progress by watching the logs
gds.gnn.nodeClassification.watch_logs(job_id)

In [None]:
# Once the train is completed, we can retrieve the training result (metrics)
train_result = gds.run_cypher("RETURN gds.remoteml.getTrainResult('myModel')");

In [None]:
# And display it
train_result

# GNN prediction!

Wow, that was cool.
But training a model is only half the picture.
We also have to use it for something.
In this case, we will use it to predict the subject of papers in the Cora dataset.

Again, this call is asynchronous, so it will return immediately.
Observe the progress by watching the logs.

Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).
We can retrieve the prediction result (the predictions themselves) by streaming from the graph.


In [None]:
# Let's trigger prediction!
job_id = gds.gnn.nodeClassification.predict("cora", "myModel", "myPredictions")

In [None]:
# And let's follow progress by watching the logs
gds.gnn.nodeClassification.watch_logs(job_id)

In [None]:
# Now that prediction is done, let's see the predictions
cora = gds.graph.get("cora")

In [None]:
# Now for some standard GDS stuff; streaming properties from the graph
predictions = gds.graph.nodeProperties.stream(
    cora, node_properties=["features", "myPredictions"], separate_property_columns=True
)

In [None]:
# And displaying them
predictions

# And that's it!

Thank you very much for participating in the testing.
We hope you enjoyed it.
If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.
For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?
If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.

Thank you!