diff --git a/dsync/__init__.py b/dsync/__init__.py index 44a6d004..36a89031 100644 --- a/dsync/__init__.py +++ b/dsync/__init__.py @@ -33,21 +33,39 @@ ) +class DSyncModelFlags(enum.Flag): + """Flags that can be set on a DSyncModel class or instance to affect its usage.""" + + NONE = 0 + + IGNORE = 0b1 + """Do not render diffs containing this model; do not make any changes to this model when synchronizing. + + Can be used to indicate a model instance that exists but should not be changed by DSync. + """ + + SKIP_CHILDREN_ON_DELETE = 0b10 + """When deleting this model, do not recursively delete its children. + + Can be used for the case where deletion of a model results in the automatic deletion of all its children. + """ + + class DSyncFlags(enum.Flag): """Flags that can be passed to a sync_* or diff_* call to affect its behavior.""" NONE = 0 - CONTINUE_ON_FAILURE = enum.auto() + CONTINUE_ON_FAILURE = 0b1 """Continue synchronizing even if failures are encountered when syncing individual models.""" - SKIP_UNMATCHED_SRC = enum.auto() + SKIP_UNMATCHED_SRC = 0b10 """Ignore objects that only exist in the source/"from" DSync when determining diffs and syncing. If this flag is set, no new objects will be created in the target/"to" DSync. """ - SKIP_UNMATCHED_DST = enum.auto() + SKIP_UNMATCHED_DST = 0b100 """Ignore objects that only exist in the target/"to" DSync when determining diffs and syncing. If this flag is set, no objects will be deleted from the target/"to" DSync. @@ -55,7 +73,7 @@ class DSyncFlags(enum.Flag): SKIP_UNMATCHED_BOTH = SKIP_UNMATCHED_SRC | SKIP_UNMATCHED_DST - LOG_UNCHANGED_RECORDS = enum.auto() + LOG_UNCHANGED_RECORDS = 0b1000 """If this flag is set, a log message will be generated during synchronization for each model, even unchanged ones. By default, when this flag is unset, only models that have actual changes to synchronize will be logged. @@ -66,7 +84,7 @@ class DSyncFlags(enum.Flag): class DSyncModel(BaseModel): """Base class for all DSync object models. - Note that APIs of this class are implemented as `get_*()` functions rather than as properties; + Note that read-only APIs of this class are implemented as `get_*()` functions rather than as properties; this is intentional as specific model classes may want to use these names (`type`, `keys`, `attrs`, etc.) as model attributes and we want to avoid any ambiguity or collisions. @@ -110,6 +128,14 @@ class DSyncModel(BaseModel): """Optional: dict of `{_modelname: field_name}` entries describing how to store "child" models in this model. When calculating a Diff or performing a sync, DSync will automatically recurse into these child models. + + Note: inclusion in `_children` is mutually exclusive from inclusion in `_identifiers` or `_attributes`. + """ + + model_flags: DSyncModelFlags = DSyncModelFlags.NONE + """Optional: any non-default behavioral flags for this DSyncModel. + + Can be set as a class attribute or an instance attribute as needed. """ dsync: Optional["DSync"] = None @@ -207,7 +233,7 @@ def delete(self) -> Optional["DSyncModel"]: return self @classmethod - def get_type(cls): + def get_type(cls) -> str: """Return the type AKA modelname of the object or the class Returns: @@ -229,7 +255,7 @@ def get_children_mapping(cls) -> Mapping[str, str]: """Get the mapping of types to fieldnames for child models of this model.""" return cls._children - def get_identifiers(self): + def get_identifiers(self) -> dict: """Get a dict of all identifiers (primary keys) and their values for this object. Returns: @@ -237,7 +263,7 @@ def get_identifiers(self): """ return self.dict(include=set(self._identifiers)) - def get_attrs(self): + def get_attrs(self) -> dict: """Get all the non-primary-key attributes or parameters for this object. Similar to Pydantic's `BaseModel.dict()` method, with the following key differences: @@ -250,7 +276,7 @@ def get_attrs(self): """ return self.dict(include=set(self._attributes)) - def get_unique_id(self): + def get_unique_id(self) -> str: """Get the unique ID of an object. By default the unique ID is built based on all the primary keys defined in `_identifiers`. @@ -260,7 +286,7 @@ def get_unique_id(self): """ return self.create_unique_id(**self.get_identifiers()) - def get_shortname(self): + def get_shortname(self) -> str: """Get the (not guaranteed-unique) shortname of an object, if any. By default the shortname is built based on all the keys defined in `_shortname`. @@ -475,13 +501,17 @@ def _sync_from_diff_element( return if element.action == "create": - self.add(obj) if parent_model: parent_model.add_child(obj) + self.add(obj) elif element.action == "delete": - self.remove(obj) if parent_model: parent_model.remove_child(obj) + if obj.model_flags & DSyncModelFlags.SKIP_CHILDREN_ON_DELETE: + # We don't need to process the child objects, but we do need to discard them + self.remove(obj, remove_children=True) + return + self.remove(obj) for child in element.get_children(): self._sync_from_diff_element(child, flags=flags, parent_model=obj, logger=logger) @@ -498,27 +528,8 @@ def diff_from(self, source: "DSync", diff_class: Type[Diff] = Diff, flags: DSync diff_class (class): Diff or subclass thereof to use for diff calculation and storage. flags (DSyncFlags): Flags influencing the behavior of this diff operation. """ - log = self._log.bind(src=source, dst=self, flags=flags).unbind("dsync") - log.info("Beginning diff") - diff = diff_class() - - for obj_type in intersection(self.top_level, source.top_level): - - diff_elements = self._diff_objects( - source=source.get_all(obj_type), - dest=self.get_all(obj_type), - source_root=source, - flags=flags, - logger=log, - ) - - for diff_element in diff_elements: - diff.add(diff_element) - - # Notify the diff that it has been fully populated, in case it wishes to print, save to a file, etc. - log.info("Diff complete") - diff.complete() - return diff + differ = DSyncDiffer(src_dsync=source, dst_dsync=self, flags=flags, diff_class=diff_class) + return differ.calculate_diffs() def diff_to(self, target: "DSync", diff_class: Type[Diff] = Diff, flags: DSyncFlags = DSyncFlags.NONE) -> Diff: """Generate a Diff describing the difference from this DSync to another one. @@ -530,154 +541,6 @@ def diff_to(self, target: "DSync", diff_class: Type[Diff] = Diff, flags: DSyncFl """ return target.diff_from(self, diff_class=diff_class, flags=flags) - def _diff_objects( # pylint: disable=too-many-arguments - self, - source: Iterable[DSyncModel], - dest: Iterable[DSyncModel], - source_root: "DSync", - flags: DSyncFlags = DSyncFlags.NONE, - logger: structlog.BoundLogger = None, - ) -> List[DiffElement]: - """Generate a list of DiffElement between the given lists of objects. - - Helper method for `diff_from`/`diff_to`; this generally shouldn't be called on its own. - - Args: - source: DSyncModel instances retrieved from another DSync instance - dest: DSyncModel instances retrieved from this DSync instance - source_root (DSync): The other DSync object being diffed against (owner of the `source` models, if any) - flags (DSyncFlags): Flags influencing the behavior of this diff operation. - logger: Parent logging context - - Raises: - TypeError: if the source and dest args are not the same type, or if that type is unsupported - """ - diffs = [] - - if isinstance(source, ABCIterable) and isinstance(dest, ABCIterable): - # Convert a list of DSyncModels into a dict using the unique_ids as keys - dict_src = {item.get_unique_id(): item for item in source} if not isinstance(source, ABCMapping) else source - dict_dst = {item.get_unique_id(): item for item in dest} if not isinstance(dest, ABCMapping) else dest - - combined_dict = {} - for uid in dict_src: - combined_dict[uid] = (dict_src.get(uid), dict_dst.get(uid)) - for uid in dict_dst: - combined_dict[uid] = (dict_src.get(uid), dict_dst.get(uid)) - else: - # In the future we might support set, etc... - raise TypeError(f"Type combination {type(source)}/{type(dest)} is not supported... for now") - - self._validate_objects_for_diff(combined_dict) - - for uid in combined_dict: - log = logger or self._log - src_obj, dst_obj = combined_dict[uid] - if not src_obj and not dst_obj: - # Should never happen - raise RuntimeError(f"UID {uid} is in combined_dict but has neither src_obj nor dst_obj??") - if src_obj: - log = log.bind(model=src_obj.get_type(), unique_id=src_obj.get_unique_id()) - if flags & DSyncFlags.SKIP_UNMATCHED_SRC and not dst_obj: - log.debug("Skipping unmatched source object") - continue - diff_element = DiffElement( - obj_type=src_obj.get_type(), - name=src_obj.get_shortname(), - keys=src_obj.get_identifiers(), - source_name=source_root.name, - dest_name=self.name, - ) - elif dst_obj: - log = log.bind(model=dst_obj.get_type(), unique_id=dst_obj.get_unique_id()) - if flags & DSyncFlags.SKIP_UNMATCHED_DST and not src_obj: - log.debug("Skipping unmatched dest object") - continue - diff_element = DiffElement( - obj_type=dst_obj.get_type(), - name=dst_obj.get_shortname(), - keys=dst_obj.get_identifiers(), - source_name=source_root.name, - dest_name=self.name, - ) - - if src_obj: - diff_element.add_attrs(source=src_obj.get_attrs(), dest=None) - if dst_obj: - diff_element.add_attrs(source=None, dest=dst_obj.get_attrs()) - - # Recursively diff the children of src_obj and dst_obj and attach the resulting diffs to the diff_element - self._diff_child_objects(diff_element, src_obj, dst_obj, source_root, flags=flags, logger=logger) - - diffs.append(diff_element) - - return diffs - - @staticmethod - def _validate_objects_for_diff(combined_dict: Mapping[str, Tuple[Optional[DSyncModel], Optional[DSyncModel]]]): - """Check whether all DSyncModels in the given dictionary are valid for comparison to one another. - - Helper method for `_diff_objects`. - - Raises: - TypeError: If any pair of objects in the dict have differing get_type() values. - ValueError: If any pair of objects in the dict have differing get_shortname() or get_identifiers() values. - """ - for uid in combined_dict: - # TODO: should we check/enforce whether all source models have the same DSync, whether all dest likewise? - # TODO: should we check/enforce whether ALL DSyncModels in this dict have the same get_type() output? - src_obj, dst_obj = combined_dict[uid] - if src_obj and dst_obj: - if src_obj.get_type() != dst_obj.get_type(): - raise TypeError(f"Type mismatch: {src_obj.get_type()} vs {dst_obj.get_type()}") - if src_obj.get_shortname() != dst_obj.get_shortname(): - raise ValueError(f"Shortname mismatch: {src_obj.get_shortname()} vs {dst_obj.get_shortname()}") - if src_obj.get_identifiers() != dst_obj.get_identifiers(): - raise ValueError(f"Keys mismatch: {src_obj.get_identifiers()} vs {dst_obj.get_identifiers()}") - - def _diff_child_objects( # pylint: disable=too-many-arguments - self, - diff_element: DiffElement, - src_obj: Optional[DSyncModel], - dst_obj: Optional[DSyncModel], - source_root: "DSync", - flags: DSyncFlags, - logger: structlog.BoundLogger, - ): - """For all children of the given DSyncModel pair, diff them recursively, adding diffs to the given diff_element. - - Helper method for `_diff_objects`. - """ - children_mapping: Mapping[str, str] - if src_obj and dst_obj: - # Get the subset of child types common to both src_obj and dst_obj - src_mapping = src_obj.get_children_mapping() - dst_mapping = dst_obj.get_children_mapping() - children_mapping = {} - for child_type, child_fieldname in src_mapping.items(): - if child_type in dst_mapping: - children_mapping[child_type] = child_fieldname - elif src_obj: - children_mapping = src_obj.get_children_mapping() - elif dst_obj: - children_mapping = dst_obj.get_children_mapping() - else: - # Should be unreachable - raise RuntimeError("Called with neither src_obj nor dest_obj??") - - for child_type, child_fieldname in children_mapping.items(): - # for example, child_type == "device" and child_fieldname == "devices" - - # for example, getattr(src_obj, "devices") --> list of device uids - # --> src_dsync.get_by_uids(, "device") --> list of device instances - src_objs = source_root.get_by_uids(getattr(src_obj, child_fieldname), child_type) if src_obj else [] - dst_objs = self.get_by_uids(getattr(dst_obj, child_fieldname), child_type) if dst_obj else [] - - for child_diff_element in self._diff_objects( - source=src_objs, dest=dst_objs, source_root=source_root, flags=flags, logger=logger, - ): - diff_element.add_child(child_diff_element) - # ------------------------------------------------------------------------------ # Object Storage Management # ------------------------------------------------------------------------------ @@ -760,11 +623,12 @@ def add(self, obj: DSyncModel): self._data[modelname][uid] = obj - def remove(self, obj: DSyncModel): + def remove(self, obj: DSyncModel, remove_children: bool = False): """Remove a DSyncModel object from the store. Args: - obj (DSyncModel): object to delete + obj (DSyncModel): object to remove + remove_children (bool): If True, also recursively remove any children of this object Raises: ObjectNotFound: if the object is not present @@ -780,6 +644,193 @@ def remove(self, obj: DSyncModel): del self._data[modelname][uid] + if remove_children: + for child_type, child_fieldname in obj.get_children_mapping().items(): + for child_id in getattr(obj, child_fieldname): + child_obj = self.get(child_type, child_id) + if child_obj: + self.remove(child_obj, remove_children=remove_children) + # DSyncModel references DSync and DSync references DSyncModel. Break the typing loop: DSyncModel.update_forward_refs() + + +class DSyncDiffer: + """Helper class implementing diff calculation logic for DSync. + + Independent from Diff and DiffElement as those classes are purely data objects, while this stores some state. + """ + + def __init__(self, src_dsync: DSync, dst_dsync: DSync, flags: DSyncFlags, diff_class: Type[Diff] = Diff): + """Create a DSyncDiffer for calculating diffs between the provided DSync instances.""" + self.src_dsync = src_dsync + self.dst_dsync = dst_dsync + self.flags = flags + + self.logger = structlog.get_logger().new(src=src_dsync, dst=dst_dsync, flags=flags) + self.diff_class = diff_class + self.diff: Optional[Diff] = None + + def calculate_diffs(self) -> Diff: + """Calculate diffs between the src and dst DSync objects and return the resulting Diff.""" + if self.diff is not None: + return self.diff + + self.logger.info("Beginning diff calculation") + self.diff = self.diff_class() + for obj_type in intersection(self.dst_dsync.top_level, self.src_dsync.top_level): + diff_elements = self.diff_object_list( + src=self.src_dsync.get_all(obj_type), dst=self.dst_dsync.get_all(obj_type), + ) + + for diff_element in diff_elements: + self.diff.add(diff_element) + + self.logger.info("Diff calculation complete") + self.diff.complete() + return self.diff + + def diff_object_list(self, src: Iterable[DSyncModel], dst: Iterable[DSyncModel]) -> List[DiffElement]: + """Calculate diffs between two lists of like objects. + + Helper method to `calculate_diffs`, usually doesn't need to be called directly. + + These helper methods work in a recursive cycle: + diff_object_list -> diff_object_pair -> diff_child_objects -> diff_object_list -> etc. + """ + diff_elements = [] + + if isinstance(src, ABCIterable) and isinstance(dst, ABCIterable): + # Convert a list of DSyncModels into a dict using the unique_ids as keys + dict_src = {item.get_unique_id(): item for item in src} if not isinstance(src, ABCMapping) else src + dict_dst = {item.get_unique_id(): item for item in dst} if not isinstance(dst, ABCMapping) else dst + + combined_dict = {} + for uid in dict_src: + combined_dict[uid] = (dict_src.get(uid), dict_dst.get(uid)) + for uid in dict_dst: + combined_dict[uid] = (dict_src.get(uid), dict_dst.get(uid)) + else: + # In the future we might support set, etc... + raise TypeError(f"Type combination {type(src)}/{type(dst)} is not supported... for now") + + self.validate_objects_for_diff(combined_dict.values()) + + for uid in combined_dict: + src_obj, dst_obj = combined_dict[uid] + diff_element = self.diff_object_pair(src_obj, dst_obj) + + if diff_element: + diff_elements.append(diff_element) + + return diff_elements + + @staticmethod + def validate_objects_for_diff(object_pairs: Iterable[Tuple[Optional[DSyncModel], Optional[DSyncModel]]]): + """Check whether all DSyncModels in the given dictionary are valid for comparison to one another. + + Helper method for `diff_object_list`. + + Raises: + TypeError: If any pair of objects in the dict have differing get_type() values. + ValueError: If any pair of objects in the dict have differing get_shortname() or get_identifiers() values. + """ + for src_obj, dst_obj in object_pairs: + # TODO: should we check/enforce whether all source models have the same DSync, whether all dest likewise? + # TODO: should we check/enforce whether ALL DSyncModels in this dict have the same get_type() output? + if src_obj and dst_obj: + if src_obj.get_type() != dst_obj.get_type(): + raise TypeError(f"Type mismatch: {src_obj.get_type()} vs {dst_obj.get_type()}") + if src_obj.get_shortname() != dst_obj.get_shortname(): + raise ValueError(f"Shortname mismatch: {src_obj.get_shortname()} vs {dst_obj.get_shortname()}") + if src_obj.get_identifiers() != dst_obj.get_identifiers(): + raise ValueError(f"Keys mismatch: {src_obj.get_identifiers()} vs {dst_obj.get_identifiers()}") + + def diff_object_pair(self, src_obj: Optional[DSyncModel], dst_obj: Optional[DSyncModel]) -> Optional["DiffElement"]: + """Diff the two provided DSyncModel objects and return a DiffElement or None. + + Helper method to `calculate_diffs`, usually doesn't need to be called directly. + + These helper methods work in a recursive cycle: + diff_object_list -> diff_object_pair -> diff_child_objects -> diff_object_list -> etc. + """ + if src_obj: + model = src_obj.get_type() + unique_id = src_obj.get_unique_id() + shortname = src_obj.get_shortname() + keys = src_obj.get_identifiers() + elif dst_obj: + model = dst_obj.get_type() + unique_id = dst_obj.get_unique_id() + shortname = dst_obj.get_shortname() + keys = dst_obj.get_identifiers() + else: + raise RuntimeError("diff_object_pair() called with neither src_obj nor dst_obj??") + + log = self.logger.bind(model=model, unique_id=unique_id) + if self.flags & DSyncFlags.SKIP_UNMATCHED_SRC and not dst_obj: + log.debug("Skipping unmatched source object") + return None + if self.flags & DSyncFlags.SKIP_UNMATCHED_DST and not src_obj: + log.debug("Skipping unmatched dest object") + return None + if src_obj and src_obj.model_flags & DSyncModelFlags.IGNORE: + log.debug("Skipping due to IGNORE flag on source object") + return None + if dst_obj and dst_obj.model_flags & DSyncModelFlags.IGNORE: + log.debug("Skipping due to IGNORE flag on dest object") + return None + + diff_element = DiffElement( + obj_type=model, name=shortname, keys=keys, source_name=self.src_dsync.name, dest_name=self.dst_dsync.name, + ) + + if src_obj: + diff_element.add_attrs(source=src_obj.get_attrs(), dest=None) + if dst_obj: + diff_element.add_attrs(source=None, dest=dst_obj.get_attrs()) + + # Recursively diff the children of src_obj and dst_obj and attach the resulting diffs to the diff_element + self.diff_child_objects(diff_element, src_obj, dst_obj) + + return diff_element + + def diff_child_objects( + self, diff_element: DiffElement, src_obj: Optional[DSyncModel], dst_obj: Optional[DSyncModel], + ): + """For all children of the given DSyncModel pair, diff them recursively, adding diffs to the given diff_element. + + Helper method to `calculate_diffs`, usually doesn't need to be called directly. + + These helper methods work in a recursive cycle: + diff_object_list -> diff_object_pair -> diff_child_objects -> diff_object_list -> etc. + """ + children_mapping: Mapping[str, str] + if src_obj and dst_obj: + # Get the subset of child types common to both src_obj and dst_obj + src_mapping = src_obj.get_children_mapping() + dst_mapping = dst_obj.get_children_mapping() + children_mapping = {} + for child_type, child_fieldname in src_mapping.items(): + if child_type in dst_mapping: + children_mapping[child_type] = child_fieldname + elif src_obj: + children_mapping = src_obj.get_children_mapping() + elif dst_obj: + children_mapping = dst_obj.get_children_mapping() + else: + raise RuntimeError("Called with neither src_obj nor dest_obj??") + + for child_type, child_fieldname in children_mapping.items(): + # for example, child_type == "device" and child_fieldname == "devices" + + # for example, getattr(src_obj, "devices") --> list of device uids + # --> src_dsync.get_by_uids(, "device") --> list of device instances + src_objs = self.src_dsync.get_by_uids(getattr(src_obj, child_fieldname), child_type) if src_obj else [] + dst_objs = self.dst_dsync.get_by_uids(getattr(dst_obj, child_fieldname), child_type) if dst_obj else [] + + for child_diff_element in self.diff_object_list(src=src_objs, dst=dst_objs): + diff_element.add_child(child_diff_element) + + return diff_element diff --git a/pyproject.toml b/pyproject.toml index 21eda1a2..1b96e84c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,10 @@ extension-pkg-whitelist = [ "pydantic" ] +[tool.pylint.basic] +# No docstrings required for private methods (Pylint default), or for test_ functions. +no-docstring-rgx="^(_|test_)" + [tool.pylint.messages_control] # Line length is enforced by Black, so pylint doesn't need to check it. # Pylint and Black disagree about how to format multi-line arrays; Black wins. diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 17410373..e149fcea 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -227,6 +227,31 @@ def backend_a(): return dsync +@pytest.fixture +def backend_a_with_extra_models(): + """Provide an instance of BackendA subclass of DSync with some extra sites and devices.""" + extra_models = BackendA() + extra_models.load() + extra_site = extra_models.site(name="lax") + extra_models.add(extra_site) + extra_device = extra_models.device(name="nyc-spine3", site_name="nyc", role="spine") + extra_models.get(extra_models.site, "nyc").add_child(extra_device) + extra_models.add(extra_device) + return extra_models + + +@pytest.fixture +def backend_a_minus_some_models(): + """Provide an instance of BackendA subclass of DSync with fewer models than the default.""" + missing_models = BackendA() + missing_models.load() + missing_models.remove(missing_models.get(missing_models.site, "rdu")) + missing_device = missing_models.get(missing_models.device, "sfo-spine2") + missing_models.get(missing_models.site, "sfo").remove_child(missing_device) + missing_models.remove(missing_device) + return missing_models + + class ErrorProneSiteA(ErrorProneModel, SiteA): """A Site that sometimes throws exceptions.""" diff --git a/tests/unit/test_dsync.py b/tests/unit/test_dsync.py index 4c7036ff..81c87999 100644 --- a/tests/unit/test_dsync.py +++ b/tests/unit/test_dsync.py @@ -2,39 +2,56 @@ import pytest -from dsync import DSync, DSyncModel, DSyncFlags +from dsync import DSync, DSyncModel, DSyncFlags, DSyncModelFlags from dsync.exceptions import ObjectAlreadyExists, ObjectNotFound, ObjectCrudException from .conftest import Site, Device, Interface, TrackedDiff, BackendA -def test_generic_dsync_methods(generic_dsync, generic_dsync_model): - """Test the standard DSync APIs on a generic DSync instance and DSyncModel instance.""" - generic_dsync.load() # no-op +def test_dsync_default_name_type(generic_dsync): + assert generic_dsync.type == "DSync" + assert generic_dsync.name == "DSync" + + +def test_dsync_generic_load_is_noop(generic_dsync): + generic_dsync.load() assert len(generic_dsync._data) == 0 # pylint: disable=protected-access - diff = generic_dsync.diff_from(generic_dsync) - assert diff.has_diffs() is False - diff = generic_dsync.diff_to(generic_dsync) - assert diff.has_diffs() is False - generic_dsync.sync_from(generic_dsync) # no-op - generic_dsync.sync_to(generic_dsync) # no-op +def test_dsync_diff_self_with_no_data_has_no_diffs(generic_dsync): + assert generic_dsync.diff_from(generic_dsync).has_diffs() is False + assert generic_dsync.diff_to(generic_dsync).has_diffs() is False + + +def test_dsync_sync_self_with_no_data_is_noop(generic_dsync): + generic_dsync.sync_from(generic_dsync) + generic_dsync.sync_to(generic_dsync) + +def test_dsync_get_with_no_data_is_none(generic_dsync): assert generic_dsync.get("anything", "myname") is None assert generic_dsync.get(DSyncModel, "") is None + +def test_dsync_get_all_with_no_data_is_empty_list(generic_dsync): assert list(generic_dsync.get_all("anything")) == [] assert list(generic_dsync.get_all(DSyncModel)) == [] + +def test_dsync_get_by_uids_with_no_data_is_empty_list(generic_dsync): assert generic_dsync.get_by_uids(["any", "another"], "anything") == [] assert generic_dsync.get_by_uids(["any", "another"], DSyncModel) == [] + +def test_dsync_add(generic_dsync, generic_dsync_model): # A DSync can store arbitrary DSyncModel objects, even if it doesn't know about them at definition time. generic_dsync.add(generic_dsync_model) with pytest.raises(ObjectAlreadyExists): generic_dsync.add(generic_dsync_model) + +def test_dsync_get_with_generic_model(generic_dsync, generic_dsync_model): + generic_dsync.add(generic_dsync_model) # The generic_dsync_model has an empty identifier/unique-id assert generic_dsync.get(DSyncModel, "") == generic_dsync_model # DSync doesn't know what a "dsyncmodel" is @@ -44,11 +61,17 @@ def test_generic_dsync_methods(generic_dsync, generic_dsync_model): # Wrong unique-id - no match assert generic_dsync.get(DSyncModel, "myname") is None + +def test_dsync_get_all_with_generic_model(generic_dsync, generic_dsync_model): + generic_dsync.add(generic_dsync_model) assert list(generic_dsync.get_all(DSyncModel)) == [generic_dsync_model] assert list(generic_dsync.get_all(DSyncModel.get_type())) == [generic_dsync_model] # Wrong object-type - no match assert list(generic_dsync.get_all("anything")) == [] + +def test_dsync_get_by_uids_with_generic_model(generic_dsync, generic_dsync_model): + generic_dsync.add(generic_dsync_model) assert generic_dsync.get_by_uids([""], DSyncModel) == [generic_dsync_model] assert generic_dsync.get_by_uids([""], DSyncModel.get_type()) == [generic_dsync_model] # Wrong unique-id - no match @@ -56,6 +79,9 @@ def test_generic_dsync_methods(generic_dsync, generic_dsync_model): # Valid unique-id mixed in with unknown ones - return the successful matches? assert generic_dsync.get_by_uids(["aname", "", "anothername"], DSyncModel) == [generic_dsync_model] + +def test_dsync_remove_with_generic_model(generic_dsync, generic_dsync_model): + generic_dsync.add(generic_dsync_model) generic_dsync.remove(generic_dsync_model) with pytest.raises(ObjectNotFound): generic_dsync.remove(generic_dsync_model) @@ -64,14 +90,6 @@ def test_generic_dsync_methods(generic_dsync, generic_dsync_model): assert list(generic_dsync.get_all(DSyncModel)) == [] assert generic_dsync.get_by_uids([""], DSyncModel) == [] - diff_elements = generic_dsync._diff_objects( # pylint: disable=protected-access - [generic_dsync_model], [generic_dsync_model], generic_dsync, - ) - assert len(diff_elements) == 1 - assert not diff_elements[0].has_diffs() - assert diff_elements[0].source_name == "DSync" - assert diff_elements[0].dest_name == "DSync" - def test_dsync_subclass_validation(): """Test the declaration-time checks on a DSync subclass.""" @@ -88,69 +106,49 @@ class BadElementName(DSync): assert "dev_class" in str(excinfo.value) -def check_diff_symmetry(diff1, diff2): - """Recursively compare two Diffs to make sure they are equal and opposite to one another.""" - assert len(list(diff1.get_children())) == len(list(diff2.get_children())) - for elem1, elem2 in zip(sorted(diff1.get_children()), sorted(diff2.get_children())): - # Same basic properties - assert elem1.type == elem2.type - assert elem1.name == elem2.name - assert elem1.keys == elem2.keys - assert elem1.has_diffs() == elem2.has_diffs() - # Opposite diffs, if any - assert elem1.source_attrs == elem2.dest_attrs - assert elem1.dest_attrs == elem2.source_attrs - check_diff_symmetry(elem1.child_diff, elem2.child_diff) - - -def test_dsync_subclass_methods_diff_sync(backend_a, backend_b): - """Test DSync diff/sync APIs on an actual concrete subclass.""" - diff_elements = backend_a._diff_objects( # pylint: disable=protected-access - source=backend_b.get_all("site"), dest=backend_a.get_all("site"), source_root=backend_b - ) - assert len(diff_elements) == 4 # atl, nyc, sfo, rdu - for diff_element in diff_elements: - diff_element.print_detailed() - assert diff_element.has_diffs() - # We don't inspect the contents of the diff elements in detail here - see test_diff_element.py for that - +def test_dsync_diff_self_with_data_has_no_diffs(backend_a): # Self diff should always show no diffs! assert backend_a.diff_from(backend_a).has_diffs() is False assert backend_a.diff_to(backend_a).has_diffs() is False - diff_ab = backend_a.diff_to(backend_b) - assert diff_ab.has_diffs() is True - diff_ba = backend_a.diff_from(backend_b) - assert diff_ba.has_diffs() is True + +def test_dsync_diff_other_with_data_has_diffs(backend_a, backend_b): + assert backend_a.diff_to(backend_b).has_diffs() is True + assert backend_a.diff_from(backend_b).has_diffs() is True + + +def test_dsync_diff_to_and_diff_from_are_symmetric(backend_a, backend_b): + diff_ab = backend_a.diff_from(backend_b) + diff_ba = backend_a.diff_to(backend_b) + + def check_diff_symmetry(diff1, diff2): + """Recursively compare two Diffs to make sure they are equal and opposite to one another.""" + assert len(list(diff1.get_children())) == len(list(diff2.get_children())) + for elem1, elem2 in zip(sorted(diff1.get_children()), sorted(diff2.get_children())): + # Same basic properties + assert elem1.type == elem2.type + assert elem1.name == elem2.name + assert elem1.keys == elem2.keys + assert elem1.has_diffs() == elem2.has_diffs() + # Opposite diffs, if any + assert elem1.source_attrs == elem2.dest_attrs + assert elem1.dest_attrs == elem2.source_attrs + check_diff_symmetry(elem1.child_diff, elem2.child_diff) check_diff_symmetry(diff_ab, diff_ba) - # Perform sync of one subtree of diffs - backend_a._sync_from_diff_element(diff_elements[0]) # pylint: disable=protected-access - # Make sure the sync descended through the diff element all the way to the leafs - assert backend_a.get(Interface, "nyc-spine1__eth0").description == "Interface 0/0" # was initially Interface 0 - # Recheck diffs - diff_elements = backend_a._diff_objects( # pylint: disable=protected-access - source=backend_a.get_all("site"), dest=backend_b.get_all("site"), source_root=backend_b - ) - for diff_element in diff_elements: - diff_element.print_detailed() - assert len(diff_elements) == 4 # atl, nyc, sfo, rdu - assert not diff_elements[0].has_diffs() # sync completed, no diffs - assert diff_elements[1].has_diffs() - assert diff_elements[2].has_diffs() - assert diff_elements[3].has_diffs() +def test_dsync_diff_from_with_custom_diff_class(backend_a, backend_b): + diff_ba = backend_a.diff_from(backend_b, diff_class=TrackedDiff) + assert isinstance(diff_ba, TrackedDiff) + assert diff_ba.is_complete is True + + +def test_dsync_sync_from(backend_a, backend_b): # Perform full sync backend_a.sync_from(backend_b) # Make sure the sync descended through the diff elements to their children assert backend_a.get(Device, "sfo-spine1").role == "leaf" # was initially "spine" - # Recheck diffs, using a custom Diff subclass this time. - diff_ba = backend_a.diff_from(backend_b, diff_class=TrackedDiff) - assert isinstance(diff_ba, TrackedDiff) - assert diff_ba.is_complete is True - diff_ba.print_detailed() - assert diff_ba.has_diffs() is False # site_nyc and site_sfo should be updated, site_atl should be created, site_rdu should be deleted site_nyc_a = backend_a.get(Site, "nyc") @@ -175,21 +173,17 @@ def test_dsync_subclass_methods_diff_sync(backend_a, backend_b): assert backend_a.get_by_uids(["nyc", "sfo"], "device") == [] -def test_dsync_subclass_methods_name_type(backend_a, backend_b): - """Test DSync name and type an actual concrete subclass. - - backend_a is using the default name and type - backend_b is using a user defined name and type - """ +def test_dsync_subclass_default_name_type(backend_a): assert backend_a.name == "BackendA" assert backend_a.type == "BackendA" + +def test_dsync_subclass_custom_name_type(backend_b): assert backend_b.name == "backend-b" assert backend_b.type == "Backend_B" -def test_dsync_subclass_methods_crud(backend_a): - """Test DSync CRUD APIs against a concrete subclass.""" +def test_dsync_add_get_remove_with_subclass_and_data(backend_a): site_nyc_a = backend_a.get(Site, "nyc") site_sfo_a = backend_a.get("site", "sfo") site_rdu_a = backend_a.get(Site, "rdu") @@ -212,11 +206,12 @@ def test_dsync_subclass_methods_crud(backend_a): backend_a.remove(site_atl_a) -def test_dsync_subclass_methods_sync_exceptions(log, error_prone_backend_a, backend_b): - """Test handling of exceptions during a sync.""" +def test_dsync_sync_from_exceptions_are_not_caught_by_default(error_prone_backend_a, backend_b): with pytest.raises(ObjectCrudException): error_prone_backend_a.sync_from(backend_b) + +def test_dsync_sync_from_with_continue_on_failure_flag(log, error_prone_backend_a, backend_b): error_prone_backend_a.sync_from(backend_b, flags=DSyncFlags.CONTINUE_ON_FAILURE) # Not all sync operations succeeded on the first try remaining_diffs = error_prone_backend_a.diff_from(backend_b) @@ -255,45 +250,95 @@ def test_dsync_subclass_methods_sync_exceptions(log, error_prone_backend_a, back pytest.fail("Sync was still incomplete after 10 retries") -def test_dsync_subclass_methods_diff_sync_skip_flags(): - """Test diff and sync behavior when using the SKIP_UNMATCHED_* flags.""" - baseline = BackendA() - baseline.load() +def test_dsync_diff_with_skip_unmatched_src_flag(backend_a, backend_a_with_extra_models, backend_a_minus_some_models): + assert backend_a.diff_from(backend_a_with_extra_models).has_diffs() + # SKIP_UNMATCHED_SRC should mean that extra models in the src are not flagged for creation in the dest + assert not backend_a.diff_from(backend_a_with_extra_models, flags=DSyncFlags.SKIP_UNMATCHED_SRC).has_diffs() + # SKIP_UNMATCHED_SRC should NOT mean that extra models in the dst are not flagged for deletion in the src + assert backend_a.diff_from(backend_a_minus_some_models, flags=DSyncFlags.SKIP_UNMATCHED_SRC).has_diffs() - extra_models = BackendA() - extra_models.load() - extra_site = extra_models.site(name="lax") - extra_models.add(extra_site) - extra_device = extra_models.device(name="nyc-spine3", site_name="nyc", role="spine") - extra_models.get(extra_models.site, "nyc").add_child(extra_device) - extra_models.add(extra_device) - missing_models = BackendA() - missing_models.load() - missing_models.remove(missing_models.get(missing_models.site, "rdu")) - missing_device = missing_models.get(missing_models.device, "sfo-spine2") - missing_models.get(missing_models.site, "sfo").remove_child(missing_device) - missing_models.remove(missing_device) +def test_dsync_diff_with_skip_unmatched_dst_flag(backend_a, backend_a_with_extra_models, backend_a_minus_some_models): + assert backend_a.diff_from(backend_a_minus_some_models).has_diffs() + # SKIP_UNMATCHED_DST should mean that missing models in the src are not flagged for deletion from the dest + assert not backend_a.diff_from(backend_a_minus_some_models, flags=DSyncFlags.SKIP_UNMATCHED_DST).has_diffs() + # SKIP_UNMATCHED_DST should NOT mean that extra models in the src are not flagged for creation in the dest + assert backend_a.diff_from(backend_a_with_extra_models, flags=DSyncFlags.SKIP_UNMATCHED_DST).has_diffs() - assert baseline.diff_from(extra_models).has_diffs() - assert baseline.diff_to(missing_models).has_diffs() - # SKIP_UNMATCHED_SRC should mean that extra models in the src are not flagged for creation in the dest - assert not baseline.diff_from(extra_models, flags=DSyncFlags.SKIP_UNMATCHED_SRC).has_diffs() - # SKIP_UNMATCHED_DST should mean that missing models in the src are not flagged for deletion from the dest - assert not baseline.diff_from(missing_models, flags=DSyncFlags.SKIP_UNMATCHED_DST).has_diffs() - # SKIP_UNMATCHED_BOTH means, well, both - assert not extra_models.diff_from(missing_models, flags=DSyncFlags.SKIP_UNMATCHED_BOTH).has_diffs() - assert not extra_models.diff_to(missing_models, flags=DSyncFlags.SKIP_UNMATCHED_BOTH).has_diffs() +def test_dsync_diff_with_skip_unmatched_both_flag(backend_a, backend_a_with_extra_models, backend_a_minus_some_models): + # SKIP_UNMATCHED_BOTH should mean that extra models in the src are not flagged for creation in the dest + assert not backend_a.diff_from(backend_a_with_extra_models, flags=DSyncFlags.SKIP_UNMATCHED_BOTH).has_diffs() + # SKIP_UNMATCHED_BOTH should mean that missing models in the src are not flagged for deletion from the dest + assert not backend_a.diff_from(backend_a_minus_some_models, flags=DSyncFlags.SKIP_UNMATCHED_BOTH).has_diffs() + - baseline.sync_from(extra_models, flags=DSyncFlags.SKIP_UNMATCHED_SRC) +def test_dsync_sync_with_skip_unmatched_src_flag(backend_a, backend_a_with_extra_models): + backend_a.sync_from(backend_a_with_extra_models, flags=DSyncFlags.SKIP_UNMATCHED_SRC) # New objects should not have been created - assert baseline.get(baseline.site, "lax") is None - assert baseline.get(baseline.device, "nyc-spine3") is None - assert "nyc-spine3" not in baseline.get(baseline.site, "nyc").devices + assert backend_a.get(backend_a.site, "lax") is None + assert backend_a.get(backend_a.device, "nyc-spine3") is None + assert "nyc-spine3" not in backend_a.get(backend_a.site, "nyc").devices - baseline.sync_from(missing_models, flags=DSyncFlags.SKIP_UNMATCHED_DST) + +def test_dsync_sync_with_skip_unmatched_dst_flag(backend_a, backend_a_minus_some_models): + backend_a.sync_from(backend_a_minus_some_models, flags=DSyncFlags.SKIP_UNMATCHED_DST) # Objects should not have been deleted - assert baseline.get(baseline.site, "rdu") is not None - assert baseline.get(baseline.device, "sfo-spine2") is not None - assert "sfo-spine2" in baseline.get(baseline.site, "sfo").devices + assert backend_a.get(backend_a.site, "rdu") is not None + assert backend_a.get(backend_a.device, "sfo-spine2") is not None + assert "sfo-spine2" in backend_a.get(backend_a.site, "sfo").devices + + +def test_dsync_diff_with_ignore_flag_on_source_models(backend_a, backend_a_with_extra_models): + # Directly ignore the extra source site + backend_a_with_extra_models.get(backend_a_with_extra_models.site, "lax").model_flags |= DSyncModelFlags.IGNORE + # Ignore any diffs on source site NYC, which should extend to its child nyc-spine3 device + backend_a_with_extra_models.get(backend_a_with_extra_models.site, "nyc").model_flags |= DSyncModelFlags.IGNORE + + diff = backend_a.diff_from(backend_a_with_extra_models) + diff.print_detailed() + assert not diff.has_diffs() + + +def test_dsync_diff_with_ignore_flag_on_target_models(backend_a, backend_a_minus_some_models): + # Directly ignore the extra target site + backend_a.get(backend_a.site, "rdu").model_flags |= DSyncModelFlags.IGNORE + # Ignore any diffs on target site SFO, which should extend to its child sfo-spine2 device + backend_a.get(backend_a.site, "sfo").model_flags |= DSyncModelFlags.IGNORE + + diff = backend_a.diff_from(backend_a_minus_some_models) + diff.print_detailed() + assert not diff.has_diffs() + + +def test_dsync_sync_skip_children_on_delete(backend_a): + class NoDeleteInterface(Interface): + """Interface that shouldn't be deleted directly.""" + + def delete(self): + raise RuntimeError("Don't delete me, bro!") + + class NoDeleteInterfaceDSync(BackendA): + """BackendA, but using NoDeleteInterface.""" + + interface = NoDeleteInterface + + extra_models = NoDeleteInterfaceDSync() + extra_models.load() + extra_device = extra_models.device(name="nyc-spine3", site_name="nyc", role="spine") + extra_device.model_flags |= DSyncModelFlags.SKIP_CHILDREN_ON_DELETE + extra_models.get(extra_models.site, "nyc").add_child(extra_device) + extra_models.add(extra_device) + extra_interface = extra_models.interface(name="eth0", device_name="nyc-spine3") + extra_device.add_child(extra_interface) + extra_models.add(extra_interface) + assert extra_models.get(extra_models.interface, "nyc-spine3__eth0") is not None + + # NoDeleteInterface.delete() should not be called since we're deleting its parent only + extra_models.sync_from(backend_a) + # The extra interface should have been removed from the DSync without calling its delete() method + assert extra_models.get(extra_models.interface, extra_interface.get_unique_id()) is None + # The sync should be complete, regardless + diff = extra_models.diff_from(backend_a) + diff.print_detailed() + assert not diff.has_diffs()