diff --git a/netbox_diode_plugin/api/common.py b/netbox_diode_plugin/api/common.py index e0152d5..9bcb6b2 100644 --- a/netbox_diode_plugin/api/common.py +++ b/netbox_diode_plugin/api/common.py @@ -2,16 +2,17 @@ # Copyright 2025 NetBox Labs Inc """Diode NetBox Plugin - API - Common types and utilities.""" -from collections import defaultdict import logging import uuid +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from django.apps import apps -from django.contrib.contenttypes.fields import GenericRelation, GenericForeignKey +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError +from django.db import models from rest_framework import status logger = logging.getLogger("netbox.diode_data") @@ -108,31 +109,7 @@ def validate(self) -> dict[str, list[str]]: if change.before: change_data.update(change.before) - # check that there is some value for every required - # reference field, but don't validate the actual reference. - excluded_relation_fields = [] - rel_errors = defaultdict(list) - for f in model._meta.get_fields(): - if isinstance(f, (GenericRelation, GenericForeignKey)): - excluded_relation_fields.append(f.name) - continue - if not f.is_relation: - continue - field_name = f.name - excluded_relation_fields.append(field_name) - - if hasattr(f, "related_model") and f.related_model == ContentType: - change_data.pop(field_name, None) - base_field = field_name[:-5] - excluded_relation_fields.append(base_field + "_id") - value = change_data.pop(base_field + "_id", None) - else: - value = change_data.pop(field_name, None) - - if not f.null and not f.blank and not f.many_to_many: - # this field is a required relation... - if value is None: - rel_errors[f.name].append(f"Field {f.name} is required") + excluded_relation_fields, rel_errors = self._validate_relations(change_data, model) if rel_errors: errors[change.object_type] = rel_errors @@ -144,6 +121,36 @@ def validate(self) -> dict[str, list[str]]: return errors or None + def _validate_relations(self, change_data: dict, model: models.Model) -> tuple[list[str], dict]: + # check that there is some value for every required + # reference field, but don't validate the actual reference. + # the fields are removed from the change_data so that other + # fields can be validated by instantiating the model. + excluded_relation_fields = [] + rel_errors = defaultdict(list) + for f in model._meta.get_fields(): + if isinstance(f, (GenericRelation, GenericForeignKey)): + excluded_relation_fields.append(f.name) + continue + if not f.is_relation: + continue + field_name = f.name + excluded_relation_fields.append(field_name) + + if hasattr(f, "related_model") and f.related_model == ContentType: + change_data.pop(field_name, None) + base_field = field_name[:-5] + excluded_relation_fields.append(base_field + "_id") + value = change_data.pop(base_field + "_id", None) + else: + value = change_data.pop(field_name, None) + + if not f.null and not f.blank and not f.many_to_many: + # this field is a required relation... + if value is None: + rel_errors[f.name].append(f"Field {f.name} is required") + return excluded_relation_fields, rel_errors + @dataclass class ChangeSetResult: diff --git a/netbox_diode_plugin/api/matcher.py b/netbox_diode_plugin/api/matcher.py index 5f098c0..e4c8e62 100644 --- a/netbox_diode_plugin/api/matcher.py +++ b/netbox_diode_plugin/api/matcher.py @@ -9,6 +9,7 @@ from typing import Type from core.models import ObjectType as NetBoxType +from django.conf import settings from django.contrib.contenttypes.fields import ContentType from django.core.exceptions import FieldDoesNotExist from django.db import models @@ -30,11 +31,44 @@ _LOGICAL_MATCHERS = { "dcim.macaddress": lambda: [ ObjectMatchCriteria( - # consider a matching mac address within the same parent object - # to be the same object although not technically required to be. fields=("mac_address", "assigned_object_type", "assigned_object_id"), name="logical_mac_address_within_parent", model_class=get_object_type_model("dcim.macaddress"), + condition=Q(assigned_object_id__isnull=False), + ), + ObjectMatchCriteria( + fields=("mac_address", "assigned_object_type", "assigned_object_id"), + name="logical_mac_address_within_parent", + model_class=get_object_type_model("dcim.macaddress"), + condition=Q(assigned_object_id__isnull=True), + ), + ], + "ipam.ipaddress": lambda: [ + ObjectMatchCriteria( + fields=("address", ), + name="logical_ip_address_global_no_vrf", + model_class=get_object_type_model("ipam.ipaddress"), + condition=Q(vrf__isnull=True), + ), + ObjectMatchCriteria( + fields=("address", "assigned_object_type", "assigned_object_id"), + name="logical_ip_address_within_vrf", + model_class=get_object_type_model("ipam.ipaddress"), + condition=Q(vrf__isnull=False) + ), + ], + "ipam.prefix": lambda: [ + ObjectMatchCriteria( + fields=("prefix",), + name="logical_prefix_global_no_vrf", + model_class=get_object_type_model("ipam.prefix"), + condition=Q(vrf__isnull=True), + ), + ObjectMatchCriteria( + fields=("prefix", "vrf_id"), + name="logical_prefix_within_vrf", + model_class=get_object_type_model("ipam.prefix"), + condition=Q(vrf__isnull=False), ), ], } @@ -404,22 +438,3 @@ def find_existing_object(data: dict, object_type: str): logger.error(f" -> No object found for matcher {matcher.name}") logger.error(" * No matchers found an existing object") return None - -def merge_data(a: dict, b: dict) -> dict: - """ - Merges two structures. - - If there are any conflicts, an error is raised. - Ignores conflicts in fields that start with an underscore, - preferring a's value. - """ - if a is None or b is None: - raise ValueError("Cannot merge None values") - merged = a.copy() - for k, v in b.items(): - if k.startswith("_"): - continue - if k in merged and merged[k] != v: - raise ValueError(f"Conflict merging {a} and {b} on {k}: {merged[k]} and {v}") - merged[k] = v - return merged diff --git a/netbox_diode_plugin/api/transformer.py b/netbox_diode_plugin/api/transformer.py index 5f0e699..12e3518 100644 --- a/netbox_diode_plugin/api/transformer.py +++ b/netbox_diode_plugin/api/transformer.py @@ -10,11 +10,12 @@ from functools import lru_cache from uuid import uuid4 +import graphlib from django.core.exceptions import ValidationError from django.utils.text import slugify -from .common import UnresolvedReference -from .matcher import find_existing_object, fingerprint, merge_data +from .common import ChangeSetException, UnresolvedReference +from .matcher import find_existing_object, fingerprint from .plugin_utils import get_json_ref_info, get_primary_value logger = logging.getLogger("netbox.diode_data") @@ -53,6 +54,9 @@ def _nested_context(object_type, uuid, field_name): _IS_CIRCULAR_REFERENCE = { "dcim.interface": frozenset(["primary_mac_address"]), "virtualization.vminterface": frozenset(["primary_mac_address"]), + "dcim.device": frozenset(["primary_ip4", "primary_ip6"]), + "dcim.virtualdevicecontext": frozenset(["primary_ip4", "primary_ip6"]), + "virtualization.virtualmachine": frozenset(["primary_ip4", "primary_ip6"]), } def _is_circular_reference(object_type, field_name): @@ -66,38 +70,52 @@ def transform_proto_json(proto_json: dict, object_type: str, supported_models: d a certain form of deduplication and resolution of existing objects. """ entities = _transform_proto_json_1(proto_json, object_type) - logger.error(f"_transform_proto_json_1: {json.dumps(entities, default=lambda o: str(o), indent=4)}") + logger.error(f"_transform_proto_json_1 entities: {json.dumps(entities, default=lambda o: str(o), indent=4)}") + entities = _topo_sort(entities) + logger.error(f"_topo_sort: {json.dumps(entities, default=lambda o: str(o), indent=4)}") deduplicated = _fingerprint_dedupe(entities) logger.error(f"_fingerprint_dedupe: {json.dumps(deduplicated, default=lambda o: str(o), indent=4)}") + deduplicated = _topo_sort(deduplicated) + logger.error(f"_topo_sort: {json.dumps(deduplicated, default=lambda o: str(o), indent=4)}") _set_slugs(deduplicated, supported_models) logger.error(f"_set_slugs: {json.dumps(deduplicated, default=lambda o: str(o), indent=4)}") resolved = _resolve_existing_references(deduplicated) logger.error(f"_resolve_references: {json.dumps(resolved, default=lambda o: str(o), indent=4)}") _set_defaults(resolved, supported_models) logger.error(f"_set_defaults: {json.dumps(resolved, default=lambda o: str(o), indent=4)}") + + # handle post-create steps output = _handle_post_creates(resolved) - logger.error(f"_merge_post_creates: {json.dumps(output, default=lambda o: str(o), indent=4)}") + logger.error(f"_handle_post_creates: {json.dumps(output, default=lambda o: str(o), indent=4)}") _check_unresolved_refs(output) + for entity in output: + entity.pop('_refs', None) + return output -def _transform_proto_json_1(proto_json: dict, object_type: str, context=None, existing=None) -> list[dict]: +def _transform_proto_json_1(proto_json: dict, object_type: str, context=None) -> list[dict]: # noqa: C901 uuid = str(uuid4()) - transformed = { + node = { "_object_type": object_type, "_uuid": uuid, + "_refs": set(), } + + # context pushed down from parent nodes if context is not None: - transformed.update(context) - existing = existing or {} - entities = [transformed] + for k, v in context.items(): + node[k] = v + if isinstance(v, UnresolvedReference): + node['_refs'].add(v.uuid) - post_create = {} + nodes = [node] + post_create = None for key, value in proto_json.items(): ref_info = get_json_ref_info(object_type, key) if ref_info is None: - transformed[_camel_to_snake_case(key)] = copy.deepcopy(value) + node[_camel_to_snake_case(key)] = copy.deepcopy(value) continue nested_context = _nested_context(object_type, uuid, ref_info.field_name) @@ -105,50 +123,74 @@ def _transform_proto_json_1(proto_json: dict, object_type: str, context=None, ex is_circular = _is_circular_reference(object_type, field_name) if ref_info.is_generic: - transformed[field_name + "_type"] = ref_info.object_type + node[field_name + "_type"] = ref_info.object_type field_name = field_name + "_id" - nested_refs = [] + refs = [] ref_value = None if isinstance(value, list): ref_value = [] for item in value: nested = _transform_proto_json_1(item, ref_info.object_type, nested_context) - nested_refs += nested - ref = nested[-1] + nodes += nested + ref_uuid = nested[0]['_uuid'] ref_value.append(UnresolvedReference( object_type=ref_info.object_type, - uuid=ref['_uuid'], + uuid=ref_uuid, )) + refs.append(ref_uuid) else: - nested_refs = _transform_proto_json_1(value, ref_info.object_type, nested_context) - ref = nested_refs[-1] + nested = _transform_proto_json_1(value, ref_info.object_type, nested_context) + nodes += nested + ref_uuid = nested[0]['_uuid'] ref_value = UnresolvedReference( object_type=ref_info.object_type, - uuid=ref['_uuid'], + uuid=ref_uuid, ) + refs.append(ref_uuid) + if is_circular: + if post_create is None: + post_create = { + "_uuid": str(uuid4()), + "_object_type": object_type, + "_refs": set(), + "_instance": node['_uuid'], + "_is_post_create": True, + } post_create[field_name] = ref_value - entities = entities + nested_refs - else: - transformed[field_name] = ref_value - entities = nested_refs + entities - - # if there are fields that must be deferred until after the object is created, - # add a new entity with the post-create data. eg a child object that references - # this object and is also referenced by this object such as primary mac address - # on an interface. - # if this object already exists, two steps are not needed, and this will be - # simplified in a later pass. - if len(post_create) > 0: - post_create_uuid = str(uuid4()) - post_create['_uuid'] = post_create_uuid - post_create['_instance'] = uuid - post_create['_object_type'] = object_type - transformed['_post_create'] = post_create_uuid - entities.append(post_create) - - return entities + post_create['_refs'].update(refs) + post_create['_refs'].add(node['_uuid']) + continue + + node[field_name] = ref_value + node['_refs'].update(refs) + + if post_create: + nodes.append(post_create) + + return nodes + + +def _topo_sort(entities: list[dict]) -> list[dict]: + """Topologically sort entities by reference.""" + by_uuid = {e['_uuid']: e for e in entities} + graph = defaultdict(set) + for entity in entities: + graph[entity['_uuid']] = entity['_refs'].copy() + + try: + ts = graphlib.TopologicalSorter(graph) + order = tuple(ts.static_order()) + return [by_uuid[uuid] for uuid in order] + except graphlib.CycleError as e: + # TODO the cycle error references the cycle here ... + raise ChangeSetException(f"Circular reference in entities: {e}", errors={ + "__all__": { + "message": "Unable to resolve circular reference in entities", + } + }) + def _set_defaults(entities: list[dict], supported_models: dict): for entity in entities: @@ -178,13 +220,23 @@ def _generate_slug(object_type, data): return None def _fingerprint_dedupe(entities: list[dict]) -> list[dict]: + """ + Deduplicates/merges entities by fingerprint. + + *list must be in topo order by reference already* + """ by_fp = {} deduplicated = [] new_refs = {} # uuid -> uuid for entity in entities: - fp = fingerprint(entity, entity['_object_type']) - existing = by_fp.get(fp) + if entity.get('_is_post_create'): + fp = entity['_uuid'] + existing = None + else: + fp = fingerprint(entity, entity['_object_type']) + existing = by_fp.get(fp) + if existing is None: logger.debug(" * entity is new.") new_entity = copy.deepcopy(entity) @@ -194,13 +246,39 @@ def _fingerprint_dedupe(entities: list[dict]) -> list[dict]: else: logger.debug(" * entity already exists.") new_refs[entity['_uuid']] = existing['_uuid'] - merged = merge_data(existing, entity) + merged = _merge_nodes(existing, entity) _update_unresolved_refs(merged, new_refs) by_fp[fp] = merged return [by_fp[fp] for fp in deduplicated] +def _merge_nodes(a: dict, b: dict) -> dict: + """ + Merges two nodes. + + If there are any conflicts, an error is raised. + Ignores conflicts in fields that start with an underscore, + preferring a's value. + """ + merged = copy.deepcopy(a) + merged['_refs'] = a['_refs'] | b['_refs'] + + for k, v in b.items(): + if k.startswith("_"): + continue + if k in merged and merged[k] != v: + raise ValueError(f"Conflict merging {a} and {b} on {k}: {merged[k]} and {v}") + merged[k] = v + return merged + + def _update_unresolved_refs(entity, new_refs): + if entity.get('_is_post_create'): + instance_uuid = entity['_instance'] + entity['_instance'] = new_refs.get(instance_uuid, instance_uuid) + + entity['_refs'] = {new_refs.get(r,r) for r in entity['_refs']} + for k, v in entity.items(): if isinstance(v, UnresolvedReference) and v.uuid in new_refs: v.uuid = new_refs[v.uuid] @@ -274,27 +352,31 @@ def cleanup_unresolved_references(data: dict) -> list[str]: def _handle_post_creates(entities: list[dict]) -> list[str]: """Merges any unnecessary post-create steps for existing objects.""" - by_uuid = {x['_uuid']: x for x in entities} + by_uuid = {e['_uuid']: (i, e) for i, e in enumerate(entities)} out = [] for entity in entities: - post_create = entity.pop('_post_create', None) - if post_create is None: + is_post_create = entity.pop('_is_post_create', False) + if not is_post_create: out.append(entity) continue - post_create = by_uuid[post_create] - if entity.get('_instance') is not None: - # this entity has a post-create, but it has already been - # created. in this case we can just merge this entity into - # the post-create entity and skip it without worrying about - # references to it. - post_create.update(entity) + instance = entity.get('_instance') + prior_index, prior_entity = by_uuid[instance] + + # a post create can be merged whenever the entities it relies on + # already exist (were resolved) or there are no dependencies between + # the object being updated and the post-create. + can_merge = all( + by_uuid[r][1].get('_instance') is not None + for r in entity['_refs'] + ) or sorted(by_uuid[r][0] for r in entity['_refs'])[-1] == prior_index + + if can_merge: + prior_entity.update([x for x in entity.items() if not x[0].startswith('_')]) else: - # this entity will be created. - # in this case we need to fix up the identifier in the post-create - # to refer to the created object. - post_create['id'] = entity['id'] + entity['id'] = prior_entity['id'] out.append(entity) + return out def _check_unresolved_refs(entities: list[dict]) -> list[str]: @@ -304,4 +386,11 @@ def _check_unresolved_refs(entities: list[dict]) -> list[str]: for k, v in e.items(): if isinstance(v, UnresolvedReference): if (v.object_type, v.uuid) not in seen: - raise ValueError(f"Unresolved reference {v} in {e} does not refer to a prior created object (circular reference?)") + raise ChangeSetException( + f"Unresolved reference {v} in {e} does not refer to a prior created object (circular reference?)", + errors={ + e['_object_type']: { + k: ["unable to resolve reference"], + } + } + ) diff --git a/netbox_diode_plugin/tests/test_api_diff_and_apply.py b/netbox_diode_plugin/tests/test_api_diff_and_apply.py index 19793c5..75bec8a 100644 --- a/netbox_diode_plugin/tests/test_api_diff_and_apply.py +++ b/netbox_diode_plugin/tests/test_api_diff_and_apply.py @@ -4,8 +4,9 @@ import logging -from dcim.models import Interface, Site +from dcim.models import Device, Interface, Site from django.contrib.auth import get_user_model +from ipam.models import IPAddress from rest_framework import status from users.models import Token from utilities.testing import APITestCase @@ -81,6 +82,45 @@ def test_generate_diff_and_apply_create_interface_with_primay_mac_address(self): new_interface = Interface.objects.get(name="Interface 1x") self.assertEqual(new_interface.primary_mac_address.mac_address, "00:00:00:00:00:01") + def test_generate_diff_and_apply_create_device_with_primary_ip4(self): + """Test generate diff and apply create device with primary ip4.""" + payload = { + "timestamp": 1, + "object_type": "ipam.ipaddress", + "entity": { + "ipAddress": { + "address": "192.168.1.1", + "assignedObjectInterface": { + "name": "Interface 2x", + "type": "1000base-t", + "device": { + "name": "Device 2x", + "role": { + "name": "Role ABC", + }, + "site": { + "name": "Site ABC", + }, + "deviceType": { + "manufacturer": { + "name": "Manufacturer A", + }, + "model": "Device Type A", + }, + "primaryIp4": { + "address": "192.168.1.1", + }, + }, + }, + }, + }, + } + + _, response = self.diff_and_apply(payload) + new_ipaddress = IPAddress.objects.get(address="192.168.1.1") + self.assertEqual(new_ipaddress.assigned_object.name, "Interface 2x") + device = Device.objects.get(name="Device 2x") + self.assertEqual(device.primary_ip4.pk, new_ipaddress.pk) def diff_and_apply(self, payload): """Diff and apply the payload."""