Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 108 additions & 66 deletions netbox_diode_plugin/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def _get_index_class_fields(object_type):
# Get the model class dynamically
model = apps.get_model(app_label, model_name)

# TagIndex registered in the netbox_diode_plugin
if app_label == "extras" and model_name == "tag":
app_label = "netbox_diode_plugin"

# Import the module where index classes are defined (adjust if needed)
index_module = dynamic_import(f"{app_label}.search.{model.__name__}Index")

Expand Down Expand Up @@ -246,7 +250,7 @@ def _get_object_type_model(object_type: str):
object_content_type = NetBoxType.objects.get_by_natural_key(
app_label, model_name
)
return object_content_type.model_class()
return object_content_type, object_content_type.model_class()

def _get_assigned_object_type(self, model_name: str):
"""Get the object type model from applied IPAddress assigned object."""
Expand All @@ -255,79 +259,114 @@ def _get_assigned_object_type(self, model_name: str):
}
return assignable_object_types.get(model_name.lower(), None)

def _add_nested_opts(self, fields, key, value):
if isinstance(value, dict):
for nested_key, nested_value in value.items():
self._add_nested_opts(fields, f"{key}__{nested_key}", nested_value)
elif not isinstance(value, list):
fields[key] = value

def _get_serializer(
self,
change_type: str,
object_id: int,
object_type: str,
object_data: dict,
change_set_id: str,
):
"""Get the serializer for the object type."""
object_type_model = self._get_object_type_model(object_type)
object_type_model, object_type_model_class = self._get_object_type_model(object_type)

if change_type == "create":
serializer = get_serializer_for_model(object_type_model)(
data=object_data, context={"request": self.request}
)
elif change_type == "update":
lookups = ()
args = {}
return self._get_serializer_to_create(object_data, object_type, object_type_model, object_type_model_class)

primary_ip_to_set: Optional[dict] = None
if change_type == "update":
return self._get_serializer_to_update(object_data, object_id, object_type, object_type_model_class)

raise ValidationError("Invalid change_type")

def _get_serializer_to_create(self, object_data, object_type, object_type_model, object_type_model_class):
# Get object data fields that are not dictionaries or lists
fields = self._get_fields_to_find_existing_objects(object_data, object_type, object_type_model)
# Check if the object already exists
try:
instance = object_type_model_class.objects.get(**fields)
return get_serializer_for_model(object_type_model_class)(
instance, data=object_data, context={"request": self.request, "pk": instance.pk}
)
except object_type_model_class.DoesNotExist:
pass
serializer = get_serializer_for_model(object_type_model_class)(
data=object_data, context={"request": self.request}
)
return serializer

if object_id:
args["id"] = object_id
elif object_type == "dcim.device" and any(
def _get_serializer_to_update(self, object_data, object_id, object_type, object_type_model_class):
lookups = ()
fields = {}
primary_ip_to_set: Optional[dict] = None
if object_id:
fields["id"] = object_id
elif object_type == "dcim.device" and any(
object_data.get(attr) for attr in ("primary_ip4", "primary_ip6")
):
):
ip_address = self._retrieve_primary_ip_address(
"primary_ip4", object_data
)

if ip_address is None:
ip_address = self._retrieve_primary_ip_address(
"primary_ip4", object_data
"primary_ip6", object_data
)

if ip_address is None:
ip_address = self._retrieve_primary_ip_address(
"primary_ip6", object_data
)

if ip_address is None:
raise ValidationError("primary IP not found")

if ip_address:
primary_ip_to_set = {
"id": ip_address.id,
"family": ip_address.family,
}
if ip_address is None:
raise ValidationError("primary IP not found")

lookups = ("site",)
args["name"] = object_data.get("name")
args["site__name"] = object_data.get("site").get("name")
else:
raise ValidationError("object_id parameter is required")
if ip_address:
primary_ip_to_set = {
"id": ip_address.id,
"family": ip_address.family,
}

try:
instance = object_type_model.objects.prefetch_related(*lookups).get(
**args
)
if object_type == "dcim.device" and primary_ip_to_set:
object_data = {
"id": instance.id,
"device_type": instance.device_type.id,
"role": instance.role.id,
"site": instance.site.id,
f'primary_ip{primary_ip_to_set.get("family")}': primary_ip_to_set.get(
"id"
),
}
except object_type_model.DoesNotExist:
raise ValidationError(f"object with id {object_id} does not exist")

serializer = get_serializer_for_model(object_type_model)(
instance, data=object_data, context={"request": self.request}
)
lookups = ("site",)
fields["name"] = object_data.get("name")
fields["site__name"] = object_data.get("site").get("name")
else:
raise ValidationError("Invalid change_type")
raise ValidationError("object_id parameter is required")
try:
instance = object_type_model_class.objects.prefetch_related(*lookups).get(**fields)
if object_type == "dcim.device" and primary_ip_to_set:
object_data = {
"id": instance.id,
"device_type": instance.device_type.id,
"role": instance.role.id,
"site": instance.site.id,
f'primary_ip{primary_ip_to_set.get("family")}': primary_ip_to_set.get(
"id"
),
}
except object_type_model_class.DoesNotExist:
raise ValidationError(f"object with id {object_id} does not exist")
serializer = get_serializer_for_model(object_type_model_class)(
instance, data=object_data, context={"request": self.request}
)
return serializer

def _get_fields_to_find_existing_objects(self, object_data, object_type, object_type_model):
fields = {}
for key, value in object_data.items():
self._add_nested_opts(fields, key, value)
match object_type:
case "dcim.interface" | "virtualization.vminterface":
mac_address = fields.pop("mac_address", None)
if mac_address is not None:
fields["primary_mac_address__mac_address"] = mac_address
case "ipam.ipaddress":
fields.pop("assigned_object_type")
fields["assigned_object_type_id"] = fields.pop("assigned_object_id")
case "ipam.prefix" | "virtualization.cluster":
fields["scope_type"] = object_type_model
return fields

def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict):
"""Retrieve the primary IP address object."""
ip_address = object_data.get(primary_ip_attr)
Expand All @@ -347,8 +386,8 @@ def _retrieve_primary_ip_address(self, primary_ip_attr: str, object_data: dict):
interface_device = interface.get("device")
if interface_device is None:
return None

ip_address_object = self._get_object_type_model("ipam.ipaddress").objects.get(
object_type_mode, object_type_model_class = self._get_object_type_model("ipam.ipaddress")
ip_address_object = object_type_model_class.objects.get(
address=ip_address.get("address"),
interface__name=interface.get("name"),
interface__device__name=interface_device.get("name"),
Expand Down Expand Up @@ -418,7 +457,7 @@ def _handle_ipaddress_assigned_object(self, object_data: dict) -> Optional[Dict[
assigned_object_keys = list(ipaddress_assigned_object.keys())
model_name = assigned_object_keys[0]
assigned_object_type = self._get_assigned_object_type(model_name)
assigned_object_model = self._get_object_type_model(assigned_object_type)
assigned_object_model, object_type_model_class = self._get_object_type_model(assigned_object_type)
assigned_object_properties_dict = dict(
ipaddress_assigned_object[model_name].items()
)
Expand Down Expand Up @@ -449,9 +488,9 @@ def _handle_ipaddress_assigned_object(self, object_data: dict) -> Optional[Dict[
return {"assigned_object": error}

assigned_object_instance = (
assigned_object_model.objects.prefetch_related(*lookups).get(**args)
object_type_model_class.objects.prefetch_related(*lookups).get(**args)
)
except assigned_object_model.DoesNotExist:
except object_type_model_class.DoesNotExist:
return {
"assigned_object": f"Assigned object with name {ipaddress_assigned_object[model_name]} does not exist"
}
Expand Down Expand Up @@ -480,16 +519,17 @@ def _handle_scope(self, object_data: dict) -> Optional[Dict[str, Any]]:
"""Handle scope object."""
if object_data.get("site"):
site = object_data.pop("site")
object_data["scope_type"] = "dcim.site"
scope_type_model = self._get_object_type_model("dcim.site")
scope_type = "dcim.site"
_, object_type_model_class = self._get_object_type_model(scope_type)
object_data["scope_type"] = scope_type
site_id = site.get("id", None)
if site_id is None:
try:
site = scope_type_model.objects.get(
site = object_type_model_class.objects.get(
name=site.get("name")
)
site_id = site.id
except scope_type_model.DoesNotExist:
except object_type_model_class.DoesNotExist:
return {"site": f"site with name {site.get('name')} does not exist"}

object_data["scope_id"] = site_id
Expand Down Expand Up @@ -549,9 +589,11 @@ def post(self, request, *args, **kwargs):
serializer_errors.append({"change_id": change_id, **errors})
continue

serializer = self._get_serializer(
change_type, object_id, object_type, object_data, change_set_id
)
serializer = self._get_serializer(change_type, object_id, object_type, object_data)

# Skip creating an object if it already exists
if change_type == "create" and serializer.context.get("pk"):
continue

if serializer.is_valid():
serializer.save()
Expand Down