Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions dsync/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,30 @@ def has_diffs(self) -> bool:
return False

def get_children(self) -> Iterator["DiffElement"]:
"""Iterate over all child elements in all groups in self.children."""
"""Iterate over all child elements in all groups in self.children.

For each group of children, check if an order method is defined,
Otherwise use the default method.
"""
order_default = "order_children_default"

for group in self.groups():
for child in self.children[group].values():
yield child
order_method_name = f"order_children_{group}"
if hasattr(self, order_method_name):
order_method = getattr(self, order_method_name)
else:
order_method = getattr(self, order_default)

yield from order_method(self.children[group])

@classmethod
def order_children_default(cls, children: dict) -> Iterator["DiffElement"]:
"""Default method to an Iterator for children.

Since children is already an OrderedDefaultDict, this method is not doing anything special.
"""
for child in children.values():
yield child

def print_detailed(self, indent: int = 0):
"""Print all diffs to screen for all child elements.
Expand Down
16 changes: 15 additions & 1 deletion examples/example1/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
from backend_c import BackendC


from dsync import Diff


class MyDiff(Diff):
"""Custom Diff class to control the order of the site objects."""

@classmethod
def order_children_site(cls, children):
"""Return the site children ordered in alphabetical order."""
keys = sorted(children.keys(), reverse=False)
for key in keys:
yield children[key]


def main():
"""Demonstrate DSync behavior using the example backends provided."""
# pylint: disable=invalid-name
Expand All @@ -18,7 +32,7 @@ def main():
c = BackendC()
c.load()

diff_a_b = a.diff_to(b)
diff_a_b = a.diff_to(b, diff_class=MyDiff)
diff_a_b.print_detailed()

a.sync_to(b)
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,42 @@ def test_diff_children():
assert diff.has_diffs()

# TODO: test print_detailed


def test_order_children_default(backend_a, backend_b):
"""Test that order_children_default is properly called when calling get_children."""

class MyDiff(Diff):
"""custom diff class to test order_children_default."""

@classmethod
def order_children_default(cls, children):
"""Return the children ordered in alphabetical order."""
keys = sorted(children.keys(), reverse=False)
for key in keys:
yield children[key]

# Validating default order method
diff_a_b = backend_a.diff_from(backend_b, diff_class=MyDiff)
children = diff_a_b.get_children()
children_names = [child.name for child in children]
assert children_names == ["atl", "nyc", "rdu", "sfo"]


def test_order_children_custom(backend_a, backend_b):
"""Test that a custom order_children method is properly called when calling get_children."""

class MyDiff(Diff):
"""custom diff class to test order_children_site."""

@classmethod
def order_children_site(cls, children):
"""Return the site children ordered in reverse-alphabetical order."""
keys = sorted(children.keys(), reverse=True)
for key in keys:
yield children[key]

diff_a_b = backend_a.diff_from(backend_b, diff_class=MyDiff)
children = diff_a_b.get_children()
children_names = [child.name for child in children]
assert children_names == ["sfo", "rdu", "nyc", "atl"]