Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions netbox_diode_plugin/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,21 @@ def _handle_ipaddress_assigned_object(
object_data["assigned_object_id"] = assigned_object_instance.id
return None

def _handle_interface_mac_address_compat(self, instance, object_type: str, object_data: dict) -> Optional[Dict[str, Any]]:
"""Handle interface mac address backward compatibility."""
# TODO(ltucker): deprecate.
if object_type != "dcim.interface" and object_type != "virtualization.vminterface":
return None

if object_data.get("mac_address"):
mac_address_value = object_data.pop("mac_address")
mac_address_instance, _ = instance.mac_addresses.get_or_create(
mac_address=mac_address_value,
)
instance.primary_mac_address = mac_address_instance
instance.save()
return None

def post(self, request, *args, **kwargs):
"""
Create a new change set and apply it to the current state.
Expand Down Expand Up @@ -533,6 +548,12 @@ def post(self, request, *args, **kwargs):
serializer_errors.append(
{"change_id": change_id, **errors_dict}
)
continue

errors = self._handle_interface_mac_address_compat(serializer.instance, object_type, object_data)
if errors is not None:
serializer_errors.append({"change_id": change_id, **errors})
continue
if len(serializer_errors) > 0:
raise ApplyChangeSetException
except ApplyChangeSetException:
Expand Down
120 changes: 119 additions & 1 deletion netbox_diode_plugin/tests/test_api_apply_change_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from rest_framework import status
from users.models import Token
from utilities.testing import APITestCase
from virtualization.models import Cluster, ClusterType
from virtualization.models import (
Cluster,
ClusterType,
VirtualMachine,
VMInterface,
)

User = get_user_model()

Expand Down Expand Up @@ -145,6 +150,12 @@ def setUp(self):
)
IPAddress.objects.bulk_create(self.ip_addresses)

self.virtual_machines = (
VirtualMachine(name="Virtual Machine 1"),
VirtualMachine(name="Virtual Machine 2"),
)
VirtualMachine.objects.bulk_create(self.virtual_machines)

self.url = "/netbox/api/plugins/diode/apply-change-set/"

def send_request(self, payload, status_code=status.HTTP_200_OK):
Expand Down Expand Up @@ -982,3 +993,110 @@ def test_add_primary_ip_address_to_device(self):
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(device_updated.name, self.devices[0].name)
self.assertEqual(device_updated.primary_ip4, self.ip_addresses[0])

def test_create_and_update_interface_with_compat_mac_address_field(self):
"""Test create interface using backward compatible mac_address field."""
payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "create",
"object_version": None,
"object_type": "dcim.interface",
"object_id": None,
"data": {
"name": "Interface 6",
"type": "virtual",
"mac_address": "00:00:00:00:00:01",
"device": {
"id": self.devices[1].pk,
},
},
},
],
}

response = self.send_request(payload)
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(Interface.objects.count(), 6)
interface_id = Interface.objects.order_by('-id').first().id
self.assertEqual(Interface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:01")

payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "update",
"object_version": None,
"object_type": "dcim.interface",
"object_id": interface_id,
"data": {
"name": "Interface 6",
"mac_address": "00:00:00:00:00:02",
"type": "virtual",
"device": {
"id": self.devices[1].pk,
},
},
},
],
}
response = self.send_request(payload)
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(Interface.objects.count(), 6)
self.assertEqual(Interface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:02")

def test_create_and_update_vminterface_with_compat_mac_address_field(self):
"""Test create vminterface using backward compatible mac_address field."""
payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "create",
"object_version": None,
"object_type": "virtualization.vminterface",
"object_id": None,
"data": {
"name": "VM Interface 1",
"mac_address": "00:00:00:00:00:01",
"virtual_machine": {
"id": self.virtual_machines[0].pk,
},
},
},
],
}

response = self.send_request(payload)
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(VMInterface.objects.count(), 1)
interface_id = VMInterface.objects.order_by('-id').first().id
self.assertEqual(VMInterface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:01")

payload = {
"change_set_id": str(uuid.uuid4()),
"change_set": [
{
"change_id": str(uuid.uuid4()),
"change_type": "update",
"object_version": None,
"object_type": "virtualization.vminterface",
"object_id": interface_id,
"data": {
"name": "VM Interface 1",
"mac_address": "00:00:00:00:00:02",
"virtual_machine": {
"id": self.virtual_machines[0].pk,
},
},
},
],
}
response = self.send_request(payload)
self.assertEqual(response.json().get("result"), "success")
self.assertEqual(VMInterface.objects.count(), 1)
self.assertEqual(VMInterface.objects.get(id=interface_id).mac_address, "00:00:00:00:00:02")