Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions netbox_diode_plugin/api/applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.db.utils import IntegrityError
from rest_framework.exceptions import ValidationError as ValidationError

from .common import NON_FIELD_ERRORS, Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, error_from_validation_error
Expand Down Expand Up @@ -41,10 +42,14 @@ def apply_changeset(change_set: ChangeSet, request) -> ChangeSetResult:
except TypeError as e:
# this indicates a problem in model validation (should raise ValidationError)
# but raised non-validation error (TypeError) -- we don't know which field trigged it.
logger.error(f"invalid data type for unspecified field (validation raised non-validation error): {data}: {e}")
raise _err("invalid data type for field", object_type, "__all__")
# ConstraintViolationError ?
# ...
import traceback
traceback.print_exc()
logger.error(f"validation raised TypeError error on unspecified field of {object_type}: {data}: {e}")
logger.error(traceback.format_exc())
raise _err("invalid data type for field (TypeError)", object_type, "__all__")
except IntegrityError as e:
logger.error(f"Integrity error {object_type}: {e} {data}")
raise _err(f"created a conflict with an existing {object_type}", object_type, "__all__")

return ChangeSetResult(
id=change_set.id,
Expand Down
34 changes: 33 additions & 1 deletion netbox_diode_plugin/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,30 @@
# Copyright 2025 NetBox Labs Inc
"""Diode NetBox Plugin - API - Common types and utilities."""

import datetime
import decimal
import logging
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum

import netaddr
from django.apps import apps
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.db import models
from django.db.backends.postgresql.psycopg_any import NumericRange
from extras.models import CustomField
from netaddr.eui import EUI
from rest_framework import status
from zoneinfo import ZoneInfo

logger = logging.getLogger("netbox.diode_data")

NON_FIELD_ERRORS = "__all__"
_TRACE = False

@dataclass
class UnresolvedReference:
Expand All @@ -43,6 +50,8 @@ def __hash__(self):

def __lt__(self, other):
"""Less than operator."""
if not isinstance(other, UnresolvedReference):
return False
return self.object_type < other.object_type or (self.object_type == other.object_type and self.uuid < other.uuid)


Expand Down Expand Up @@ -238,7 +247,7 @@ class AutoSlug:


def error_from_validation_error(e, object_name):
"""Convert a from rest_framework.exceptions.ValidationError to a ChangeSetException."""
"""Convert a from DRF ValidationError to a ChangeSetException."""
errors = {}
if e.detail:
if isinstance(e.detail, dict):
Expand All @@ -252,3 +261,26 @@ def error_from_validation_error(e, object_name):
NON_FIELD_ERRORS: [e.detail]
}
return ChangeSetException("validation error", errors=errors)

def harmonize_formats(data):
"""Puts all data in a format that can be serialized and compared."""
match data:
case None:
return None
case str() | int() | float() | bool() | decimal.Decimal() | UnresolvedReference():
return data
case dict():
return {k: harmonize_formats(v) if not k.startswith("_") else v for k, v in data.items()}
case list() | tuple():
return [harmonize_formats(v) for v in data]
case datetime.datetime():
return data.strftime("%Y-%m-%dT%H:%M:%SZ")
case datetime.date():
return data.strftime("%Y-%m-%d")
case NumericRange():
return (data.lower, data.upper-1)
case netaddr.IPNetwork() | EUI() | ZoneInfo():
return str(data)
case _:
logger.warning(f"Unknown type in harmonize_formats: {type(data)}")
return data
56 changes: 21 additions & 35 deletions netbox_diode_plugin/api/differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@

import netaddr
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.db.backends.postgresql.psycopg_any import NumericRange
from netaddr.eui import EUI
from rest_framework import serializers
from utilities.data import shallow_compare_dict

from .common import Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, error_from_validation_error
from .common import (
NON_FIELD_ERRORS,
Change,
ChangeSet,
ChangeSetException,
ChangeSetResult,
ChangeType,
error_from_validation_error,
harmonize_formats,
)
from .plugin_utils import get_primary_value, legal_fields
from .supported_models import extract_supported_models
from .transformer import cleanup_unresolved_references, set_custom_field_defaults, transform_proto_json
Expand All @@ -36,11 +45,15 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901

model = SUPPORTED_MODELS.get(object_type)
if not model:
raise ValidationError(f"Model {model_class.__name__} is not supported")
raise serializers.ValidationError({
NON_FIELD_ERRORS: [f"Model {model_class.__name__} is not supported"]
})

fields = model.get("fields", {})
if not fields:
raise ValidationError(f"Model {model_class.__name__} has no fields")
raise serializers.ValidationError({
NON_FIELD_ERRORS: [f"Model {model_class.__name__} has no fields"]
})

diode_fields = legal_fields(model_class)

Expand All @@ -52,9 +65,6 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901
if not hasattr(instance, field_name):
continue

if field_info["type"] == "ForeignKey" and field_info.get("is_many_to_one_rel", False):
continue

value = getattr(instance, field_name)
if hasattr(value, "all"): # Handle many-to-many and many-to-one relationships
# For any relationship that has an 'all' method, get all related objects' primary keys
Expand Down Expand Up @@ -82,33 +92,11 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901
else:
cfmap[cf.name] = cf.serialize(value)
prechange_data["custom_fields"] = cfmap
prechange_data = _harmonize_formats(prechange_data)
return prechange_data

prechange_data = harmonize_formats(prechange_data)

def _harmonize_formats(prechange_data):
if prechange_data is None:
return None
if isinstance(prechange_data, (str, int, float, bool, decimal.Decimal)):
return prechange_data
if isinstance(prechange_data, dict):
return {k: _harmonize_formats(v) for k, v in prechange_data.items()}
if isinstance(prechange_data, (list, tuple)):
return [_harmonize_formats(v) for v in prechange_data]
if isinstance(prechange_data, datetime.datetime):
return prechange_data.strftime("%Y-%m-%dT%H:%M:%SZ")
if isinstance(prechange_data, datetime.date):
return prechange_data.strftime("%Y-%m-%d")
if isinstance(prechange_data, NumericRange):
return (prechange_data.lower, prechange_data.upper-1)
if isinstance(prechange_data, netaddr.IPNetwork):
return str(prechange_data)
if isinstance(prechange_data, EUI):
return str(prechange_data)

logger.warning(f"Unknown type in prechange_data: {type(prechange_data)}")
return prechange_data


def clean_diff_data(data: dict, exclude_empty_values: bool = True) -> dict:
"""Clean diff data by removing null values."""
result = {}
Expand Down Expand Up @@ -139,7 +127,6 @@ def diff_to_change(
change_type = ChangeType.UPDATE if len(prechange_data) > 0 else ChangeType.CREATE
if change_type == ChangeType.UPDATE and not len(changed_attrs) > 0:
change_type = ChangeType.NOOP

primary_value = str(get_primary_value(prechange_data | postchange_data, object_type))
if primary_value is None:
primary_value = "(unnamed)"
Expand Down Expand Up @@ -173,8 +160,7 @@ def sort_dict_recursively(d):
if isinstance(d, dict):
return {k: sort_dict_recursively(v) for k, v in sorted(d.items())}
if isinstance(d, list):
# Convert all items to strings for comparison
return sorted([sort_dict_recursively(item) for item in d], key=str)
return [sort_dict_recursively(item) for item in d]
return d

def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult:
Expand All @@ -183,7 +169,7 @@ def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult:
return _generate_changeset(entity, object_type)
except ChangeSetException:
raise
except ValidationError as e:
except serializers.ValidationError as e:
raise error_from_validation_error(e, object_type)
except Exception as e:
logger.error(f"Unexpected error generating changeset: {e}")
Expand Down
Loading