Skip to content

Commit

Permalink
relationship source
Browse files Browse the repository at this point in the history
  • Loading branch information
mpreusse committed Apr 4, 2023
1 parent 8b9f6c1 commit fbbd5a6
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 20 deletions.
17 changes: 9 additions & 8 deletions graphio/objects/relationshipset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from graphio import defaults
from graphio.helper import chunks, create_single_index, create_composite_index
from graphio.queries import rels_create_unwind, rels_merge_unwind, rels_params_from_objects
from graphio.queries import rels_create_factory, rels_merge_factory, rels_params_from_objects
from graphio.graph import run_query_return_results

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,7 +80,7 @@ class RelationshipSet:
"""

def __init__(self, rel_type, start_node_labels, end_node_labels, start_node_properties, end_node_properties,
batch_size=None, default_props=None):
batch_size=None, default_props=None, source=False):
"""
:param rel_type: Realtionship type.
Expand All @@ -101,6 +101,7 @@ def __init__(self, rel_type, start_node_labels, end_node_labels, start_node_prop
self.start_node_properties = start_node_properties
self.end_node_properties = end_node_properties
self.default_props = default_props
self.source = source

self.fixed_order_start_node_properties = tuple(self.start_node_properties)
self.fixed_order_end_node_properties = tuple(self.end_node_properties)
Expand Down Expand Up @@ -478,11 +479,11 @@ def create(self, graph, database=None, batch_size=None):
log.debug('Batch Size: {}'.format(batch_size))

# iterate over chunks of rels
q = rels_create_unwind(self.start_node_labels, self.end_node_labels, self.start_node_properties,
self.end_node_properties, self.rel_type)
q = rels_create_factory(self.start_node_labels, self.end_node_labels, self.start_node_properties,
self.end_node_properties, self.rel_type, source=self.source)
for batch in chunks(self.relationships, size=batch_size):
query_parameters = rels_params_from_objects(batch)
run_query_return_results(graph, q, database=database, **query_parameters)
run_query_return_results(graph, q, database=database, source=self.uuid, **query_parameters)

def merge(self, graph, database=None, batch_size=None):
"""
Expand All @@ -494,11 +495,11 @@ def merge(self, graph, database=None, batch_size=None):
log.debug('Batch Size: {}'.format(batch_size))

# iterate over chunks of rels
q = rels_merge_unwind(self.start_node_labels, self.end_node_labels, self.start_node_properties,
self.end_node_properties, self.rel_type)
q = rels_merge_factory(self.start_node_labels, self.end_node_labels, self.start_node_properties,
self.end_node_properties, self.rel_type, source=self.source)
for batch in chunks(self.relationships, size=batch_size):
query_parameters = rels_params_from_objects(batch)
run_query_return_results(graph, q, database=database, **query_parameters)
run_query_return_results(graph, q, database=database, source=self.uuid, **query_parameters)

def create_index(self, graph, database=None):
"""
Expand Down
20 changes: 14 additions & 6 deletions graphio/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def rels_params_from_objects(relationships, property_identifier=None):
return {property_identifier: output}


def rels_create_unwind(start_node_labels, end_node_labels, start_node_properties,
end_node_properties, rel_type, property_identifier=None):
def rels_create_factory(start_node_labels, end_node_labels, start_node_properties,
end_node_properties, rel_type, property_identifier=None, source=False):
"""
Create relationship query with explicit arguments.
Expand Down Expand Up @@ -244,13 +244,16 @@ def rels_create_unwind(start_node_labels, end_node_labels, start_node_properties
q.append("WHERE " + ' AND '.join(where_clauses))

q.append(f"CREATE (a)-[r:{rel_type}]->(b)")
q.append("SET r = rel.properties RETURN count(r)")
q.append("SET r = rel.properties")

if source:
q.append(f"SET r._source = [$source]")

return q.query()


def rels_merge_unwind(start_node_labels, end_node_labels, start_node_properties,
end_node_properties, rel_type, property_identifier=None):
def rels_merge_factory(start_node_labels, end_node_labels, start_node_properties,
end_node_properties, rel_type, property_identifier=None, source=False):
"""
Merge relationship query with explicit arguments.
Expand Down Expand Up @@ -299,6 +302,11 @@ def rels_merge_unwind(start_node_labels, end_node_labels, start_node_properties,
q.append("WHERE " + ' AND '.join(where_clauses))

q.append(f"MERGE (a)-[r:{rel_type}]->(b)")
q.append("SET r = rel.properties RETURN count(r)")
q.append("ON CREATE SET r = rel.properties")
q.append("ON MATCH SET r += rel.properties")

if source:
q.append(f"ON CREATE SET r._source = [$source]")
q.append(f"ON MATCH SET r._source = r._source + [$source]")

return q.query()
24 changes: 24 additions & 0 deletions test/test_objects/test_relationshipset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# however, NodeSets are also tested separately
import os
import json
from uuid import uuid4
import pytest
from graphio.objects.nodeset import NodeSet
from graphio.objects.relationshipset import RelationshipSet, tuplify_json_list
Expand Down Expand Up @@ -210,6 +211,13 @@ def test_relationshipset_create_number(self, graph, create_nodes_test, small_rel

assert result[0][0] == 100

def test_relationshipset_create_source(self, graph, create_nodes_test, small_relationshipset):
small_relationshipset.source = True
small_relationshipset.create(graph)

assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE '{small_relationshipset.uuid}' in r._source RETURN count(r)")[0][0] == 100
assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE size(r._source) = 1 RETURN count(r)")[0][0] == 100

def test_relationshipset_create_mulitple_node_props(self, graph, create_nodes_test):

rs = RelationshipSet('TEST', ['Test'], ['Bar'], ['uuid'], ['uuid', 'key'])
Expand Down Expand Up @@ -291,6 +299,22 @@ def test_relationshipset_merge_number(self, graph, create_nodes_test, small_rela

assert result[0][0] == 100

def test_relationshipset_merge_source(self, graph, create_nodes_test, small_relationshipset):
small_relationshipset.source = True
small_relationshipset.merge(graph)

assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE '{small_relationshipset.uuid}' in r._source RETURN count(r)")[0][0] == 100
assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE size(r._source) = 1 RETURN count(r)")[0][0] == 100

# change uuid of relationshipset
small_relationshipset.uuid = str(uuid4())
small_relationshipset.merge(graph)

assert run_query_return_results(graph,f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE '{small_relationshipset.uuid}' in r._source RETURN count(r)")[0][0] == 100

assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE size(r._source) = 2 RETURN count(r)")[0][0] == 100
assert run_query_return_results(graph, f"MATCH (:Test)-[r:TEST]->(:Foo) WHERE size(r._source) <> 2 RETURN count(r)")[0][0] == 0

def test_relationshipset_merge_no_labels(self, graph, create_nodes_test, small_relationshipset_no_labels):

small_relationshipset_no_labels.merge(graph)
Expand Down
33 changes: 27 additions & 6 deletions test/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from graphio.queries import rels_create_unwind, rels_merge_unwind, CypherQuery, get_label_string_from_list_of_labels, \
from graphio.queries import rels_create_factory, rels_merge_factory, CypherQuery, get_label_string_from_list_of_labels, \
match_clause_with_properties, merge_clause_with_properties, match_properties_as_string, nodes_merge_factory, \
nodes_create_factory

Expand Down Expand Up @@ -107,23 +107,44 @@ def test_nodes_merge_factory_preserve_array_props_with_source(self):
ON MATCH SET n._source = n._source + [$source]"""


class TestRelsCreate:
class TestRelationshipsCreateFactory:

def test_rels_create(self):
q = rels_create_unwind(['Person'], ['Movie'], ['name'], ['title'], "LIKES")
q = rels_create_factory(['Person'], ['Movie'], ['name'], ['title'], "LIKES")
assert q == """UNWIND $rels AS rel
MATCH (a:Person), (b:Movie)
WHERE a.name = rel.start_name AND b.title = rel.end_title
CREATE (a)-[r:LIKES]->(b)
SET r = rel.properties RETURN count(r)"""
SET r = rel.properties"""

def test_rels_create_source(self):
q = rels_create_factory(['Person'], ['Movie'], ['name'], ['title'], "LIKES", source=True)
assert q == """UNWIND $rels AS rel
MATCH (a:Person), (b:Movie)
WHERE a.name = rel.start_name AND b.title = rel.end_title
CREATE (a)-[r:LIKES]->(b)
SET r = rel.properties
SET r._source = [$source]"""


class TestRelsMerge:

def test_rels_merge_unwind(self):
q = rels_merge_unwind(['Person'], ['Movie'], ['name'], ['title'], "LIKES")
q = rels_merge_factory(['Person'], ['Movie'], ['name'], ['title'], "LIKES")
assert q == """UNWIND $rels AS rel
MATCH (a:Person), (b:Movie)
WHERE a.name = rel.start_name AND b.title = rel.end_title
MERGE (a)-[r:LIKES]->(b)
ON CREATE SET r = rel.properties
ON MATCH SET r += rel.properties"""

def test_rels_merge_source(self):
q = rels_merge_factory(['Person'], ['Movie'], ['name'], ['title'], "LIKES", source=True)
assert q == """UNWIND $rels AS rel
MATCH (a:Person), (b:Movie)
WHERE a.name = rel.start_name AND b.title = rel.end_title
MERGE (a)-[r:LIKES]->(b)
SET r = rel.properties RETURN count(r)"""
ON CREATE SET r = rel.properties
ON MATCH SET r += rel.properties
ON CREATE SET r._source = [$source]
ON MATCH SET r._source = r._source + [$source]"""

0 comments on commit fbbd5a6

Please sign in to comment.