diff --git a/netbox_diode_plugin/api/applier.py b/netbox_diode_plugin/api/applier.py index eed793b..101f30f 100644 --- a/netbox_diode_plugin/api/applier.py +++ b/netbox_diode_plugin/api/applier.py @@ -4,7 +4,6 @@ import logging -from dataclasses import dataclass, field from django.apps import apps from django.contrib.contenttypes.models import ContentType @@ -12,47 +11,14 @@ from django.db import models from rest_framework.exceptions import ValidationError as ValidationError -from .differ import Change, ChangeSet, ChangeType +from .common import NON_FIELD_ERRORS, Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType from .plugin_utils import get_object_type_model, legal_fields from .supported_models import get_serializer_for_model logger = logging.getLogger(__name__) -@dataclass -class ApplyChangeSetResult: - """A result of applying a change set.""" - - id: str - success: bool - errors: dict | None = field(default=None) - - def to_dict(self) -> dict: - """Convert the result to a dictionary.""" - return { - "id": self.id, - "success": self.success, - "errors": self.errors, - } - - -class ApplyChangeSetException(Exception): - """ApplyChangeSetException is raised when an error occurs while applying a change set.""" - - def __init__(self, message, errors=None): - """Initialize the exception.""" - super().__init__(message) - self.message = message - self.errors = errors or {} - - def __str__(self): - """Return the string representation of the exception.""" - if self.errors: - return f"{self.message}: {self.errors}" - return self.message - - -def apply_changeset(change_set: ChangeSet) -> ApplyChangeSetResult: +def apply_changeset(change_set: ChangeSet) -> ChangeSetResult: """Apply a change set.""" _validate_change_set(change_set) @@ -71,14 +37,12 @@ def apply_changeset(change_set: ChangeSet) -> ApplyChangeSetResult: except ValidationError as e: raise _err_from_validation_error(e, f"changes[{i}]") except ObjectDoesNotExist: - raise _err(f"{object_type} with id {change.object_id} does not exist", f"changes[{i}].object_id") + raise _err(f"{object_type} with id {change.object_id} does not exist", f"changes[{i}]", "object_id") # ConstraintViolationError ? # ... - return ApplyChangeSetResult( + return ChangeSetResult( id=change_set.id, - success=True, - errors=None, ) def _apply_change(data: dict, model_class: models.Model, change: Change, created: dict): @@ -129,27 +93,30 @@ def _pre_apply(model_class: models.Model, change: Change, created: dict): def _validate_change_set(change_set: ChangeSet): if not change_set.id: - raise _err("Change set ID is required", "id") + raise _err("Change set ID is required", "changeset","id") if not change_set.changes: - raise _err("Changes are required", "changes") + raise _err("Changes are required", "changeset", "changes") for i, change in enumerate(change_set.changes): if change.object_id is None and change.ref_id is None: - raise _err("Object ID or Ref ID must be provided", f"changes[{i}]") + raise _err("Object ID or Ref ID must be provided", f"changes[{i}]", NON_FIELD_ERRORS) if change.change_type not in ChangeType: - raise _err(f"Unsupported change type '{change.change_type}'", f"changes[{i}].change_type") + raise _err(f"Unsupported change type '{change.change_type}'", f"changes[{i}]", "change_type") -def _err(message, field): - return ApplyChangeSetException(message, errors={field: [message]}) +def _err(message, object_name, field): + return ChangeSetException(message, errors={object_name: {field: [message]}}) -def _err_from_validation_error(e, prefix): +def _err_from_validation_error(e, object_name): errors = {} if e.detail: if isinstance(e.detail, dict): - for k, v in e.detail.items(): - errors[f"{prefix}.{k}"] = v + errors[object_name] = e.detail elif isinstance(e.detail, (list, tuple)): - errors[prefix] = e.detail + errors[object_name] = { + NON_FIELD_ERRORS: e.detail + } else: - errors[prefix] = [e.detail] - return ApplyChangeSetException("validation error", errors=errors) + errors[object_name] = { + NON_FIELD_ERRORS: [e.detail] + } + return ChangeSetException("validation error", errors=errors) diff --git a/netbox_diode_plugin/api/common.py b/netbox_diode_plugin/api/common.py index 41011b4..e0152d5 100644 --- a/netbox_diode_plugin/api/common.py +++ b/netbox_diode_plugin/api/common.py @@ -1,9 +1,22 @@ #!/usr/bin/env python -# Copyright 2024 NetBox Labs Inc +# Copyright 2025 NetBox Labs Inc """Diode NetBox Plugin - API - Common types and utilities.""" -from dataclasses import dataclass +from collections import defaultdict +import logging +import uuid +from dataclasses import dataclass, field +from enum import Enum +from django.apps import apps +from django.contrib.contenttypes.fields import GenericRelation, GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ValidationError +from rest_framework import status + +logger = logging.getLogger("netbox.diode_data") + +NON_FIELD_ERRORS = "__all__" @dataclass class UnresolvedReference: @@ -29,3 +42,143 @@ def __hash__(self): def __lt__(self, other): """Less than operator.""" return self.object_type < other.object_type or (self.object_type == other.object_type and self.uuid < other.uuid) + + +class ChangeType(Enum): + """Change type enum.""" + + CREATE = "create" + UPDATE = "update" + NOOP = "noop" + + +@dataclass +class Change: + """A change to a model instance.""" + + change_type: ChangeType + object_type: str + object_id: int | None = field(default=None) + object_primary_value: str | None = field(default=None) + ref_id: str | None = field(default=None) + id: str = field(default_factory=lambda: str(uuid.uuid4())) + before: dict | None = field(default=None) + data: dict | None = field(default=None) + new_refs: list[str] = field(default_factory=list) + + def to_dict(self) -> dict: + """Convert the change to a dictionary.""" + return { + "id": self.id, + "change_type": self.change_type.value, + "object_type": self.object_type, + "object_id": self.object_id, + "ref_id": self.ref_id, + "object_primary_value": self.object_primary_value, + "before": self.before, + "data": self.data, + "new_refs": self.new_refs, + } + + +@dataclass +class ChangeSet: + """A set of changes to a model instance.""" + + id: str = field(default_factory=lambda: str(uuid.uuid4())) + changes: list[Change] = field(default_factory=list) + branch: dict[str, str] | None = field(default=None) # {"id": str, "name": str} + + def to_dict(self) -> dict: + """Convert the change set to a dictionary.""" + return { + "id": self.id, + "changes": [change.to_dict() for change in self.changes], + "branch": self.branch, + } + + def validate(self) -> dict[str, list[str]]: + """Validate basics of the change set data.""" + errors = defaultdict(dict) + + for change in self.changes: + model = apps.get_model(change.object_type) + + change_data = change.data.copy() + if change.before: + change_data.update(change.before) + + # check that there is some value for every required + # reference field, but don't validate the actual reference. + excluded_relation_fields = [] + rel_errors = defaultdict(list) + for f in model._meta.get_fields(): + if isinstance(f, (GenericRelation, GenericForeignKey)): + excluded_relation_fields.append(f.name) + continue + if not f.is_relation: + continue + field_name = f.name + excluded_relation_fields.append(field_name) + + if hasattr(f, "related_model") and f.related_model == ContentType: + change_data.pop(field_name, None) + base_field = field_name[:-5] + excluded_relation_fields.append(base_field + "_id") + value = change_data.pop(base_field + "_id", None) + else: + value = change_data.pop(field_name, None) + + if not f.null and not f.blank and not f.many_to_many: + # this field is a required relation... + if value is None: + rel_errors[f.name].append(f"Field {f.name} is required") + if rel_errors: + errors[change.object_type] = rel_errors + + try: + instance = model(**change_data) + instance.clean_fields(exclude=excluded_relation_fields) + except ValidationError as e: + errors[change.object_type].update(e.error_dict) + + return errors or None + + +@dataclass +class ChangeSetResult: + """A result of applying a change set.""" + + id: str | None = field(default_factory=lambda: str(uuid.uuid4())) + change_set: ChangeSet | None = field(default=None) + errors: dict | None = field(default=None) + + def to_dict(self) -> dict: + """Convert the result to a dictionary.""" + if self.change_set: + return self.change_set.to_dict() + + return { + "id": self.id, + "errors": self.errors, + } + + def get_status_code(self) -> int: + """Get the status code for the result.""" + return status.HTTP_200_OK if not self.errors else status.HTTP_400_BAD_REQUEST + + +class ChangeSetException(Exception): + """ChangeSetException is raised when an error occurs while generating or applying a change set.""" + + def __init__(self, message, errors=None): + """Initialize the exception.""" + super().__init__(message) + self.message = message + self.errors = errors or {} + + def __str__(self): + """Return the string representation of the exception.""" + if self.errors: + return f"{self.message}: {self.errors}" + return self.message diff --git a/netbox_diode_plugin/api/differ.py b/netbox_diode_plugin/api/differ.py index 84b6848..e44ecab 100644 --- a/netbox_diode_plugin/api/differ.py +++ b/netbox_diode_plugin/api/differ.py @@ -1,18 +1,15 @@ #!/usr/bin/env python -# Copyright 2024 NetBox Labs Inc +# Copyright 2025 NetBox Labs Inc """Diode NetBox Plugin - API - Differ.""" import copy -import json import logging -import uuid -from dataclasses import dataclass, field -from enum import Enum from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from utilities.data import shallow_compare_dict +from .common import Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType from .plugin_utils import get_primary_value, legal_fields from .supported_models import extract_supported_models from .transformer import cleanup_unresolved_references, transform_proto_json @@ -21,58 +18,6 @@ SUPPORTED_MODELS = extract_supported_models() -class ChangeType(Enum): - """Change type enum.""" - - CREATE = "create" - UPDATE = "update" - NOOP = "noop" - - -@dataclass -class Change: - """A change to a model instance.""" - - change_type: ChangeType - object_type: str - object_id: int | None = field(default=None) - object_primary_value: str | None = field(default=None) - ref_id: str | None = field(default=None) - id: str = field(default_factory=lambda: str(uuid.uuid4())) - before: dict | None = field(default=None) - data: dict | None = field(default=None) - new_refs: list[str] = field(default_factory=list) - - def to_dict(self) -> dict: - """Convert the change to a dictionary.""" - return { - "id": self.id, - "change_type": self.change_type.value, - "object_type": self.object_type, - "object_id": self.object_id, - "ref_id": self.ref_id, - "object_primary_value": self.object_primary_value, - "before": self.before, - "data": self.data, - "new_refs": self.new_refs, - } - - -@dataclass -class ChangeSet: - """A set of changes to a model instance.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - changes: list[Change] = field(default_factory=list) - branch: dict[str, str] | None = field(default=None) # {"id": str, "name": str} - - def to_dict(self) -> dict: - """Convert the change set to a dictionary.""" - return { - "id": self.id, - "changes": [change.to_dict() for change in self.changes], - "branch": self.branch, - } def prechange_data_from_instance(instance) -> dict: # noqa: C901 """Convert model instance data to a dictionary format for comparison.""" @@ -193,7 +138,7 @@ def sort_dict_recursively(d): return d -def generate_changeset(entity: dict, object_type: str) -> ChangeSet: +def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult: """Generate a changeset for an entity.""" change_set = ChangeSet() @@ -227,5 +172,11 @@ def generate_changeset(entity: dict, object_type: str) -> ChangeSet: new_refs, ) change_set.changes.append(change) - logger.error(f"change_set: {json.dumps(change_set.to_dict(), default=str, indent=4)}") - return change_set + + if errors := change_set.validate(): + raise ChangeSetException("Invalid change set", errors) + + return ChangeSetResult( + id=change_set.id, + change_set=change_set, + ) diff --git a/netbox_diode_plugin/api/matcher.py b/netbox_diode_plugin/api/matcher.py index 7d6973d..5f098c0 100644 --- a/netbox_diode_plugin/api/matcher.py +++ b/netbox_diode_plugin/api/matcher.py @@ -120,7 +120,7 @@ def fingerprint(self, data: dict) -> str|None: values = [] for field in sorted_fields: value = data[field] - if isinstance(value, (dict, UnresolvedReference)): + if isinstance(value, dict): logger.warning(f"unexpected value type for fingerprinting: {value}") return None if field in insensitive: @@ -232,13 +232,11 @@ def _prepare_data(self, data: dict) -> dict: if field.is_relation and hasattr(field, "related_model") and field.related_model == ContentType: prepared[field_name] = content_type_id(value) else: - logger.error("no.") prepared[field_name] = value - logger.error(f"field: {field_name} -> {value}") except FieldDoesNotExist: continue - logger.error(f"prepared data: {data} -> {prepared}") + # logger.error(f"prepared data: {data} -> {prepared}") return prepared @lru_cache(maxsize=256) diff --git a/netbox_diode_plugin/api/serializers.py b/netbox_diode_plugin/api/serializers.py index 838f8d3..60e2860 100644 --- a/netbox_diode_plugin/api/serializers.py +++ b/netbox_diode_plugin/api/serializers.py @@ -2,131 +2,10 @@ # Copyright 2024 NetBox Labs Inc """Diode NetBox Plugin - Serializers.""" -import logging - -from dcim.api.serializers import ( - DeviceRoleSerializer, - DeviceSerializer, - DeviceTypeSerializer, - InterfaceSerializer, - ManufacturerSerializer, - PlatformSerializer, - SiteSerializer, -) -from django.conf import settings from netbox.api.serializers import NetBoxModelSerializer -from packaging import version from netbox_diode_plugin.models import Setting -if version.parse(version.parse(settings.VERSION).base_version) >= version.parse("4.1"): - from core.models import ObjectChange -else: - from extras.models import ObjectChange -from ipam.api.serializers import IPAddressSerializer, PrefixSerializer -from rest_framework import serializers -from utilities.api import get_serializer_for_model -from virtualization.api.serializers import ( - ClusterGroupSerializer, - ClusterSerializer, - ClusterTypeSerializer, - VirtualDiskSerializer, - VirtualMachineSerializer, - VMInterfaceSerializer, -) - -logger = logging.getLogger("netbox.netbox_diode_plugin.api.serializers") - - -def dynamic_import(name): - """Dynamically import a class from an absolute path string.""" - components = name.split(".") - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def get_diode_serializer(instance): - """Get the Diode serializer based on instance model.""" - serializer = get_serializer_for_model(instance) - - serializer_name = f"netbox_diode_plugin.api.serializers.Diode{serializer.__name__}" - - try: - serializer = dynamic_import(serializer_name) - except AttributeError: - logger.warning(f"Could not find serializer for {serializer_name}") - pass - - return serializer - - -class ObjectStateSerializer(serializers.Serializer): - """Object State Serializer.""" - - object_type = serializers.SerializerMethodField(read_only=True) - object_change_id = serializers.SerializerMethodField(read_only=True) - object = serializers.SerializerMethodField(read_only=True) - - def get_object_type(self, instance): - """ - Get the object type from context sent from view. - - Return a string with the format "app.model". - """ - return self.context.get("object_type") - - def get_object_change_id(self, instance): - """ - Get the object changed based on instance ID. - - Return the ID of last change. - """ - object_changed = ( - ObjectChange.objects.filter(changed_object_id=instance.id) - .order_by("-id") - .values_list("id", flat=True) - ) - return object_changed[0] if len(object_changed) > 0 else None - - def get_object(self, instance): - """ - Get the serializer based on instance model. - - Get the data from the model according to its ID. - Return the object according to serializer defined in the NetBox. - """ - serializer = get_diode_serializer(instance) - - object_data = instance.__class__.objects.filter(id=instance.id) - - context = {"request": self.context.get("request")} - - data = serializer(object_data, context=context, many=True).data[0] - - return data - - -class ChangeSerialiazer(serializers.Serializer): - """ChangeSet Serializer.""" - - change_id = serializers.UUIDField(required=True) - change_type = serializers.CharField(required=True) - object_version = serializers.IntegerField(required=False, allow_null=True) - object_type = serializers.CharField(required=True) - object_id = serializers.IntegerField(required=False, allow_null=True) - data = serializers.DictField(required=True) - - -class ApplyChangeSetRequestSerializer(serializers.Serializer): - """ApplyChangeSet request Serializer.""" - - change_set_id = serializers.UUIDField(required=True) - change_set = serializers.ListField( - child=ChangeSerialiazer(), required=True, allow_empty=False - ) - class SettingSerializer(NetBoxModelSerializer): """Setting Serializer.""" @@ -142,250 +21,3 @@ class Meta: "created", "last_updated", ) - - -class DiodeIPAddressSerializer(IPAddressSerializer): - """Diode IP Address Serializer.""" - - class Meta: - """Meta class.""" - - model = IPAddressSerializer.Meta.model - fields = IPAddressSerializer.Meta.fields - - def get_assigned_object(self, obj): - """Get the assigned object based on the instance model.""" - if obj.assigned_object is None: - return None - - serializer = get_diode_serializer(obj.assigned_object) - - context = {"request": self.context["request"]} - assigned_object = serializer(obj.assigned_object, context=context).data - - if assigned_object.get("device"): - device_serializer = get_diode_serializer(obj.assigned_object.device) - device = device_serializer(obj.assigned_object.device, context=context).data - assigned_object["device"] = device - - if serializer.__name__.endswith("InterfaceSerializer"): - assigned_object = {"interface": assigned_object} - - return assigned_object - - -class DiodeSiteSerializer(SiteSerializer): - """Diode Site Serializer.""" - - status = serializers.CharField() - - class Meta: - """Meta class.""" - - model = SiteSerializer.Meta.model - fields = SiteSerializer.Meta.fields - - -class DiodeDeviceRoleSerializer(DeviceRoleSerializer): - """Diode Device Role Serializer.""" - - class Meta: - """Meta class.""" - - model = DeviceRoleSerializer.Meta.model - fields = DeviceRoleSerializer.Meta.fields - - -class DiodeManufacturerSerializer(ManufacturerSerializer): - """Diode Manufacturer Serializer.""" - - class Meta: - """Meta class.""" - - model = ManufacturerSerializer.Meta.model - fields = ManufacturerSerializer.Meta.fields - - -class DiodePlatformSerializer(PlatformSerializer): - """Diode Platform Serializer.""" - - manufacturer = DiodeManufacturerSerializer(required=False, allow_null=True) - - class Meta: - """Meta class.""" - - model = PlatformSerializer.Meta.model - fields = PlatformSerializer.Meta.fields - - -class DiodeDeviceTypeSerializer(DeviceTypeSerializer): - """Diode Device Type Serializer.""" - - default_platform = DiodePlatformSerializer(required=False, allow_null=True) - manufacturer = DiodeManufacturerSerializer(required=False, allow_null=True) - - class Meta: - """Meta class.""" - - model = DeviceTypeSerializer.Meta.model - fields = DeviceTypeSerializer.Meta.fields - - -class DiodeDeviceSerializer(DeviceSerializer): - """Diode Device Serializer.""" - - site = DiodeSiteSerializer() - device_type = DiodeDeviceTypeSerializer() - role = DiodeDeviceRoleSerializer() - platform = DiodePlatformSerializer(required=False, allow_null=True) - status = serializers.CharField() - - class Meta: - """Meta class.""" - - model = DeviceSerializer.Meta.model - fields = DeviceSerializer.Meta.fields - - -class DiodeNestedInterfaceSerializer(InterfaceSerializer): - """Diode Nested Interface Serializer.""" - - class Meta: - """Meta class.""" - - model = InterfaceSerializer.Meta.model - fields = InterfaceSerializer.Meta.fields - - -class DiodeInterfaceSerializer(InterfaceSerializer): - """Diode Interface Serializer.""" - - device = DiodeDeviceSerializer() - parent = DiodeNestedInterfaceSerializer() - type = serializers.CharField() - mode = serializers.CharField() - - class Meta: - """Meta class.""" - - model = InterfaceSerializer.Meta.model - fields = InterfaceSerializer.Meta.fields - - -class DiodePrefixSerializer(PrefixSerializer): - """Diode Prefix Serializer.""" - - status = serializers.CharField() - site = serializers.SerializerMethodField(read_only=True) - - class Meta: - """Meta class.""" - - model = PrefixSerializer.Meta.model - fields = PrefixSerializer.Meta.fields + ["site"] - - def get_site(self, obj): - """Get the site from the instance scope.""" - if obj.scope is None: - return None - - scope_model_meta = obj.scope_type.model_class()._meta - if scope_model_meta.app_label == "dcim" and scope_model_meta.model_name == "site": - serializer = get_serializer_for_model(obj.scope) - context = {'request': self.context['request']} - return serializer(obj.scope, nested=True, context=context).data - - return None - - -class DiodeClusterGroupSerializer(ClusterGroupSerializer): - """Diode Cluster Group Serializer.""" - - class Meta: - """Meta class.""" - - model = ClusterGroupSerializer.Meta.model - fields = ClusterGroupSerializer.Meta.fields - - -class DiodeClusterTypeSerializer(ClusterTypeSerializer): - """Diode Cluster Type Serializer.""" - - class Meta: - """Meta class.""" - - model = ClusterTypeSerializer.Meta.model - fields = ClusterTypeSerializer.Meta.fields - - -class DiodeClusterSerializer(ClusterSerializer): - """Diode Cluster Serializer.""" - - type = DiodeClusterTypeSerializer() - group = DiodeClusterGroupSerializer() - status = serializers.CharField() - site = serializers.SerializerMethodField(read_only=True) - - class Meta: - """Meta class.""" - - model = ClusterSerializer.Meta.model - fields = ClusterSerializer.Meta.fields + ["site"] - - def get_site(self, obj): - """Get the site from the instance scope.""" - if obj.scope is None: - return None - - scope_model_meta = obj.scope_type.model_class()._meta - if scope_model_meta.app_label == "dcim" and scope_model_meta.model_name == "site": - serializer = get_serializer_for_model(obj.scope) - context = {'request': self.context['request']} - return serializer(obj.scope, nested=True, context=context).data - - return None - - -class DiodeVirtualMachineSerializer(VirtualMachineSerializer): - """Diode Virtual Machine Serializer.""" - - status = serializers.CharField() - site = DiodeSiteSerializer() - cluster = DiodeClusterSerializer() - device = DiodeDeviceSerializer() - role = DiodeDeviceRoleSerializer() - tenant = serializers.CharField() - platform = DiodePlatformSerializer() - primary_ip = DiodeIPAddressSerializer() - primary_ip4 = DiodeIPAddressSerializer() - primary_ip6 = DiodeIPAddressSerializer() - - class Meta: - """Meta class.""" - - model = VirtualMachineSerializer.Meta.model - fields = VirtualMachineSerializer.Meta.fields - - -class DiodeVirtualDiskSerializer(VirtualDiskSerializer): - """Diode Virtual Disk Serializer.""" - - virtual_machine = DiodeVirtualMachineSerializer() - - class Meta: - """Meta class.""" - - model = VirtualDiskSerializer.Meta.model - fields = VirtualDiskSerializer.Meta.fields - - -class DiodeVMInterfaceSerializer(VMInterfaceSerializer): - """Diode VM Interface Serializer.""" - - virtual_machine = DiodeVirtualMachineSerializer() - - class Meta: - """Meta class.""" - - model = VMInterfaceSerializer.Meta.model - fields = VMInterfaceSerializer.Meta.fields diff --git a/netbox_diode_plugin/api/views.py b/netbox_diode_plugin/api/views.py index 5539db0..5f6d004 100644 --- a/netbox_diode_plugin/api/views.py +++ b/netbox_diode_plugin/api/views.py @@ -7,17 +7,14 @@ from django.apps import apps from django.db import transaction -from rest_framework import status, views +from rest_framework import views from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from netbox_diode_plugin.api.applier import ( - ApplyChangeSetException, - ApplyChangeSetResult, - apply_changeset, -) -from netbox_diode_plugin.api.differ import Change, ChangeSet, ChangeType, generate_changeset +from netbox_diode_plugin.api.applier import apply_changeset +from netbox_diode_plugin.api.common import Change, ChangeSet, ChangeSetException, ChangeSetResult +from netbox_diode_plugin.api.differ import generate_changeset from netbox_diode_plugin.api.permissions import IsDiodeWriter logger = logging.getLogger("netbox.diode_data") @@ -77,7 +74,14 @@ def _post(self, request, *args, **kwargs): f"No data found for {entity_key} in entity got: {entity.keys()}" ) - change_set = generate_changeset(original_entity_data, object_type) + try: + result = generate_changeset(original_entity_data, object_type) + except ChangeSetException as e: + logger.error(f"Error generating change set: {e}") + result = ChangeSetResult( + errors=e.errors, + ) + return Response(result.to_dict(), status=result.get_status_code()) branch_id = request.headers.get("X-NetBox-Branch") @@ -85,12 +89,11 @@ def _post(self, request, *args, **kwargs): if branch_id and Branch is not None: try: branch = Branch.objects.get(id=branch_id) - change_set.branch = {"id": branch.id, "name": branch.name} + result.branch = {"id": branch.id, "name": branch.name} except Branch.DoesNotExist: logger.warning(f"Branch with ID {branch_id} does not exist") - logger.info(f"change_set: {json.dumps(change_set.to_dict(), default=str)}") - return Response(change_set.to_dict(), status=status.HTTP_200_OK) + return Response(result.to_dict(), status=result.get_status_code()) class ApplyChangeSetView(views.APIView): @@ -131,25 +134,11 @@ def _post(self, request, *args, **kwargs): try: with transaction.atomic(): result = apply_changeset(change_set) - except ApplyChangeSetException as e: + except ChangeSetException as e: logger.error(f"Error applying change set: {e}") - result = ApplyChangeSetResult( + result = ChangeSetResult( id=change_set.id, - success=False, errors=e.errors, ) - return Response(result.to_dict(), status=status.HTTP_400_BAD_REQUEST) - - return Response(result.to_dict(), status=status.HTTP_200_OK) - - @staticmethod - def _get_error_response(change_set_id, errors): - """Get the error response.""" - return Response( - { - "change_set_id": change_set_id, - "result": "failed", - "errors": errors, - }, - status=status.HTTP_400_BAD_REQUEST, - ) + + return Response(result.to_dict(), status=result.get_status_code()) diff --git a/netbox_diode_plugin/tests/test_api_apply_change_set.py b/netbox_diode_plugin/tests/test_api_apply_change_set.py index 8fe4d15..b2d27c0 100644 --- a/netbox_diode_plugin/tests/test_api_apply_change_set.py +++ b/netbox_diode_plugin/tests/test_api_apply_change_set.py @@ -29,6 +29,8 @@ User = get_user_model() +def _get_error(response, object_name, field): + return response.json().get("errors", {}).get(object_name, {}).get(field, []) class BaseApplyChangeSet(APITestCase): """Base ApplyChangeSet test case.""" @@ -232,9 +234,7 @@ def test_change_type_create_return_200(self): ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) def test_change_type_update_return_200(self): """Test update change_type with successful.""" @@ -261,13 +261,12 @@ def test_change_type_update_return_200(self): ], } - response = self.client.post( + _ = self.client.post( self.url, payload, format="json", **self.user_header ) site_updated = Site.objects.get(id=20) - self.assertEqual(response.json().get("success"), True) self.assertEqual(site_updated.name, "Site A") def test_change_type_create_with_error_return_400(self): @@ -297,13 +296,11 @@ def test_change_type_create_with_error_return_400(self): } response = self.send_request(payload, status_code=status.HTTP_400_BAD_REQUEST) - site_created = Site.objects.filter(name="Site A") - self.assertEqual(response.json().get("success"), False) self.assertIn( 'Expected a list of items but got type "int".', - response.json().get("errors", {}).get("changes[0].asns", []), + _get_error(response, "changes[0]", "asns"), ) self.assertFalse(site_created.exists()) @@ -335,11 +332,9 @@ def test_change_type_update_with_error_return_400(self): response = self.send_request(payload, status_code=status.HTTP_400_BAD_REQUEST) site_updated = Site.objects.get(id=20) - - self.assertEqual(response.json().get("success"), False) self.assertIn( 'Expected a list of items but got type "int".', - response.json().get("errors", {}).get("changes[0].asns", []), + _get_error(response, "changes[0]", "asns") ) self.assertEqual(site_updated.name, "Site 2") @@ -385,9 +380,7 @@ def test_change_type_create_with_multiples_objects_return_200(self): ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) def test_change_type_update_with_multiples_objects_return_200(self): """Test update change type with two objects.""" @@ -429,12 +422,11 @@ def test_change_type_update_with_multiples_objects_return_200(self): ], } - response = self.send_request(payload) + _ = self.send_request(payload) site_updated = Site.objects.get(id=20) device_updated = Device.objects.get(id=10) - self.assertEqual(response.json().get("success"), True) self.assertEqual(site_updated.name, "Site A") self.assertEqual(device_updated.name, "Test Device 3") @@ -484,10 +476,9 @@ def test_change_type_create_and_update_with_error_in_one_object_return_400(self) site_created = Site.objects.filter(name="Site Z") device_created = Device.objects.filter(name="Test Device 4") - self.assertEqual(response.json().get("success"), False) self.assertIn( "Related object not found using the provided numeric ID: 3", - response.json().get("errors", {}).get("changes[1].device_type", []), + _get_error(response, "changes[1]", "device_type"), ) self.assertFalse(site_created.exists()) self.assertFalse(device_created.exists()) @@ -555,11 +546,9 @@ def test_multiples_create_type_error_in_two_objects_return_400(self): site_created = Site.objects.filter(name="Site Z") device_created = Device.objects.filter(name="Test Device 4") - self.assertEqual(response.json().get("success"), False) - self.assertIn( "Related object not found using the provided numeric ID: 3", - response.json().get("errors", {}).get("changes[1].device_type", []), + _get_error(response, "changes[1]", "device_type"), ) self.assertFalse(site_created.exists()) @@ -598,7 +587,7 @@ def test_change_type_update_with_object_id_not_exist_return_400(self): self.assertIn( "dcim.site with id 30 does not exist", - response.json().get("errors", {}).get("changes[0].object_id", []), + _get_error(response, "changes[0]", "object_id"), ) self.assertEqual(site_updated.name, "Site 2") @@ -630,9 +619,9 @@ def test_change_set_id_field_not_provided_return_400(self): response = self.send_request(payload, status_code=status.HTTP_400_BAD_REQUEST) self.assertIsNone(response.json().get("errors", {}).get("change_id", None)) - self.assertEqual( - response.json().get("errors", {}).get("id", []), - ["Change set ID is required"], + self.assertIn( + "Change set ID is required", + _get_error(response, "changeset", "id"), ) def test_change_type_field_not_provided_return_400( @@ -666,7 +655,7 @@ def test_change_type_field_not_provided_return_400( self.assertIn( "Unsupported change type ''", - response.json().get("errors", {}).get("changes[0].change_type", []), + _get_error(response, "changes[0]", "change_type"), ) def test_change_set_id_field_and_change_set_not_provided_return_400(self): @@ -680,7 +669,7 @@ def test_change_set_id_field_and_change_set_not_provided_return_400(self): self.assertIn( "Change set ID is required", - response.json().get("errors", {}).get("id", []), + _get_error(response, "changeset", "id"), ) def test_change_type_and_object_type_provided_return_400( @@ -731,7 +720,7 @@ def test_change_type_and_object_type_provided_return_400( self.assertIn( "Unsupported change type 'None'", - response.json().get("errors", {}).get("changes[0].change_type", []), + _get_error(response, "changes[0]", "change_type"), ) # self.assertEqual( # response.json().get("errors")[0].get("change_type"), @@ -772,9 +761,7 @@ def test_create_ip_address_return_200(self): }, ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) # def test_create_ip_address_return_400(self): # """Test create ip_address with missing interface name.""" @@ -953,11 +940,9 @@ def test_add_primary_ip_address_to_device(self): ], } - response = self.send_request(payload) - + _ = self.send_request(payload) device_updated = Device.objects.get(id=10) - self.assertEqual(response.json().get("success"), True) self.assertEqual(device_updated.name, self.devices[0].name) self.assertEqual(device_updated.primary_ip4, self.ip_addresses[0]) @@ -981,9 +966,7 @@ def test_create_prefix_with_site_stored_as_scope(self): }, ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) self.assertEqual(Prefix.objects.get(prefix="192.168.0.0/24").scope, self.sites[0]) def test_create_prefix_with_unknown_site_fails(self): @@ -1007,11 +990,9 @@ def test_create_prefix_with_unknown_site_fails(self): ], } response = self.send_request(payload, status_code=status.HTTP_400_BAD_REQUEST) - - self.assertEqual(response.json().get("success"), False) self.assertIn( 'Please select a site.', - response.json().get("errors", {}).get("changes[0].scope", []), + _get_error(response, "changes[0]", "scope"), ) self.assertFalse(Prefix.objects.filter(prefix="192.168.0.0/24").exists()) @@ -1038,9 +1019,7 @@ def test_create_virtualization_cluster_with_site_stored_as_scope(self): }, ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) self.assertEqual(Cluster.objects.get(name="Cluster 3").scope, self.sites[0]) def test_create_virtualmachine_with_cluster_site_stored_as_scope(self): @@ -1074,7 +1053,5 @@ def test_create_virtualmachine_with_cluster_site_stored_as_scope(self): }, ], } - response = self.send_request(payload) - - self.assertEqual(response.json().get("success"), True) + _ = self.send_request(payload) self.assertEqual(VirtualMachine.objects.get(name="VM foobar", site_id=self.sites[0].id).cluster.scope, self.sites[0]) diff --git a/netbox_diode_plugin/tests/test_api_diff_and_apply.py b/netbox_diode_plugin/tests/test_api_diff_and_apply.py index 6aae051..19793c5 100644 --- a/netbox_diode_plugin/tests/test_api_diff_and_apply.py +++ b/netbox_diode_plugin/tests/test_api_diff_and_apply.py @@ -2,12 +2,16 @@ # Copyright 2024 NetBox Labs Inc """Diode NetBox Plugin - Tests.""" +import logging + from dcim.models import Interface, Site from django.contrib.auth import get_user_model from rest_framework import status from users.models import Token from utilities.testing import APITestCase +logger = logging.getLogger(__name__) + User = get_user_model() @@ -39,8 +43,6 @@ def test_generate_diff_and_apply_create_site(self): } _, response = self.diff_and_apply(payload) - self.assertEqual(response.json().get("success"), True) - new_site = Site.objects.get(name="Generate Diff and Apply Site") self.assertEqual(new_site.slug, "generate-diff-and-apply-site") @@ -76,8 +78,6 @@ def test_generate_diff_and_apply_create_interface_with_primay_mac_address(self): } _, response = self.diff_and_apply(payload) - self.assertEqual(response.json().get("success"), True) - new_interface = Interface.objects.get(name="Interface 1x") self.assertEqual(new_interface.primary_mac_address.mac_address, "00:00:00:00:00:01") diff --git a/netbox_diode_plugin/tests/test_api_serializers.py b/netbox_diode_plugin/tests/test_api_serializers.py deleted file mode 100644 index 00e9547..0000000 --- a/netbox_diode_plugin/tests/test_api_serializers.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# Copyright 2024 NetBox Labs Inc -"""Diode NetBox Plugin - Tests.""" -from unittest.mock import MagicMock - -from dcim.models import Site -from django.test import TestCase -from extras.api.serializers import TagSerializer -from extras.models import Tag - -from netbox_diode_plugin.api.serializers import DiodeIPAddressSerializer, DiodeSiteSerializer, get_diode_serializer - - -class SerializersTestCase(TestCase): - """Test case for the serializers.""" - - def test_get_diode_serializer(self): - """Check the diode serializer is found.""" - site = Site.objects.create(name="test") - assert get_diode_serializer(site) == DiodeSiteSerializer - - tag = Tag.objects.create(name="test") - assert get_diode_serializer(tag) == TagSerializer - - - def test_get_assigned_object_returns_none_if_no_assigned_object(self): - """Check the assigned object is None if not provided.""" - obj = MagicMock() - obj.assigned_object = None - serializer = DiodeIPAddressSerializer() - result = serializer.get_assigned_object(obj) - self.assertIsNone(result)