diff --git a/docker/Dockerfile-diode-netbox-plugin b/docker/Dockerfile-diode-netbox-plugin index ec3f9c6..24a73fd 100644 --- a/docker/Dockerfile-diode-netbox-plugin +++ b/docker/Dockerfile-diode-netbox-plugin @@ -1,4 +1,4 @@ -FROM netboxcommunity/netbox:v4.1.11-3.0.2 +FROM netboxcommunity/netbox:v4.2.3-3.1.1 COPY ./netbox/configuration/ /etc/netbox/config/ RUN chmod 755 /etc/netbox/config/* && \ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c092668..c1112ab 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,7 +1,7 @@ name: diode-netbox-plugin services: netbox: &netbox - image: netboxcommunity/netbox:v4.1.11-3.0.2-diode-netbox-plugin + image: netboxcommunity/netbox:v4.2.3-3.1.1-diode-netbox-plugin build: context: . dockerfile: Dockerfile-diode-netbox-plugin diff --git a/netbox-plugin.yaml b/netbox-plugin.yaml index 3046a2a..916c77d 100644 --- a/netbox-plugin.yaml +++ b/netbox-plugin.yaml @@ -1,6 +1,9 @@ version: 0.1 package_name: netboxlabs-diode-netbox-plugin compatibility: + - release: 0.7.0 + netbox_min: 4.2.3 + netbox_max: 4.2.3 - release: 0.6.0 netbox_min: 4.1.0 netbox_max: 4.1.3 diff --git a/netbox_diode_plugin/__init__.py b/netbox_diode_plugin/__init__.py index fa1d860..c5da907 100644 --- a/netbox_diode_plugin/__init__.py +++ b/netbox_diode_plugin/__init__.py @@ -15,7 +15,7 @@ class NetBoxDiodePluginConfig(PluginConfig): description = "Diode plugin for NetBox." version = version_semver() base_url = "diode" - min_version = "3.7.2" + min_version = "4.2.3" default_settings = { # Auto-provision users for Diode plugin "auto_provision_users": False, diff --git a/netbox_diode_plugin/api/serializers.py b/netbox_diode_plugin/api/serializers.py index df0c5fb..838f8d3 100644 --- a/netbox_diode_plugin/api/serializers.py +++ b/netbox_diode_plugin/api/serializers.py @@ -275,14 +275,27 @@ class Meta: class DiodePrefixSerializer(PrefixSerializer): """Diode Prefix Serializer.""" - site = DiodeSiteSerializer() status = serializers.CharField() + site = serializers.SerializerMethodField(read_only=True) class Meta: """Meta class.""" model = PrefixSerializer.Meta.model - fields = PrefixSerializer.Meta.fields + fields = PrefixSerializer.Meta.fields + ["site"] + + def get_site(self, obj): + """Get the site from the instance scope.""" + if obj.scope is None: + return None + + scope_model_meta = obj.scope_type.model_class()._meta + if scope_model_meta.app_label == "dcim" and scope_model_meta.model_name == "site": + serializer = get_serializer_for_model(obj.scope) + context = {'request': self.context['request']} + return serializer(obj.scope, nested=True, context=context).data + + return None class DiodeClusterGroupSerializer(ClusterGroupSerializer): @@ -311,13 +324,26 @@ class DiodeClusterSerializer(ClusterSerializer): type = DiodeClusterTypeSerializer() group = DiodeClusterGroupSerializer() status = serializers.CharField() - site = DiodeSiteSerializer() + site = serializers.SerializerMethodField(read_only=True) class Meta: """Meta class.""" model = ClusterSerializer.Meta.model - fields = ClusterSerializer.Meta.fields + fields = ClusterSerializer.Meta.fields + ["site"] + + def get_site(self, obj): + """Get the site from the instance scope.""" + if obj.scope is None: + return None + + scope_model_meta = obj.scope_type.model_class()._meta + if scope_model_meta.app_label == "dcim" and scope_model_meta.model_name == "site": + serializer = get_serializer_for_model(obj.scope) + context = {'request': self.context['request']} + return serializer(obj.scope, nested=True, context=context).data + + return None class DiodeVirtualMachineSerializer(VirtualMachineSerializer): diff --git a/netbox_diode_plugin/api/views.py b/netbox_diode_plugin/api/views.py index 1768e5a..d2fdd15 100644 --- a/netbox_diode_plugin/api/views.py +++ b/netbox_diode_plugin/api/views.py @@ -1,9 +1,9 @@ #!/usr/bin/env python # Copyright 2024 NetBox Labs Inc """Diode NetBox Plugin - API Views.""" - from typing import Any, Dict, Optional +from django.apps import apps from django.conf import settings from packaging import version @@ -11,11 +11,11 @@ from core.models import ObjectType as NetBoxType else: from django.contrib.contenttypes.models import ContentType as NetBoxType + from django.core.exceptions import FieldError -from django.db import transaction +from django.core.exceptions import ValidationError as DjangoValidationError +from django.db import models, transaction from django.db.models import Q -from extras.models import CachedValue -from netbox.search import LookupTypes from rest_framework import status, views from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAuthenticated @@ -23,10 +23,87 @@ from utilities.api import get_serializer_for_model from netbox_diode_plugin.api.permissions import IsDiodeReader, IsDiodeWriter -from netbox_diode_plugin.api.serializers import ( - ApplyChangeSetRequestSerializer, - ObjectStateSerializer, -) +from netbox_diode_plugin.api.serializers import ApplyChangeSetRequestSerializer, ObjectStateSerializer + + +def dynamic_import(name): + """Dynamically import a class from an absolute path string.""" + components = name.split(".") + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def _get_index_class_fields(object_type): + """ + Given an object type name (e.g., 'dcim.site'), dynamically find and return the corresponding Index class fields. + + :param object_type: Object type name in the format 'app_label.model_name' + :return: The corresponding model and its Index class (e.g., SiteIndex) field names or None. + """ + try: + # Extract app_label and model_name from 'dcim.site' + app_label, model_name = object_type.split('.') + + # Get the model class dynamically + model = apps.get_model(app_label, model_name) + + # Import the module where index classes are defined (adjust if needed) + index_module = dynamic_import(f"{app_label}.search.{model.__name__}Index") + + # Retrieve the index class fields tuple + fields = getattr(index_module, "fields", None) + + # Extract the field names list from the tuple + field_names = [field[0] for field in fields] + + return model, field_names + + except (LookupError, ModuleNotFoundError, AttributeError, ValueError): + return None, None + +def _validate_model_instance_fields(instance, fields, value): + """ + Validate the model instance fields against the value. + + :param instance: The model instance. + :param fields: The fields of the model instance. + :param value: The value to validate against the model instance fields. + :return: fields list passed validation + """ + errors = {} + + # Set provided values to the instance fields + for field in fields: + if hasattr(instance, field): + # get the field type + field_cls = instance._meta.get_field(field).__class__ + + field_value = _convert_field_value(field_cls, value) + setattr(instance, field, field_value) + + # Attempt to validate the instance + try: + instance.clean_fields() + except DjangoValidationError as e: + errors = e.message_dict + return errors + +def _convert_field_value(field_cls, value): + """Return the converted field value based on the field type.""" + if value is None: + return value + + try: + if issubclass(field_cls, (models.FloatField, models.DecimalField)): + return float(value) + if issubclass(field_cls, models.IntegerField): + return int(value) + except (ValueError, TypeError): + pass + + return value class ObjectStateView(views.APIView): @@ -60,49 +137,39 @@ def _get_lookups(self, object_type_model: str) -> tuple: return ("site",) return () - def get(self, request, *args, **kwargs): - """ - Return a JSON with object_type, object_change_id, and object. - - Search for objects according to object type. - If the obj_type parameter is not in the parameters, raise a ValidationError. - When object ID is provided in the request, search using it in the model specified by object type. - If ID is not provided, use the q parameter for searching. - Lookup is iexact - """ - object_type = self.request.query_params.get("object_type", None) + def _search_queryset(self, request): + """Search for objects according to object type using search index classes.""" + object_type = request.GET.get("object_type", None) + object_id = request.GET.get("id", None) + query = request.GET.get("q", None) if not object_type: raise ValidationError("object_type parameter is required") - app_label, model_name = object_type.split(".") - object_content_type = NetBoxType.objects.get_by_natural_key( - app_label, model_name - ) - object_type_model = object_content_type.model_class() + if not object_id and not query: + raise ValidationError("id or q parameter is required") - object_id = self.request.query_params.get("id", None) + model, fields = _get_index_class_fields(object_type) if object_id: - queryset = object_type_model.objects.filter(id=object_id) + queryset = model.objects.filter(id=object_id) else: - lookup = LookupTypes.EXACT - search_value = self.request.query_params.get("q", None) - if not search_value: - raise ValidationError("id or q parameter is required") + q = Q() - query_filter = Q(**{f"value__{lookup}": search_value}) - query_filter &= Q(object_type__in=[object_content_type]) + invalid_fields = _validate_model_instance_fields(model(), fields, query) - object_id_in_cached_value = CachedValue.objects.filter( - query_filter - ).values_list("object_id", flat=True) + fields = [field for field in fields if field not in invalid_fields] - queryset = object_type_model.objects.filter( - id__in=object_id_in_cached_value - ) + for field in fields: + q |= Q(**{f"{field}__exact": query}) # Exact match - lookups = self._get_lookups(str(object_type_model).lower()) + try: + queryset = model.objects.filter(q) + except DjangoValidationError: + queryset = model.objects.none() + pass + + lookups = self._get_lookups(str(model).lower()) if lookups: queryset = queryset.prefetch_related(*lookups) @@ -112,16 +179,32 @@ def get(self, request, *args, **kwargs): ) if additional_attributes_query_filter: - try: - queryset = queryset.filter(**additional_attributes_query_filter) - except (FieldError, ValueError): - return Response( - {"errors": ["invalid additional attributes provided"]}, - status=status.HTTP_400_BAD_REQUEST, - ) + queryset = queryset.filter(**additional_attributes_query_filter) + + return queryset + + def get(self, request, *args, **kwargs): + """ + Return a JSON with object_type, object_change_id, and object. + + Search for objects according to object type. + If the obj_type parameter is not in the parameters, raise a ValidationError. + When object ID is provided in the request, search using it in the model specified by object type. + If ID is not provided, use the q parameter for searching. + Lookup is iexact + """ + try: + queryset = self._search_queryset(request) + except (FieldError, ValueError): + return Response( + {"errors": ["invalid additional attributes provided"]}, + status=status.HTTP_400_BAD_REQUEST, + ) self.check_object_permissions(request, queryset) + object_type = request.GET.get("object_type", None) + serializer = ObjectStateSerializer( queryset, many=True, @@ -285,17 +368,6 @@ def _get_error_response(change_set_id, error): status=status.HTTP_400_BAD_REQUEST, ) - def _ipaddress_assigned_object(self, change_set: list) -> list: - """Retrieve the IP address assigned object from the change set.""" - ipaddress_assigned_object = [ - change.get("data").get("assigned_object", None) - for change in change_set - if change.get("object_type") == "ipam.ipaddress" - and change.get("data", {}).get("assigned_object", None) - ] - - return ipaddress_assigned_object - def _retrieve_assigned_object_interface_device_lookup_args( self, device: dict ) -> dict: @@ -338,17 +410,17 @@ def _retrieve_assigned_object_interface_device_lookup_args( ) return args - def _handle_ipaddress_assigned_object( - self, object_data: dict, ipaddress_assigned_object: list - ) -> Optional[Dict[str, Any]]: + def _handle_ipaddress_assigned_object(self, object_data: dict) -> Optional[Dict[str, Any]]: """Handle IPAM IP address assigned object.""" - if any(ipaddress_assigned_object): - assigned_object_keys = list(ipaddress_assigned_object[0].keys()) + ipaddress_assigned_object = object_data.get("assigned_object", None) + + if ipaddress_assigned_object is not None: + assigned_object_keys = list(ipaddress_assigned_object.keys()) model_name = assigned_object_keys[0] assigned_object_type = self._get_assigned_object_type(model_name) assigned_object_model = self._get_object_type_model(assigned_object_type) assigned_object_properties_dict = dict( - ipaddress_assigned_object[0][model_name].items() + ipaddress_assigned_object[model_name].items() ) if len(assigned_object_properties_dict) == 0: @@ -381,7 +453,7 @@ def _handle_ipaddress_assigned_object( ) except assigned_object_model.DoesNotExist: return { - "assigned_object": f"Assigned object with name {ipaddress_assigned_object[0][model_name]} does not exist" + "assigned_object": f"Assigned object with name {ipaddress_assigned_object[model_name]} does not exist" } object_data.pop("assigned_object") @@ -389,6 +461,57 @@ def _handle_ipaddress_assigned_object( object_data["assigned_object_id"] = assigned_object_instance.id return None + def _handle_interface_mac_address_compat(self, instance, object_type: str, object_data: dict) -> Optional[Dict[str, Any]]: + """Handle interface mac address backward compatibility.""" + # TODO(ltucker): deprecate. + if object_type != "dcim.interface" and object_type != "virtualization.vminterface": + return None + + if object_data.get("mac_address"): + mac_address_value = object_data.pop("mac_address") + mac_address_instance, _ = instance.mac_addresses.get_or_create( + mac_address=mac_address_value, + ) + instance.primary_mac_address = mac_address_instance + instance.save() + return None + + def _handle_scope(self, object_data: dict) -> Optional[Dict[str, Any]]: + """Handle scope object.""" + if object_data.get("site"): + site = object_data.pop("site") + object_data["scope_type"] = "dcim.site" + scope_type_model = self._get_object_type_model("dcim.site") + site_id = site.get("id", None) + if site_id is None: + try: + site = scope_type_model.objects.get( + name=site.get("name") + ) + site_id = site.id + except scope_type_model.DoesNotExist: + return {"site": f"site with name {site.get('name')} does not exist"} + + object_data["scope_id"] = site_id + + return None + + def _transform_object_data(self, object_type: str, object_data: dict) -> Optional[Dict[str, Any]]: + """Transform object data.""" + errors = None + + match object_type: + case "ipam.ipaddress": + errors = self._handle_ipaddress_assigned_object(object_data) + case "ipam.prefix": + errors = self._handle_scope(object_data) + case "virtualization.cluster": + errors = self._handle_scope(object_data) + case _: + pass + + return errors + def post(self, request, *args, **kwargs): """ Create a new change set and apply it to the current state. @@ -411,8 +534,6 @@ def post(self, request, *args, **kwargs): change_set = request_serializer.data.get("change_set", None) - ipaddress_assigned_object = self._ipaddress_assigned_object(change_set) - try: with transaction.atomic(): for change in change_set: @@ -422,14 +543,7 @@ def post(self, request, *args, **kwargs): object_data = change.get("data", None) object_id = change.get("object_id", None) - errors = None - if ( - any(ipaddress_assigned_object) - and object_type == "ipam.ipaddress" - ): - errors = self._handle_ipaddress_assigned_object( - object_data, ipaddress_assigned_object - ) + errors = self._transform_object_data(object_type, object_data) if errors is not None: serializer_errors.append({"change_id": change_id, **errors}) @@ -450,6 +564,12 @@ def post(self, request, *args, **kwargs): serializer_errors.append( {"change_id": change_id, **errors_dict} ) + continue + + errors = self._handle_interface_mac_address_compat(serializer.instance, object_type, object_data) + if errors is not None: + serializer_errors.append({"change_id": change_id, **errors}) + continue if len(serializer_errors) > 0: raise ApplyChangeSetException except ApplyChangeSetException: diff --git a/netbox_diode_plugin/tests/test_api_apply_change_set.py b/netbox_diode_plugin/tests/test_api_apply_change_set.py index c8e1232..6bef32d 100644 --- a/netbox_diode_plugin/tests/test_api_apply_change_set.py +++ b/netbox_diode_plugin/tests/test_api_apply_change_set.py @@ -14,12 +14,17 @@ Site, ) from django.contrib.auth import get_user_model -from ipam.models import ASN, RIR, IPAddress +from ipam.models import ASN, RIR, IPAddress, Prefix from netaddr import IPNetwork from rest_framework import status from users.models import Token from utilities.testing import APITestCase -from virtualization.models import Cluster, ClusterType +from virtualization.models import ( + Cluster, + ClusterType, + VirtualMachine, + VMInterface, +) User = get_user_model() @@ -145,6 +150,12 @@ def setUp(self): ) IPAddress.objects.bulk_create(self.ip_addresses) + self.virtual_machines = ( + VirtualMachine(name="Virtual Machine 1"), + VirtualMachine(name="Virtual Machine 2"), + ) + VirtualMachine.objects.bulk_create(self.virtual_machines) + self.url = "/netbox/api/plugins/diode/apply-change-set/" def send_request(self, payload, status_code=status.HTTP_200_OK): @@ -982,3 +993,164 @@ def test_add_primary_ip_address_to_device(self): self.assertEqual(response.json().get("result"), "success") self.assertEqual(device_updated.name, self.devices[0].name) self.assertEqual(device_updated.primary_ip4, self.ip_addresses[0]) + + def test_create_and_update_interface_with_compat_mac_address_field(self): + """Test create interface using backward compatible mac_address field.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "dcim.interface", + "object_id": None, + "data": { + "name": "Interface 6", + "type": "virtual", + "mac_address": "00:00:00:00:00:01", + "device": { + "id": self.devices[1].pk, + }, + }, + }, + ], + } + + response = self.send_request(payload) + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(Interface.objects.count(), 6) + interface_id = Interface.objects.order_by('-id').first().id + self.assertEqual(Interface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:01") + + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "update", + "object_version": None, + "object_type": "dcim.interface", + "object_id": interface_id, + "data": { + "name": "Interface 6", + "mac_address": "00:00:00:00:00:02", + "type": "virtual", + "device": { + "id": self.devices[1].pk, + }, + }, + }, + ], + } + response = self.send_request(payload) + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(Interface.objects.count(), 6) + self.assertEqual(Interface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:02") + + def test_create_and_update_vminterface_with_compat_mac_address_field(self): + """Test create vminterface using backward compatible mac_address field.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "virtualization.vminterface", + "object_id": None, + "data": { + "name": "VM Interface 1", + "mac_address": "00:00:00:00:00:01", + "virtual_machine": { + "id": self.virtual_machines[0].pk, + }, + }, + }, + ], + } + + response = self.send_request(payload) + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(VMInterface.objects.count(), 1) + interface_id = VMInterface.objects.order_by('-id').first().id + self.assertEqual(VMInterface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:01") + + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "update", + "object_version": None, + "object_type": "virtualization.vminterface", + "object_id": interface_id, + "data": { + "name": "VM Interface 1", + "mac_address": "00:00:00:00:00:02", + "virtual_machine": { + "id": self.virtual_machines[0].pk, + }, + }, + }, + ], + } + response = self.send_request(payload) + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(VMInterface.objects.count(), 1) + self.assertEqual(VMInterface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:02") + + def test_create_prefix_with_site_stored_as_scope(self): + """Test create prefix with site stored as scope.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "ipam.prefix", + "object_id": None, + "data": { + "prefix": "192.168.0.0/24", + "site": { + "name": self.sites[0].name, + }, + }, + }, + ], + } + response = self.send_request(payload) + + self.assertEqual(response.json().get("result"), "success") + self.assertEqual(Prefix.objects.get(prefix="192.168.0.0/24").scope, self.sites[0]) + + def test_create_prefix_with_unknown_site_fails(self): + """Test create prefix with unknown site fails.""" + payload = { + "change_set_id": str(uuid.uuid4()), + "change_set": [ + { + "change_id": str(uuid.uuid4()), + "change_type": "create", + "object_version": None, + "object_type": "ipam.prefix", + "object_id": None, + "data": { + "prefix": "192.168.0.0/24", + "site": { + "name": "unknown site" + }, + }, + }, + ], + } + response = self.send_request(payload, status_code=status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json().get("result"), "failed") + self.assertIn( + 'site with name unknown site does not exist', + response.json().get("errors")[0].get("site"), + ) + self.assertFalse(Prefix.objects.filter(prefix="192.168.0.0/24").exists()) diff --git a/netbox_diode_plugin/tests/test_api_object_state.py b/netbox_diode_plugin/tests/test_api_object_state.py index 7031549..d13ef35 100644 --- a/netbox_diode_plugin/tests/test_api_object_state.py +++ b/netbox_diode_plugin/tests/test_api_object_state.py @@ -12,7 +12,6 @@ Site, ) from django.contrib.auth import get_user_model -from django.core.management import call_command from ipam.models import IPAddress from netaddr import IPNetwork from rest_framework import status @@ -154,9 +153,6 @@ def setUpClass(cls): ) IPAddress.objects.bulk_create(cls.ip_addresses) - # call_command is because the searching using q parameter uses CachedValue to get the object ID - call_command("reindex") - def setUp(self): """Set up test.""" self.root_user = User.objects.create_user( @@ -182,7 +178,7 @@ def setUp(self): def test_return_object_state_using_id(self): """Test searching using id parameter - Root User.""" - site_id = Site.objects.get(name=self.sites[0]).id + site_id = Site.objects.get(name=self.sites[0].name).id query_parameters = {"id": site_id, "object_type": "dcim.site"} response = self.client.get(self.url, query_parameters, **self.root_header)