diff --git a/dsync/diff.py b/dsync/diff.py index 48f4dd3c..1061f3b5 100644 --- a/dsync/diff.py +++ b/dsync/diff.py @@ -68,10 +68,30 @@ def has_diffs(self) -> bool: return False def get_children(self) -> Iterator["DiffElement"]: - """Iterate over all child elements in all groups in self.children.""" + """Iterate over all child elements in all groups in self.children. + + For each group of children, check if an order method is defined, + Otherwise use the default method. + """ + order_default = "order_children_default" + for group in self.groups(): - for child in self.children[group].values(): - yield child + order_method_name = f"order_children_{group}" + if hasattr(self, order_method_name): + order_method = getattr(self, order_method_name) + else: + order_method = getattr(self, order_default) + + yield from order_method(self.children[group]) + + @classmethod + def order_children_default(cls, children: dict) -> Iterator["DiffElement"]: + """Default method to an Iterator for children. + + Since children is already an OrderedDefaultDict, this method is not doing anything special. + """ + for child in children.values(): + yield child def print_detailed(self, indent: int = 0): """Print all diffs to screen for all child elements. diff --git a/examples/example1/main.py b/examples/example1/main.py index 9e57d6af..8fabf85f 100644 --- a/examples/example1/main.py +++ b/examples/example1/main.py @@ -5,6 +5,20 @@ from backend_c import BackendC +from dsync import Diff + + +class MyDiff(Diff): + """Custom Diff class to control the order of the site objects.""" + + @classmethod + def order_children_site(cls, children): + """Return the site children ordered in alphabetical order.""" + keys = sorted(children.keys(), reverse=False) + for key in keys: + yield children[key] + + def main(): """Demonstrate DSync behavior using the example backends provided.""" # pylint: disable=invalid-name @@ -18,7 +32,7 @@ def main(): c = BackendC() c.load() - diff_a_b = a.diff_to(b) + diff_a_b = a.diff_to(b, diff_class=MyDiff) diff_a_b.print_detailed() a.sync_to(b) diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index a9b889d5..31649574 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -49,3 +49,42 @@ def test_diff_children(): assert diff.has_diffs() # TODO: test print_detailed + + +def test_order_children_default(backend_a, backend_b): + """Test that order_children_default is properly called when calling get_children.""" + + class MyDiff(Diff): + """custom diff class to test order_children_default.""" + + @classmethod + def order_children_default(cls, children): + """Return the children ordered in alphabetical order.""" + keys = sorted(children.keys(), reverse=False) + for key in keys: + yield children[key] + + # Validating default order method + diff_a_b = backend_a.diff_from(backend_b, diff_class=MyDiff) + children = diff_a_b.get_children() + children_names = [child.name for child in children] + assert children_names == ["atl", "nyc", "rdu", "sfo"] + + +def test_order_children_custom(backend_a, backend_b): + """Test that a custom order_children method is properly called when calling get_children.""" + + class MyDiff(Diff): + """custom diff class to test order_children_site.""" + + @classmethod + def order_children_site(cls, children): + """Return the site children ordered in reverse-alphabetical order.""" + keys = sorted(children.keys(), reverse=True) + for key in keys: + yield children[key] + + diff_a_b = backend_a.diff_from(backend_b, diff_class=MyDiff) + children = diff_a_b.get_children() + children_names = [child.name for child in children] + assert children_names == ["sfo", "rdu", "nyc", "atl"]