Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
## Bug fixes

* Fixed a bug which caused the auth token returned from the GDS Arrow Server was not correctly received.
* Fixed a bug which didn't allow the user to specify `relationship_types` as a string in `gds.graph.relationshipProperties.stream()`.
* Fixed a bug in `kge-predict-transe-pyg-train.ipynb` which now uses the `gds.graph.relationshipProperty.stream()` call and can correctly handle multiple relationships between the same pair of nodes. Issue ref: [#554](https://github.com/neo4j/graph-data-science-client/issues/554)

## Improvements

Expand Down
27 changes: 12 additions & 15 deletions doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ version (2.5+ or later) installed.

Additionally, the following Python libraries are required:

* `graphdatascience`
(https://neo4j.com/docs/graph-data-science-client/current/installation/[see
documentation for installation instructions])
* PyG
(https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
PyG documentation for installation instructions])
* `graphdatascience`,
https://neo4j.com/docs/graph-data-science-client/current/installation/[see
documentation for installation instructions]
* `pytorch-geometric` version >= 2.5.0,
https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
PyG documentation for installation instructions]

== Setup

Expand Down Expand Up @@ -233,15 +233,13 @@ format it into a `Data` structure suitable for training with PyG.
[source, python, role=no-test]
----
def create_data_from_graph(relationship_type):
rels_tmp = gds.graph.relationshipProperties.stream(
ttv_G, ["rel_id"], relationship_type, separate_property_columns=True
)
rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
topology = [
rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
]
edge_index = torch.tensor(topology, dtype=torch.long)
edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)
edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
data = Data(edge_index=edge_index, edge_type=edge_type)
data.num_nodes = len(nodeId_to_id)
display(data)
Expand Down Expand Up @@ -303,7 +301,7 @@ def train_model_with_pyg():
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=20000,
batch_size=1000,
k=10,
)

Expand All @@ -316,12 +314,11 @@ def train_model_with_pyg():
rank, hits = test(val_tensor_data)
print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")

print(model)
rank, hits_at_10 = test(test_tensor_data)
print(f"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}")

torch.save(model, f"./model_{epoch_count}.pt")

mean_rank, mrr, hits_at_k = test(test_tensor_data)
print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")

return model
----

Expand Down
19 changes: 8 additions & 11 deletions examples/kge-predict-transe-pyg-train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
"\n",
"Additionally, the following Python libraries are required:\n",
"\n",
"- `graphdatascience` ([see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/))\n",
"- PyG ([see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html))\n",
"- `graphdatascience`, [see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/)\n",
"- `pytorch-geometric` version >= 2.5.0, [see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)\n",
"\n",
"## Setup\n",
"\n",
Expand Down Expand Up @@ -307,15 +307,13 @@
"outputs": [],
"source": [
"def create_data_from_graph(relationship_type):\n",
" rels_tmp = gds.graph.relationshipProperties.stream(\n",
" ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n",
" )\n",
" rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n",
" topology = [\n",
" rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n",
" rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n",
" ]\n",
" edge_index = torch.tensor(topology, dtype=torch.long)\n",
" edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n",
" edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n",
" data = Data(edge_index=edge_index, edge_type=edge_type)\n",
" data.num_nodes = len(nodeId_to_id)\n",
" display(data)\n",
Expand Down Expand Up @@ -398,7 +396,7 @@
" head_index=data.edge_index[0],\n",
" rel_type=data.edge_type,\n",
" tail_index=data.edge_index[1],\n",
" batch_size=20000,\n",
" batch_size=1000,\n",
" k=10,\n",
" )\n",
"\n",
Expand All @@ -411,12 +409,11 @@
" rank, hits = test(val_tensor_data)\n",
" print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n",
"\n",
" print(model)\n",
" rank, hits_at_10 = test(test_tensor_data)\n",
" print(f\"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}\")\n",
"\n",
" torch.save(model, f\"./model_{epoch_count}.pt\")\n",
"\n",
" mean_rank, mrr, hits_at_k = test(test_tensor_data)\n",
" print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n",
"\n",
" return model"
]
},
Expand Down
6 changes: 3 additions & 3 deletions graphdatascience/graph/base_graph_proc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from ..server_version.compatible_with import compatible_with
from ..server_version.server_version import ServerVersion
from .graph_entity_ops_runner import (
GraphElementPropertyRunner,
GraphLabelRunner,
GraphNodePropertiesRunner,
GraphNodePropertyRunner,
GraphPropertyRunner,
GraphRelationshipPropertiesRunner,
GraphRelationshipPropertyRunner,
GraphRelationshipRunner,
GraphRelationshipsRunner,
)
Expand Down Expand Up @@ -390,9 +390,9 @@ def nodeProperties(self) -> GraphNodePropertiesRunner:
return GraphNodePropertiesRunner(self._query_runner, self._namespace, self._server_version)

@property
def relationshipProperty(self) -> GraphElementPropertyRunner:
def relationshipProperty(self) -> GraphRelationshipPropertyRunner:
self._namespace += ".relationshipProperty"
return GraphElementPropertyRunner(self._query_runner, self._namespace, self._server_version)
return GraphRelationshipPropertyRunner(self._query_runner, self._namespace, self._server_version)

@property
def relationshipProperties(self) -> GraphRelationshipPropertiesRunner:
Expand Down
19 changes: 12 additions & 7 deletions graphdatascience/graph/graph_entity_ops_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ def _handle_properties(
)


class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The def relationshipProperty is the only one what uses this runner..? 😮

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so it was renamed, moved closer to GraphRelationshipPropertiesRunner and added wrapping into a list if string was passed.

@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
def stream(self, G: Graph, node_properties: str, node_labels: Strings = ["*"], **config: Any) -> DataFrame:
self._namespace += ".stream"
return self._handle_properties(G, node_properties, node_labels, config)


class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
@filter_id_func_deprecation_warning()
Expand Down Expand Up @@ -197,6 +190,16 @@ def drop(self, G: Graph, node_properties: List[str], **config: Any) -> "Series[A
).squeeze()


class GraphRelationshipPropertyRunner(GraphEntityOpsBaseRunner):
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
def stream(
self, G: Graph, relationship_property: str, relationship_types: Strings = ["*"], **config: Any
) -> DataFrame:
self._namespace += ".stream"
relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would suggest to do this transformation as part of _handle_properties.
In our signature, we have Strings here, but these procedures do not always handle both str and list of str.
Such as https://neo4j.com/docs/graph-data-science/current/management-ops/graph-reads/graph-stream-relationships/#_syntax.
For nodeLabels we do support both str and list of str

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a card to fix this inconsistency as well in gds

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've made an experiment where I put this wrapping into a list in _handle_properties and it breaks nodeProperty.stream and relationshipProperty.stream cases. And I think we won't need this wrapping when we align APIs after getting this card done.
So I prefer to leave it as it is now and merge.

return self._handle_properties(G, relationship_property, relationship_types, config)


class GraphRelationshipPropertiesRunner(GraphEntityOpsBaseRunner):
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
def stream(
Expand All @@ -209,6 +212,8 @@ def stream(
) -> DataFrame:
self._namespace += ".stream"

relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types

result = self._handle_properties(G, relationship_properties, relationship_types, config)

# new format was requested, but the query was run via Cypher
Expand Down
40 changes: 40 additions & 0 deletions graphdatascience/tests/integration/test_graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,46 @@ def test_graph_relationshipProperties_stream_with_arrow_separate_property_column
assert {e for e in result["relY"]} == {5, 6, 7}


@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
def test_graph_relationshipProperties_stream_with_arrow_rel_as_str(gds: GraphDataScience) -> None:
G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})

result = gds.graph.relationshipProperties.stream(G, ["relX", "relY"], "REL", concurrency=2)

assert list(result.keys()) == [
"sourceNodeId",
"targetNodeId",
"relationshipType",
"relationshipProperty",
"propertyValue",
]

x_values = result[result.relationshipProperty == "relX"]
assert {e for e in x_values["propertyValue"]} == {4, 5, 6}
y_values = result[result.relationshipProperty == "relY"]
assert {e for e in y_values["propertyValue"]} == {5, 6, 7}


@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
def test_graph_relationshipProperties_stream_with_arrow_rel_as_str_sep(gds: GraphDataScience) -> None:
G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})

result = gds.graph.relationshipProperties.stream(
G, ["relX", "relY"], "REL", separate_property_columns=True, concurrency=2
)

assert list(result.keys()) == [
"sourceNodeId",
"targetNodeId",
"relationshipType",
"relX",
"relY",
]

assert {e for e in result["relX"]} == {4, 5, 6}
assert {e for e in result["relY"]} == {5, 6, 7}


def test_graph_streamRelationshipProperties_without_arrow(gds_without_arrow: GraphDataScience) -> None:
G, _ = gds_without_arrow.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})

Expand Down
4 changes: 2 additions & 2 deletions graphdatascience/tests/unit/test_graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_graph_relationshipProperty_stream(runner: CollectingQueryRunner, gds: G
assert runner.last_params() == {
"graph_name": "g",
"properties": "dummyProp",
"entities": "dummyType",
"entities": ["dummyType"],
"config": {"concurrency": 2},
}

Expand Down Expand Up @@ -390,7 +390,7 @@ def test_graph_relationshipProperties_stream(runner: CollectingQueryRunner, gds:
assert runner.last_params() == {
"graph_name": "g",
"properties": ["dummyProp"],
"entities": "dummyType",
"entities": ["dummyType"],
"config": {"concurrency": 2},
}

Expand Down
2 changes: 1 addition & 1 deletion requirements/dev/notebook-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ scipy == 1.10.1
torch==2.1.0
torch-scatter==2.1.1
torch-sparse==0.6.17
torch-geometric==2.3.1
torch-geometric>=2.5.0