diff --git a/netbox_diode_plugin/api/views.py b/netbox_diode_plugin/api/views.py index d2fdd15..acfee6e 100644 --- a/netbox_diode_plugin/api/views.py +++ b/netbox_diode_plugin/api/views.py @@ -49,6 +49,10 @@ def _get_index_class_fields(object_type): # Get the model class dynamically model = apps.get_model(app_label, model_name) + # TagIndex registered in the netbox_diode_plugin + if app_label == "extras" and model_name == "tag": + app_label = "netbox_diode_plugin" + # Import the module where index classes are defined (adjust if needed) index_module = dynamic_import(f"{app_label}.search.{model.__name__}Index") @@ -246,7 +250,7 @@ def _get_object_type_model(object_type: str): object_content_type = NetBoxType.objects.get_by_natural_key( app_label, model_name ) - return object_content_type.model_class() + return object_content_type, object_content_type.model_class() def _get_assigned_object_type(self, model_name: str): """Get the object type model from applied IPAddress assigned object.""" @@ -255,79 +259,114 @@ def _get_assigned_object_type(self, model_name: str): } return assignable_object_types.get(model_name.lower(), None) + def _add_nested_opts(self, fields, key, value): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + self._add_nested_opts(fields, f"{key}__{nested_key}", nested_value) + elif not isinstance(value, list): + fields[key] = value + def _get_serializer( self, change_type: str, object_id: int, object_type: str, object_data: dict, - change_set_id: str, ): """Get the serializer for the object type.""" - object_type_model = self._get_object_type_model(object_type) + object_type_model, object_type_model_class = self._get_object_type_model(object_type) + if change_type == "create": - serializer = get_serializer_for_model(object_type_model)( - data=object_data, context={"request": self.request} - ) - elif change_type == "update": - lookups = () - args = {} + return self._get_serializer_to_create(object_data, object_type, object_type_model, object_type_model_class) - primary_ip_to_set: Optional[dict] = None + if change_type == "update": + return self._get_serializer_to_update(object_data, object_id, object_type, object_type_model_class) + + raise ValidationError("Invalid change_type") + + def _get_serializer_to_create(self, object_data, object_type, object_type_model, object_type_model_class): + # Get object data fields that are not dictionaries or lists + fields = self._get_fields_to_find_existing_objects(object_data, object_type, object_type_model) + # Check if the object already exists + try: + instance = object_type_model_class.objects.get(**fields) + return get_serializer_for_model(object_type_model_class)( + instance, data=object_data, context={"request": self.request, "pk": instance.pk} + ) + except object_type_model_class.DoesNotExist: + pass + serializer = get_serializer_for_model(object_type_model_class)( + data=object_data, context={"request": self.request} + ) + return serializer - if object_id: - args["id"] = object_id - elif object_type == "dcim.device" and any( + def _get_serializer_to_update(self, object_data, object_id, object_type, object_type_model_class): + lookups = () + fields = {} + primary_ip_to_set: Optional[dict] = None + if object_id: + fields["id"] = object_id + elif object_type == "dcim.device" and any( object_data.get(attr) for attr in ("primary_ip4", "primary_ip6") - ): + ): + ip_address = self._retrieve_primary_ip_address( + "primary_ip4", object_data + ) + + if ip_address is None: ip_address = self._retrieve_primary_ip_address( - "primary_ip4", object_data + "primary_ip6", object_data ) - if ip_address is None: - ip_address = self._retrieve_primary_ip_address( - "primary_ip6", object_data - ) - - if ip_address is None: - raise ValidationError("primary IP not found") - - if ip_address: - primary_ip_to_set = { - "id": ip_address.id, - "family": ip_address.family, - } + if ip_address is None: + raise ValidationError("primary IP not found") - lookups = ("site",) - args["name"] = object_data.get("name") - args["site__name"] = object_data.get("site").get("name") - else: - raise ValidationError("object_id parameter is required") + if ip_address: + primary_ip_to_set = { + "id": ip_address.id, + "family": ip_address.family, + } - try: - instance = object_type_model.objects.prefetch_related(*lookups).get( - **args - ) - if object_type == "dcim.device" and primary_ip_to_set: - object_data = { - "id": instance.id, - "device_type": instance.device_type.id, - "role": instance.role.id, - "site": instance.site.id, - f'primary_ip{primary_ip_to_set.get("family")}': primary_ip_to_set.get( - "id" - ), - } - except object_type_model.DoesNotExist: - raise ValidationError(f"object with id {object_id} does not exist") - - serializer = get_serializer_for_model(object_type_model)( - instance, data=object_data, context={"request": self.request} - ) + lookups = ("site",) + fields["name"] = object_data.get("name") + fields["site__name"] = object_data.get("site").get("name") else: - raise ValidationError("Invalid change_type") + raise ValidationError("object_id parameter is required") + try: + instance = object_type_model_class.objects.prefetch_related(*lookups).get(**fields) + if object_type == "dcim.device" and primary_ip_to_set: + object_data = { + "id": instance.id, + "device_type": instance.device_type.id, + "role": instance.role.id, + "site": instance.site.id, + f'primary_ip{primary_ip_to_set.get("family")}': primary_ip_to_set.get( + "id" + ), + } + except object_type_model_class.DoesNotExist: + raise ValidationError(f"object with id {object_id} does not exist") + serializer = get_serializer_for_model(object_type_model_class)( + instance, data=object_data, context={"request": self.request} + ) return serializer + def _get_fields_to_find_existing_objects(self, object_data, object_type, object_type_model): + fields = {} + for key, value in object_data.items(): + self._add_nested_opts(fields, key, value) + match object_type: + case "dcim.interface" | "virtualization.vminterface": + mac_address = fields.pop("mac_address", None) + if mac_address is not None: + fields["primary_mac_address__mac_address"] = mac_address + case "ipam.ipaddress": + fields.pop("assigned_object_type") + fields["assigned_object_type_id"] = fields.pop("assigned_object_id") + case "ipam.prefix" | "virtualization.cluster": + fields["scope_type"] = object_type_model + return fields + def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict): """Retrieve the primary IP address object.""" ip_address = object_data.get(primary_ip_attr) @@ -347,8 +386,8 @@ def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict): interface_device = interface.get("device") if interface_device is None: return None - - ip_address_object = self._get_object_type_model("ipam.ipaddress").objects.get( + object_type_mode, object_type_model_class = self._get_object_type_model("ipam.ipaddress") + ip_address_object = object_type_model_class.objects.get( address=ip_address.get("address"), interface__name=interface.get("name"), interface__device__name=interface_device.get("name"), @@ -418,7 +457,7 @@ def _handle_ipaddress_assigned_object(self, object_data: dict) -> Optional[Dict[ assigned_object_keys = list(ipaddress_assigned_object.keys()) model_name = assigned_object_keys[0] assigned_object_type = self._get_assigned_object_type(model_name) - assigned_object_model = self._get_object_type_model(assigned_object_type) + assigned_object_model, object_type_model_class = self._get_object_type_model(assigned_object_type) assigned_object_properties_dict = dict( ipaddress_assigned_object[model_name].items() ) @@ -449,9 +488,9 @@ def _handle_ipaddress_assigned_object(self, object_data: dict) -> Optional[Dict[ return {"assigned_object": error} assigned_object_instance = ( - assigned_object_model.objects.prefetch_related(*lookups).get(**args) + object_type_model_class.objects.prefetch_related(*lookups).get(**args) ) - except assigned_object_model.DoesNotExist: + except object_type_model_class.DoesNotExist: return { "assigned_object": f"Assigned object with name {ipaddress_assigned_object[model_name]} does not exist" } @@ -480,16 +519,17 @@ def _handle_scope(self, object_data: dict) -> Optional[Dict[str, Any]]: """Handle scope object.""" if object_data.get("site"): site = object_data.pop("site") - object_data["scope_type"] = "dcim.site" - scope_type_model = self._get_object_type_model("dcim.site") + scope_type = "dcim.site" + _, object_type_model_class = self._get_object_type_model(scope_type) + object_data["scope_type"] = scope_type site_id = site.get("id", None) if site_id is None: try: - site = scope_type_model.objects.get( + site = object_type_model_class.objects.get( name=site.get("name") ) site_id = site.id - except scope_type_model.DoesNotExist: + except object_type_model_class.DoesNotExist: return {"site": f"site with name {site.get('name')} does not exist"} object_data["scope_id"] = site_id @@ -549,9 +589,11 @@ def post(self, request, *args, **kwargs): serializer_errors.append({"change_id": change_id, **errors}) continue - serializer = self._get_serializer( - change_type, object_id, object_type, object_data, change_set_id - ) + serializer = self._get_serializer(change_type, object_id, object_type, object_data) + + # Skip creating an object if it already exists + if change_type == "create" and serializer.context.get("pk"): + continue if serializer.is_valid(): serializer.save()