Skip to content

Commit

Permalink
create rels on array properties
Browse files Browse the repository at this point in the history
  • Loading branch information
mpreusse committed Oct 25, 2022
1 parent 8e6fba4 commit 8646849
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 17 deletions.
3 changes: 2 additions & 1 deletion graphio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from graphio.objects.nodeset import NodeSet
from graphio.objects.relationshipset import RelationshipSet
from graphio.objects.datacontainer import Container
from graphio.model import ModelNode, ModelRelationship, MergeKey, Label, NodeDescriptor
from graphio.model import ModelNode, ModelRelationship, MergeKey, Label, NodeDescriptor
from graphio.objects.properties import ArrayProperty
11 changes: 11 additions & 0 deletions graphio/objects/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class Property:
def __init__(self, key: str):
self.key = key

def __str__(self):
return self.key


class ArrayProperty(Property):
def __init__(self, key: str):
super(ArrayProperty, self).__init__(key)
12 changes: 4 additions & 8 deletions graphio/objects/relationshipset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,11 @@ def __init__(self, rel_type, start_node_labels, end_node_labels, start_node_prop
"""
:param rel_type: Realtionship type.
:type rel_type: str
:param start_node_labels: Labels of the start node.
:type start_node_labels: list[str]
:param end_node_labels: Labels of the end node.
:type end_node_labels: list[str]
:param start_node_properties: Property keys to identify the start node.
:type start_node_properties: list[str]
:param end_node_properties: Properties to identify the end node.
:type end_node_properties: list[str]
:param batch_size: Batch size for Neo4j operations.
:type batch_size: int
"""

self.rel_type = rel_type
Expand All @@ -77,8 +71,8 @@ def __init__(self, rel_type, start_node_labels, end_node_labels, start_node_prop
self.combined = '{0}_{1}_{2}_{3}_{4}'.format(self.rel_type,
'_'.join(sorted(self.start_node_labels)),
'_'.join(sorted(self.end_node_labels)),
'_'.join(sorted(self.start_node_properties)),
'_'.join(sorted(self.end_node_properties))
'_'.join(sorted([str(x) for x in self.start_node_properties])),
'_'.join(sorted([str(x) for x in self.end_node_properties]))
)

if batch_size:
Expand Down Expand Up @@ -472,6 +466,7 @@ def create(self, graph, database=None, batch_size=None):

# iterate over chunks of rels
q = rels_create_unwind(self.start_node_labels, self.end_node_labels, self.start_node_properties, self.end_node_properties, self.rel_type)
print(q)
for batch in chunks(self.relationships, size=batch_size):
query_parameters = rels_params_from_objects(batch)
run_query_return_results(graph, q, database=database, **query_parameters)
Expand All @@ -489,6 +484,7 @@ def merge(self, graph, database=None, batch_size=None):
# iterate over chunks of rels
q = rels_merge_unwind(self.start_node_labels, self.end_node_labels, self.start_node_properties,
self.end_node_properties, self.rel_type)
print(q)
for batch in chunks(self.relationships, size=batch_size):
query_parameters = rels_params_from_objects(batch)
run_query_return_results(graph, q, database=database, **query_parameters)
Expand Down
22 changes: 17 additions & 5 deletions graphio/queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List

from graphio.objects.properties import ArrayProperty

class CypherQuery:

Expand Down Expand Up @@ -287,9 +287,15 @@ def rels_create_unwind(start_node_labels, end_node_labels, start_node_properties
# collect WHERE clauses
where_clauses = []
for property in start_node_properties:
where_clauses.append('a.{0} = rel.start_{0}'.format(property))
if isinstance(property, ArrayProperty):
where_clauses.append(f'rel.start_{property} IN a.{property}')
else:
where_clauses.append('a.{0} = rel.start_{0}'.format(property))
for property in end_node_properties:
where_clauses.append('b.{0} = rel.end_{0}'.format(property))
if isinstance(property, ArrayProperty):
where_clauses.append(f'rel.end_{property} IN b.{property}')
else:
where_clauses.append('b.{0} = rel.end_{0}'.format(property))

q.append("WHERE " + ' AND '.join(where_clauses))

Expand Down Expand Up @@ -336,9 +342,15 @@ def rels_merge_unwind(start_node_labels, end_node_labels, start_node_properties,
# collect WHERE clauses
where_clauses = []
for property in start_node_properties:
where_clauses.append('a.{0} = rel.start_{0}'.format(property))
if isinstance(property, ArrayProperty):
where_clauses.append(f'rel.start_{property} IN a.{property}')
else:
where_clauses.append('a.{0} = rel.start_{0}'.format(property))
for property in end_node_properties:
where_clauses.append('b.{0} = rel.end_{0}'.format(property))
if isinstance(property, ArrayProperty):
where_clauses.append(f'rel.end_{property} IN b.{property}')
else:
where_clauses.append('b.{0} = rel.end_{0}'.format(property))

q.append("WHERE " + ' AND '.join(where_clauses))

Expand Down
72 changes: 69 additions & 3 deletions test/test_objects/test_relationshipset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from graphio.objects.nodeset import NodeSet
from graphio.objects.relationshipset import RelationshipSet, tuplify_json_list
from graphio.objects.properties import ArrayProperty
from graphio.graph import run_query_return_results


Expand Down Expand Up @@ -52,9 +53,9 @@ def create_nodes_test(graph, clear_graph):
ns3 = NodeSet(['Bar'], merge_keys=['uuid', 'key'])

for i in range(100):
ns1.add_node({'uuid': i})
ns2.add_node({'uuid': i})
ns3.add_node({'uuid': i, 'key': i})
ns1.add_node({'uuid': i, 'array_key': [i, 9999, 99999]})
ns2.add_node({'uuid': i, 'array_key': [i, 7777, 77777]})
ns3.add_node({'uuid': i, 'key': i, 'array_key': [i, 6666, 66666]})

ns1.create(graph)
ns2.create(graph)
Expand Down Expand Up @@ -185,6 +186,32 @@ def test_relationshipset_create_mulitple_node_props(self, graph, create_nodes_te

assert result[0][0] == 100

def test_relationshipset_create_array_props(self, graph, create_nodes_test):

rs = RelationshipSet('TEST_ARRAY', ['Test'], ['Foo'], [ArrayProperty('array_key')], [ArrayProperty('array_key')])

for i in range(100):
rs.add_relationship({'array_key': i}, {'array_key': i})

rs.create(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100

def test_relationshipset_create_string_and_array_props(self, graph, create_nodes_test):

rs = RelationshipSet('TEST_ARRAY', ['Test'], ['Foo'], [ArrayProperty('array_key')], [ArrayProperty('array_key')])

for i in range(100):
rs.add_relationship({'uuid': i, 'array_key': i}, {'uuid': i, 'array_key': i})

rs.create(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100


class TestRelationshipSetIndex:
def test_relationship_create_single_index(self, graph, clear_graph, small_relationshipset):
Expand Down Expand Up @@ -221,6 +248,45 @@ def test_relationshipset_merge_number(self, graph, create_nodes_test, small_rela

assert result[0][0] == 100

def test_relationshipset_merge_array_props(self, graph, create_nodes_test):

rs = RelationshipSet('TEST_ARRAY', ['Test'], ['Foo'], [ArrayProperty('array_key')], [ArrayProperty('array_key')])

for i in range(100):
rs.add_relationship({'array_key': i}, {'array_key': i})

rs.merge(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100

# merge again
rs.merge(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100

def test_relationshipset_merge_string_and_array_props(self, graph, create_nodes_test):

rs = RelationshipSet('TEST_ARRAY', ['Test'], ['Foo'], [ArrayProperty('array_key')], [ArrayProperty('array_key')])

for i in range(100):
rs.add_relationship({'uuid': i, 'array_key': i}, {'uuid': i, 'array_key': i})

rs.merge(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100

# run again
rs.merge(graph)

result = run_query_return_results(graph, "MATCH (t:Test)-[r:TEST_ARRAY]->(f:Foo) RETURN count(r)")

assert result[0][0] == 100

class TestRelationshipSetSerialize:

Expand Down

0 comments on commit 8646849

Please sign in to comment.