diff --git a/changelog.md b/changelog.md index f8ca9d876..beaf8a2f0 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,8 @@ ## Bug fixes * Fixed a bug which caused the auth token returned from the GDS Arrow Server was not correctly received. +* Fixed a bug which didn't allow the user to specify `relationship_types` as a string in `gds.graph.relationshipProperties.stream()`. +* Fixed a bug in `kge-predict-transe-pyg-train.ipynb` which now uses the `gds.graph.relationshipProperty.stream()` call and can correctly handle multiple relationships between the same pair of nodes. Issue ref: [#554](https://github.com/neo4j/graph-data-science-client/issues/554) ## Improvements diff --git a/doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc b/doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc index 564c2ba5a..08e80ea1a 100644 --- a/doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc +++ b/doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc @@ -29,12 +29,12 @@ version (2.5+ or later) installed. Additionally, the following Python libraries are required: -* `graphdatascience` -(https://neo4j.com/docs/graph-data-science-client/current/installation/[see -documentation for installation instructions]) -* PyG -(https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see -PyG documentation for installation instructions]) +* `graphdatascience`, +https://neo4j.com/docs/graph-data-science-client/current/installation/[see +documentation for installation instructions] +* `pytorch-geometric` version >= 2.5.0, +https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see +PyG documentation for installation instructions] == Setup @@ -233,15 +233,13 @@ format it into a `Data` structure suitable for training with PyG. [source, python, role=no-test] ---- def create_data_from_graph(relationship_type): - rels_tmp = gds.graph.relationshipProperties.stream( - ttv_G, ["rel_id"], relationship_type, separate_property_columns=True - ) + rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type) topology = [ rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]), ] edge_index = torch.tensor(topology, dtype=torch.long) - edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long) + edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long) data = Data(edge_index=edge_index, edge_type=edge_type) data.num_nodes = len(nodeId_to_id) display(data) @@ -303,7 +301,7 @@ def train_model_with_pyg(): head_index=data.edge_index[0], rel_type=data.edge_type, tail_index=data.edge_index[1], - batch_size=20000, + batch_size=1000, k=10, ) @@ -316,12 +314,11 @@ def train_model_with_pyg(): rank, hits = test(val_tensor_data) print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}") - print(model) - rank, hits_at_10 = test(test_tensor_data) - print(f"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}") - torch.save(model, f"./model_{epoch_count}.pt") + mean_rank, mrr, hits_at_k = test(test_tensor_data) + print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}") + return model ---- diff --git a/examples/kge-predict-transe-pyg-train.ipynb b/examples/kge-predict-transe-pyg-train.ipynb index ebfb23189..2733425be 100644 --- a/examples/kge-predict-transe-pyg-train.ipynb +++ b/examples/kge-predict-transe-pyg-train.ipynb @@ -37,8 +37,8 @@ "\n", "Additionally, the following Python libraries are required:\n", "\n", - "- `graphdatascience` ([see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/))\n", - "- PyG ([see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html))\n", + "- `graphdatascience`, [see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/)\n", + "- `pytorch-geometric` version >= 2.5.0, [see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)\n", "\n", "## Setup\n", "\n", @@ -307,15 +307,13 @@ "outputs": [], "source": [ "def create_data_from_graph(relationship_type):\n", - " rels_tmp = gds.graph.relationshipProperties.stream(\n", - " ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n", - " )\n", + " rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n", " topology = [\n", " rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n", " rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n", " ]\n", " edge_index = torch.tensor(topology, dtype=torch.long)\n", - " edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n", + " edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n", " data = Data(edge_index=edge_index, edge_type=edge_type)\n", " data.num_nodes = len(nodeId_to_id)\n", " display(data)\n", @@ -398,7 +396,7 @@ " head_index=data.edge_index[0],\n", " rel_type=data.edge_type,\n", " tail_index=data.edge_index[1],\n", - " batch_size=20000,\n", + " batch_size=1000,\n", " k=10,\n", " )\n", "\n", @@ -411,12 +409,11 @@ " rank, hits = test(val_tensor_data)\n", " print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n", "\n", - " print(model)\n", - " rank, hits_at_10 = test(test_tensor_data)\n", - " print(f\"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}\")\n", - "\n", " torch.save(model, f\"./model_{epoch_count}.pt\")\n", "\n", + " mean_rank, mrr, hits_at_k = test(test_tensor_data)\n", + " print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n", + "\n", " return model" ] }, diff --git a/graphdatascience/graph/base_graph_proc_runner.py b/graphdatascience/graph/base_graph_proc_runner.py index 6d5531580..131540a2b 100644 --- a/graphdatascience/graph/base_graph_proc_runner.py +++ b/graphdatascience/graph/base_graph_proc_runner.py @@ -15,12 +15,12 @@ from ..server_version.compatible_with import compatible_with from ..server_version.server_version import ServerVersion from .graph_entity_ops_runner import ( - GraphElementPropertyRunner, GraphLabelRunner, GraphNodePropertiesRunner, GraphNodePropertyRunner, GraphPropertyRunner, GraphRelationshipPropertiesRunner, + GraphRelationshipPropertyRunner, GraphRelationshipRunner, GraphRelationshipsRunner, ) @@ -390,9 +390,9 @@ def nodeProperties(self) -> GraphNodePropertiesRunner: return GraphNodePropertiesRunner(self._query_runner, self._namespace, self._server_version) @property - def relationshipProperty(self) -> GraphElementPropertyRunner: + def relationshipProperty(self) -> GraphRelationshipPropertyRunner: self._namespace += ".relationshipProperty" - return GraphElementPropertyRunner(self._query_runner, self._namespace, self._server_version) + return GraphRelationshipPropertyRunner(self._query_runner, self._namespace, self._server_version) @property def relationshipProperties(self) -> GraphRelationshipPropertiesRunner: diff --git a/graphdatascience/graph/graph_entity_ops_runner.py b/graphdatascience/graph/graph_entity_ops_runner.py index fd3a79716..254973a72 100644 --- a/graphdatascience/graph/graph_entity_ops_runner.py +++ b/graphdatascience/graph/graph_entity_ops_runner.py @@ -70,13 +70,6 @@ def _handle_properties( ) -class GraphElementPropertyRunner(GraphEntityOpsBaseRunner): - @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0)) - def stream(self, G: Graph, node_properties: str, node_labels: Strings = ["*"], **config: Any) -> DataFrame: - self._namespace += ".stream" - return self._handle_properties(G, node_properties, node_labels, config) - - class GraphNodePropertyRunner(GraphEntityOpsBaseRunner): @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0)) @filter_id_func_deprecation_warning() @@ -197,6 +190,16 @@ def drop(self, G: Graph, node_properties: List[str], **config: Any) -> "Series[A ).squeeze() +class GraphRelationshipPropertyRunner(GraphEntityOpsBaseRunner): + @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0)) + def stream( + self, G: Graph, relationship_property: str, relationship_types: Strings = ["*"], **config: Any + ) -> DataFrame: + self._namespace += ".stream" + relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types + return self._handle_properties(G, relationship_property, relationship_types, config) + + class GraphRelationshipPropertiesRunner(GraphEntityOpsBaseRunner): @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0)) def stream( @@ -209,6 +212,8 @@ def stream( ) -> DataFrame: self._namespace += ".stream" + relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types + result = self._handle_properties(G, relationship_properties, relationship_types, config) # new format was requested, but the query was run via Cypher diff --git a/graphdatascience/tests/integration/test_graph_ops.py b/graphdatascience/tests/integration/test_graph_ops.py index beb0c86ed..858c4c4a3 100644 --- a/graphdatascience/tests/integration/test_graph_ops.py +++ b/graphdatascience/tests/integration/test_graph_ops.py @@ -695,6 +695,46 @@ def test_graph_relationshipProperties_stream_with_arrow_separate_property_column assert {e for e in result["relY"]} == {5, 6, 7} +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) +def test_graph_relationshipProperties_stream_with_arrow_rel_as_str(gds: GraphDataScience) -> None: + G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}}) + + result = gds.graph.relationshipProperties.stream(G, ["relX", "relY"], "REL", concurrency=2) + + assert list(result.keys()) == [ + "sourceNodeId", + "targetNodeId", + "relationshipType", + "relationshipProperty", + "propertyValue", + ] + + x_values = result[result.relationshipProperty == "relX"] + assert {e for e in x_values["propertyValue"]} == {4, 5, 6} + y_values = result[result.relationshipProperty == "relY"] + assert {e for e in y_values["propertyValue"]} == {5, 6, 7} + + +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) +def test_graph_relationshipProperties_stream_with_arrow_rel_as_str_sep(gds: GraphDataScience) -> None: + G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}}) + + result = gds.graph.relationshipProperties.stream( + G, ["relX", "relY"], "REL", separate_property_columns=True, concurrency=2 + ) + + assert list(result.keys()) == [ + "sourceNodeId", + "targetNodeId", + "relationshipType", + "relX", + "relY", + ] + + assert {e for e in result["relX"]} == {4, 5, 6} + assert {e for e in result["relY"]} == {5, 6, 7} + + def test_graph_streamRelationshipProperties_without_arrow(gds_without_arrow: GraphDataScience) -> None: G, _ = gds_without_arrow.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}}) diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index eaeb7bf57..62a033a59 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -305,7 +305,7 @@ def test_graph_relationshipProperty_stream(runner: CollectingQueryRunner, gds: G assert runner.last_params() == { "graph_name": "g", "properties": "dummyProp", - "entities": "dummyType", + "entities": ["dummyType"], "config": {"concurrency": 2}, } @@ -390,7 +390,7 @@ def test_graph_relationshipProperties_stream(runner: CollectingQueryRunner, gds: assert runner.last_params() == { "graph_name": "g", "properties": ["dummyProp"], - "entities": "dummyType", + "entities": ["dummyType"], "config": {"concurrency": 2}, } diff --git a/requirements/dev/notebook-ci.txt b/requirements/dev/notebook-ci.txt index eca3d8cad..71dd39079 100644 --- a/requirements/dev/notebook-ci.txt +++ b/requirements/dev/notebook-ci.txt @@ -6,4 +6,4 @@ scipy == 1.10.1 torch==2.1.0 torch-scatter==2.1.1 torch-sparse==0.6.17 -torch-geometric==2.3.1 +torch-geometric>=2.5.0