Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions json_merger/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def __init__(self, lst, root):

def move_to_result(self, lst_idx, match_uid):
self.in_result_idx.add(lst_idx)
if lst_idx in self.not_in_result_idx:
self.not_in_result_idx.remove(lst_idx)
self.not_in_result_idx.remove(lst_idx)

# TODO unless we actually do a same list comparator which is hard
# and takes a lot we need to drop this unused feature since:
Expand Down Expand Up @@ -152,10 +151,12 @@ def __init__(self, root, head, update, sources,
self._node_src_indices = {}
self._head_idx_to_node = {}
self._update_idx_to_node = {}
self._dirty_nodes = set()

self._next_node_id = 0
self.match_uids = {}

self.multiple_match_choice_idx = set()
self.multiple_match_choices = []

def _new_node_id(self):
Expand All @@ -177,6 +178,30 @@ def _push_node(self, root_elem, head_elem, update_elem):
if update_idx >= 0:
self._update_idx_to_node[update_idx] = node_id

def _get_nodes(self, head_elem, update_elem):
"""Get nodes to which either head_elem or update_elem point to."""
head_idx, head_obj = head_elem
update_idx, update_obj = update_elem
res = set()

if head_idx in self._head_idx_to_node:
res.add(self._head_idx_to_node[head_idx])
if update_idx in self._update_idx_to_node:
res.add(self._update_idx_to_node[update_idx])

return res

def _pop_node(self, node_id):
"""Remove a node from the graph."""
root_idx, head_idx, update_idx = self._node_src_indices[node_id]
del self._node_src_indices[node_id]
del self.node_data[node_id]

if head_idx in self._head_idx_to_node:
del self._head_idx_to_node[head_idx]
if update_idx in self._update_idx_to_node:
del self._update_idx_to_node[update_idx]

def _get_matches(self, target, source, source_idx):
comparator, src_list = self.comparators[(target, source)]
matches = comparator.get_matches(src_list, source_idx)
Expand All @@ -188,11 +213,21 @@ def _add_matches(self, root_elems, head_elems, update_elems):
for h in head_elems
for u in update_elems]
if len(matches) == 1:
self._push_node(*matches[0])
root_elem, head_elem, update_elem = matches[0]
node_ids = self._get_nodes(head_elem, update_elem)
# If this single match overrides a previous node entry we remove
# add this match as a multiple_match_choice and mark the node
# for removal. We will later remove the node from the graph so
# that future collisions with this node will be caught.
if not node_ids:
self._push_node(root_elem, head_elem, update_elem)
else:
self._dirty_nodes.update(node_ids)
self.multiple_match_choice_idx.add(
(root_elem[0], head_elem[0], update_elem[0]))
else:
match_objs = [(r[1] or None, h[1] or None, u[1] or None)
for r, h, u in matches]
self.multiple_match_choices.extend(match_objs)
match_indices = [(r[0], h[0], u[0]) for r, h, u in matches]
self.multiple_match_choice_idx.update(match_indices)

def _populate_nodes(self):
if 'head' in self.sources:
Expand All @@ -213,6 +248,15 @@ def _populate_nodes(self):
head_elems = self._get_matches('head', 'update', update_idx)
self._add_matches(root_elems, head_elems, update_elems)

for node_id in self._dirty_nodes:
self.multiple_match_choice_idx.add(self._node_src_indices[node_id])
self._pop_node(node_id)
for r_idx, h_idx, u_idx in self.multiple_match_choice_idx:
r_obj = self.root[r_idx] if r_idx >= 0 else None
h_obj = self.head[h_idx] if h_idx >= 0 else None
u_obj = self.update[u_idx] if u_idx >= 0 else None
self.multiple_match_choices.append((r_obj, h_obj, u_obj))

def _build_stats(self):
match_uid = 0
for node_id, indices in self._node_src_indices.items():
Expand All @@ -236,6 +280,18 @@ def _build_stats(self):
if len(matches) == 1 and matches[0][0] >= 0:
self.update_stats.add_root_match(idx, matches[0][0])

def _get_next_node(self, source, indices):
if source not in self.sources:
return None
idx_to_node = {
'head': self._head_idx_to_node,
'update': self._update_idx_to_node
}[source]
for idx in indices:
if idx in idx_to_node:
return idx_to_node[idx]
return None

def build_graph(self):
self._populate_nodes()
self._build_stats()
Expand All @@ -244,20 +300,10 @@ def build_graph(self):
# lists.
self.node_data[FIRST] = (NOTHING, NOTHING, NOTHING)
self.graph[FIRST] = BeforeNodes()
next_head_node = None
next_update_node = None

if 'head' in self.sources:
for idx in range(len(self.head)):
if idx in self._head_idx_to_node:
next_head_node = self._head_idx_to_node[idx]
break
if 'update' in self.sources:
for idx in range(len(self.update)):
if idx in self._update_idx_to_node:
next_update_node = self._update_idx_to_node[idx]
break

next_head_node = self._get_next_node('head', range(len(self.head)))
next_update_node = self._get_next_node('update',
range(len(self.update)))
self.graph[FIRST].head_node = next_head_node
self.graph[FIRST].update_node = next_update_node

Expand All @@ -272,18 +318,8 @@ def build_graph(self):
if update_idx >= 0:
update_next_l = range(update_idx + 1, len(self.update))

next_head_node = None
next_update_node = None
for head_next in head_next_l:
if (head_next in self._head_idx_to_node and
'head' in self.sources):
next_head_node = self._head_idx_to_node[head_next]
break
for update_next in update_next_l:
if (update_next in self._update_idx_to_node and
'update' in self.sources):
next_update_node = self._update_idx_to_node[update_next]
break
next_head_node = self._get_next_node('head', head_next_l)
next_update_node = self._get_next_node('update', update_next_l)
self.graph[node_id] = BeforeNodes(next_head_node, next_update_node)

return self.graph, self.node_data
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_list_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,33 @@ def test_error_on_multiple_match():
assert u.unified == [(2, 2, 2)]


def test_multiple_match_symmetry():
root = []
l1 = [1, 2, 3, 3, 3]
l2 = [1, 2, 3]

u1 = ListUnifier(root, l1, l2,
UnifierOps.KEEP_UPDATE_AND_HEAD_ENTITIES_HEAD_FIRST)
u2 = ListUnifier(root, l2, l1,
UnifierOps.KEEP_UPDATE_AND_HEAD_ENTITIES_UPDATE_FIRST)

with pytest.raises(MergeError) as u1_excinfo:
u1.unify()
with pytest.raises(MergeError) as u2_excinfo:
u2.unify()

assert len(u1_excinfo.value.content) == 3
assert len(u2_excinfo.value.content) == 3

for conflict in u1_excinfo.value.content + u2_excinfo.value.content:
assert conflict.conflict_type == ConflictType.MANUAL_MERGE
assert conflict.path == ()
assert conflict.body == (None, 3, 3)

assert u1.unified == [(NOTHING, 1, 1), (NOTHING, 2, 2)]
assert u2.unified == [(NOTHING, 1, 1), (NOTHING, 2, 2)]


def test_stats():
root = [1, 2, 10]
head = [1, 3, 4, 2]
Expand Down