diff --git a/netbox_diode_plugin/api/applier.py b/netbox_diode_plugin/api/applier.py index 3f1b041..0267302 100644 --- a/netbox_diode_plugin/api/applier.py +++ b/netbox_diode_plugin/api/applier.py @@ -11,7 +11,7 @@ from django.db import models from rest_framework.exceptions import ValidationError as ValidationError -from .common import NON_FIELD_ERRORS, Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType +from .common import NON_FIELD_ERRORS, Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, error_from_validation_error from .plugin_utils import get_object_type_model, legal_fields from .supported_models import get_serializer_for_model @@ -35,7 +35,7 @@ def apply_changeset(change_set: ChangeSet, request) -> ChangeSetResult: data = _pre_apply(model_class, change, created) _apply_change(data, model_class, change, created, request) except ValidationError as e: - raise _err_from_validation_error(e, object_type) + raise error_from_validation_error(e, object_type) except ObjectDoesNotExist: raise _err(f"{object_type} with id {change.object_id} does not exist", object_type, "object_id") except TypeError as e: @@ -129,17 +129,3 @@ def _err(message, object_name, field): object_name = "__all__" return ChangeSetException(message, errors={object_name: {field: [message]}}) -def _err_from_validation_error(e, object_name): - errors = {} - if e.detail: - if isinstance(e.detail, dict): - errors[object_name] = e.detail - elif isinstance(e.detail, (list, tuple)): - errors[object_name] = { - NON_FIELD_ERRORS: e.detail - } - else: - errors[object_name] = { - NON_FIELD_ERRORS: [e.detail] - } - return ChangeSetException("validation error", errors=errors) diff --git a/netbox_diode_plugin/api/common.py b/netbox_diode_plugin/api/common.py index 65a9a1f..8c735a9 100644 --- a/netbox_diode_plugin/api/common.py +++ b/netbox_diode_plugin/api/common.py @@ -235,3 +235,20 @@ class AutoSlug: field_name: str value: str + + +def error_from_validation_error(e, object_name): + """Convert a drf ValidationError to a ChangeSetException.""" + errors = {} + if e.detail: + if isinstance(e.detail, dict): + errors[object_name] = e.detail + elif isinstance(e.detail, (list, tuple)): + errors[object_name] = { + NON_FIELD_ERRORS: e.detail + } + else: + errors[object_name] = { + NON_FIELD_ERRORS: [e.detail] + } + return ChangeSetException("validation error", errors=errors) diff --git a/netbox_diode_plugin/api/differ.py b/netbox_diode_plugin/api/differ.py index 026f3c9..a3121c0 100644 --- a/netbox_diode_plugin/api/differ.py +++ b/netbox_diode_plugin/api/differ.py @@ -9,8 +9,9 @@ from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from utilities.data import shallow_compare_dict +from django.db.backends.postgresql.psycopg_any import NumericRange -from .common import Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, UnresolvedReference +from .common import Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, error_from_validation_error from .plugin_utils import get_primary_value, legal_fields from .supported_models import extract_supported_models from .transformer import cleanup_unresolved_references, set_custom_field_defaults, transform_proto_json @@ -78,29 +79,23 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901 else: cfmap[cf.name] = cf.serialize(value) prechange_data["custom_fields"] = cfmap - + prechange_data = _harmonize_formats(prechange_data) return prechange_data -def _harmonize_formats(prechange_data: dict, postchange_data: dict): - for k, v in prechange_data.items(): - if k.startswith('_'): - continue - if isinstance(v, datetime.datetime): - prechange_data[k] = v.strftime("%Y-%m-%dT%H:%M:%SZ") - elif isinstance(v, datetime.date): - prechange_data[k] = v.strftime("%Y-%m-%d") - elif isinstance(v, int) and k in postchange_data: - val = postchange_data[k] - if isinstance(val, UnresolvedReference): - continue - try: - postchange_data[k] = int(val) - except Exception: - continue - elif isinstance(v, dict): - _harmonize_formats(v, postchange_data.get(k, {})) +def _harmonize_formats(prechange_data): + if isinstance(prechange_data, dict): + return {k: _harmonize_formats(v) for k, v in prechange_data.items()} + if isinstance(prechange_data, (list, tuple)): + return [_harmonize_formats(v) for v in prechange_data] + if isinstance(prechange_data, datetime.datetime): + return prechange_data.strftime("%Y-%m-%dT%H:%M:%SZ") + if isinstance(prechange_data, datetime.date): + return prechange_data.strftime("%Y-%m-%d") + if isinstance(prechange_data, NumericRange): + return (prechange_data.lower, prechange_data.upper-1) + return prechange_data def clean_diff_data(data: dict, exclude_empty_values: bool = True) -> dict: """Clean diff data by removing null values.""" @@ -170,8 +165,19 @@ def sort_dict_recursively(d): return sorted([sort_dict_recursively(item) for item in d], key=str) return d - def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult: + """Generate a changeset for an entity.""" + try: + return _generate_changeset(entity, object_type) + except ChangeSetException: + raise + except ValidationError as e: + raise error_from_validation_error(e, object_type) + except Exception as e: + logger.error(f"Unexpected error generating changeset: {e}") + raise + +def _generate_changeset(entity: dict, object_type: str) -> ChangeSetResult: """Generate a changeset for an entity.""" change_set = ChangeSet() @@ -196,7 +202,6 @@ def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult: # this is also important for custom fields because they do not appear to # respsect paritial update serialization. entity = _partially_merge(prechange_data, entity, instance) - _harmonize_formats(prechange_data, entity) changed_data = shallow_compare_dict( prechange_data, entity, ) diff --git a/netbox_diode_plugin/api/plugin_utils.py b/netbox_diode_plugin/api/plugin_utils.py index 9a08f33..d85038c 100644 --- a/netbox_diode_plugin/api/plugin_utils.py +++ b/netbox_diode_plugin/api/plugin_utils.py @@ -1,16 +1,20 @@ """Diode plugin helpers.""" # Generated code. DO NOT EDIT. -# Timestamp: 2025-04-12 15:25:46Z +# Timestamp: 2025-04-13 13:20:10Z from dataclasses import dataclass +import datetime +import decimal from functools import lru_cache +import logging from typing import Type from core.models import ObjectType as NetBoxType from django.contrib.contenttypes.models import ContentType from django.db import models +logger = logging.getLogger(__name__) @lru_cache(maxsize=256) def get_object_type_model(object_type: str) -> Type[models.Model]: @@ -995,4 +999,197 @@ def legal_fields(object_type: str|Type[models.Model]) -> frozenset[str]: def get_primary_value(data: dict, object_type: str) -> str|None: field = _OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP.get(object_type, 'name') - return data.get(field) \ No newline at end of file + return data.get(field) + + +def transform_timestamp_to_date_only(value: str) -> str: + return datetime.datetime.fromisoformat(value).strftime('%Y-%m-%d') + +def transform_float_to_decimal(value: float) -> decimal.Decimal: + try: + return decimal.Decimal(str(value)) + except decimal.InvalidOperation: + raise ValueError(f'Invalid decimal value: {value}') + +def int_from_int64string(value: str) -> int: + return int(value) + +def collect_integer_pairs(value: list[int]) -> list[tuple[int, int]]: + if len(value) % 2 != 0: + raise ValueError('Array must have an even number of elements') + return [(value[i], value[i+1]) for i in range(0, len(value), 2)] + +def for_all(transform): + def wrapper(value): + if isinstance(value, list): + return [transform(v) for v in value] + return transform(value) + return wrapper + +_FORMAT_TRANSFORMATIONS = { + 'circuits.circuit': { + 'commit_rate': int_from_int64string, + 'distance': transform_float_to_decimal, + 'install_date': transform_timestamp_to_date_only, + 'termination_date': transform_timestamp_to_date_only, + }, + 'circuits.circuittermination': { + 'port_speed': int_from_int64string, + 'upstream_speed': int_from_int64string, + }, + 'dcim.cable': { + 'length': transform_float_to_decimal, + }, + 'dcim.consoleport': { + 'speed': int_from_int64string, + }, + 'dcim.consoleserverport': { + 'speed': int_from_int64string, + }, + 'dcim.device': { + 'latitude': transform_float_to_decimal, + 'longitude': transform_float_to_decimal, + 'position': transform_float_to_decimal, + 'vc_position': int_from_int64string, + 'vc_priority': int_from_int64string, + }, + 'dcim.devicetype': { + 'u_height': transform_float_to_decimal, + 'weight': transform_float_to_decimal, + }, + 'dcim.frontport': { + 'rear_port_position': int_from_int64string, + }, + 'dcim.interface': { + 'mtu': int_from_int64string, + 'rf_channel_frequency': transform_float_to_decimal, + 'rf_channel_width': transform_float_to_decimal, + 'speed': int_from_int64string, + 'tx_power': int_from_int64string, + }, + 'dcim.moduletype': { + 'weight': transform_float_to_decimal, + }, + 'dcim.powerfeed': { + 'amperage': int_from_int64string, + 'max_utilization': int_from_int64string, + 'voltage': int_from_int64string, + }, + 'dcim.powerport': { + 'allocated_draw': int_from_int64string, + 'maximum_draw': int_from_int64string, + }, + 'dcim.rack': { + 'max_weight': int_from_int64string, + 'mounting_depth': int_from_int64string, + 'outer_depth': int_from_int64string, + 'outer_width': int_from_int64string, + 'starting_unit': int_from_int64string, + 'u_height': int_from_int64string, + 'weight': transform_float_to_decimal, + 'width': int_from_int64string, + }, + 'dcim.rackreservation': { + 'units': for_all(int_from_int64string), + }, + 'dcim.racktype': { + 'max_weight': int_from_int64string, + 'mounting_depth': int_from_int64string, + 'outer_depth': int_from_int64string, + 'outer_width': int_from_int64string, + 'starting_unit': int_from_int64string, + 'u_height': int_from_int64string, + 'weight': transform_float_to_decimal, + 'width': int_from_int64string, + }, + 'dcim.rearport': { + 'positions': int_from_int64string, + }, + 'dcim.site': { + 'latitude': transform_float_to_decimal, + 'longitude': transform_float_to_decimal, + }, + 'dcim.virtualdevicecontext': { + 'identifier': int_from_int64string, + }, + 'ipam.aggregate': { + 'date_added': transform_timestamp_to_date_only, + }, + 'ipam.asn': { + 'asn': int_from_int64string, + }, + 'ipam.asnrange': { + 'end': int_from_int64string, + 'start': int_from_int64string, + }, + 'ipam.fhrpgroup': { + 'group_id': int_from_int64string, + }, + 'ipam.fhrpgroupassignment': { + 'priority': int_from_int64string, + }, + 'ipam.role': { + 'weight': int_from_int64string, + }, + 'ipam.service': { + 'ports': for_all(int_from_int64string), + }, + 'ipam.vlan': { + 'vid': int_from_int64string, + }, + 'ipam.vlangroup': { + 'vid_ranges': collect_integer_pairs, + }, + 'ipam.vlantranslationrule': { + 'local_vid': int_from_int64string, + 'remote_vid': int_from_int64string, + }, + 'virtualization.virtualdisk': { + 'size': int_from_int64string, + }, + 'virtualization.virtualmachine': { + 'disk': int_from_int64string, + 'memory': int_from_int64string, + 'vcpus': transform_float_to_decimal, + }, + 'virtualization.vminterface': { + 'mtu': int_from_int64string, + }, + 'vpn.ikepolicy': { + 'version': int_from_int64string, + }, + 'vpn.ikeproposal': { + 'group': int_from_int64string, + 'sa_lifetime': int_from_int64string, + }, + 'vpn.ipsecpolicy': { + 'pfs_group': int_from_int64string, + }, + 'vpn.ipsecproposal': { + 'sa_lifetime_data': int_from_int64string, + 'sa_lifetime_seconds': int_from_int64string, + }, + 'vpn.l2vpn': { + 'identifier': int_from_int64string, + }, + 'vpn.tunnel': { + 'tunnel_id': int_from_int64string, + }, + 'wireless.wirelesslink': { + 'distance': transform_float_to_decimal, + }, +} + +def apply_format_transformations(data: dict, object_type: str): + for key, transform in _FORMAT_TRANSFORMATIONS.get(object_type, {}).items(): + val = data.get(key, None) + if val is None: + continue + try: + data[key] = transform(val) + except ValidationError: + raise + except ValueError as e: + raise ValidationError(f'Invalid value {val} for field {key} in {object_type}: {e}') + except Exception as e: + raise ValidationError(f'Invalid value {val} for field {key} in {object_type}') \ No newline at end of file diff --git a/netbox_diode_plugin/api/transformer.py b/netbox_diode_plugin/api/transformer.py index 10234f0..830e5a7 100644 --- a/netbox_diode_plugin/api/transformer.py +++ b/netbox_diode_plugin/api/transformer.py @@ -18,7 +18,13 @@ from .common import AutoSlug, ChangeSetException, UnresolvedReference from .matcher import find_existing_object, fingerprint -from .plugin_utils import CUSTOM_FIELD_OBJECT_REFERENCE_TYPE, get_json_ref_info, get_primary_value, legal_fields +from .plugin_utils import ( + CUSTOM_FIELD_OBJECT_REFERENCE_TYPE, + apply_format_transformations, + get_json_ref_info, + get_primary_value, + legal_fields, +) logger = logging.getLogger("netbox.diode_data") @@ -72,6 +78,7 @@ def transform_proto_json(proto_json: dict, object_type: str, supported_models: d """ entities = _transform_proto_json_1(proto_json, object_type) logger.debug(f"_transform_proto_json_1 entities: {json.dumps(entities, default=lambda o: str(o), indent=4)}") + entities = _topo_sort(entities) logger.debug(f"_topo_sort: {json.dumps(entities, default=lambda o: str(o), indent=4)}") deduplicated = _fingerprint_dedupe(entities) @@ -105,6 +112,7 @@ def _transform_proto_json_1(proto_json: dict, object_type: str, context=None) -> # handle camelCase protoJSON if provided... proto_json = _ensure_snake_case(proto_json, object_type) + apply_format_transformations(proto_json, object_type) # context pushed down from parent nodes if context is not None: diff --git a/netbox_diode_plugin/tests/test_api_diff_and_apply.py b/netbox_diode_plugin/tests/test_api_diff_and_apply.py index 6303735..2367309 100644 --- a/netbox_diode_plugin/tests/test_api_diff_and_apply.py +++ b/netbox_diode_plugin/tests/test_api_diff_and_apply.py @@ -3,11 +3,14 @@ """Diode NetBox Plugin - Tests.""" import datetime +import decimal import logging from uuid import uuid4 from core.models import ObjectType from dcim.models import Device, Interface, Site +from ipam.models import VLANGroup +from circuits.models import Circuit from django.contrib.auth import get_user_model from extras.models import CustomField from extras.models.customfields import CustomFieldTypeChoices @@ -72,6 +75,15 @@ def setUp(self): self.date_field.object_types.set([self.object_type]) self.date_field.save() + self.decimal_field = CustomField.objects.create( + name='mydecimal', + type=CustomFieldTypeChoices.TYPE_DECIMAL, + required=False, + unique=False, + ) + self.decimal_field.object_types.set([self.object_type]) + self.decimal_field.save() + def test_generate_diff_and_apply_create_interface_with_tags(self): """Test generate diff and apply create interface with tags.""" interface_uuid = str(uuid4()) @@ -416,6 +428,9 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): "myuuid": { "text": site_uuid, }, + "mydecimal": { + "decimal": 1234.567, + }, "some_json": { "json": '{"some_key": 9876543210}', }, @@ -428,6 +443,7 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): new_site = Site.objects.get(name="A New Custom Site") self.assertEqual(new_site.custom_field_data[self.uuid_field.name], site_uuid) self.assertEqual(new_site.custom_field_data[self.json_field.name], {"some_key": 9876543210}) + self.assertEqual(new_site.custom_field_data[self.decimal_field.name], 1234.567) payload = { "timestamp": 1, @@ -513,8 +529,52 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): diff = response1.json().get("change_set", {}) self.assertEqual(diff.get("changes", []), []) - def test_generate_diff_wrong_type_date(self): - """Test generate diff wrong type date.""" + def test_generate_diff_and_apply_circuit_with_install_date(self): + """Test generate diff and apply circuit with date.""" + circuit_uuid = str(uuid4()) + payload = { + "timestamp": 1, + "object_type": "circuits.circuit", + "entity": { + "circuit": { + "cid": f"Circuit {circuit_uuid}", + "install_date": "2026-01-01T00:00:00Z", + "provider": { + "name": f"Provider {uuid4()}", + }, + "type": { + "name": f"Ciruit Type {uuid4()}", + }, + }, + }, + } + + _, response = self.diff_and_apply(payload) + new_circuit = Circuit.objects.get(cid=f"Circuit {circuit_uuid}") + self.assertEqual(new_circuit.install_date, datetime.date(2026, 1, 1)) + + def test_generate_diff_and_apply_site_with_lat_lon(self): + """Test generate diff and apply site with lat and lon.""" + site_uuid = str(uuid4()) + payload = { + "timestamp": 1, + "object_type": "dcim.site", + "entity": { + "site": { + "name": f"Site {site_uuid}", + "latitude": 23.456, + "longitude": 78.910, + }, + }, + } + + _, response = self.diff_and_apply(payload) + new_site = Site.objects.get(name=f"Site {site_uuid}") + self.assertEqual(new_site.latitude, decimal.Decimal("23.456")) + self.assertEqual(new_site.longitude, decimal.Decimal("78.910")) + + def test_generate_diff_and_apply_wrong_type_date(self): + """Test generate diff and apply wrong type date.""" payload = { "timestamp": 1, "object_type": "dcim.site", @@ -542,6 +602,42 @@ def test_generate_diff_wrong_type_date(self): ) self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST) + def test_generate_diff_and_apply_vlan_group_with_vid_ranges(self): + """Test generate diff and apply vlan group vid ranges.""" + payload = { + "timestamp": 1, + "object_type": "ipam.vlangroup", + "entity": { + "vlan_group": { + "name": "VLAN Group 1", + "vid_ranges": [1,5,10,15], + }, + }, + } + _, response = self.diff_and_apply(payload) + new_vlan_group = VLANGroup.objects.get(name="VLAN Group 1") + self.assertEqual(new_vlan_group.vid_ranges[0].lower, 1) + self.assertEqual(new_vlan_group.vid_ranges[0].upper, 6) + self.assertEqual(new_vlan_group.vid_ranges[1].lower, 10) + self.assertEqual(new_vlan_group.vid_ranges[1].upper, 16) + + payload = { + "timestamp": 1, + "object_type": "ipam.vlangroup", + "entity": { + "vlan_group": { + "name": "VLAN Group 1", + "vid_ranges": [3,9,12,20], + }, + }, + } + _, response = self.diff_and_apply(payload) + new_vlan_group = VLANGroup.objects.get(name="VLAN Group 1") + self.assertEqual(new_vlan_group.vid_ranges[0].lower, 3) + self.assertEqual(new_vlan_group.vid_ranges[0].upper, 10) + self.assertEqual(new_vlan_group.vid_ranges[1].lower, 12) + self.assertEqual(new_vlan_group.vid_ranges[1].upper, 21) + def diff_and_apply(self, payload): """Diff and apply the payload."""