diff --git a/README.rst b/README.rst index b9e973f0b..fbf851a0d 100644 --- a/README.rst +++ b/README.rst @@ -1427,6 +1427,70 @@ List devices GET /api/v1/controller/device/ +**Available filters** + +You can filter a list of devices based on their configuration +status using the ``status`` (e.g modified, applied, or error). + +.. code-block:: text + + GET /api/v1/controller/device/?config__status={status} + +You can filter a list of devices based on their configuration backend +using the ``backend`` (e.g netjsonconfig.OpenWrt or netjsonconfig.OpenWisp). + +.. code-block:: text + + GET /api/v1/controller/device/?config__backend={backend} + +You can filter a list of devices based on their +organization using the ``organization_id`` or ``organization_slug``. + +.. code-block:: text + + GET /api/v1/controller/device/?organization={organization_id} + +.. code-block:: text + + GET /api/v1/controller/device/?organization_slug={organization_slug} + +You can filter a list of devices based on their +configuration templates using the ``template_id``. + +.. code-block:: text + + GET /api/v1/controller/device/?config__templates={template_id} + +You can filter a list of devices based on +their device group using the ``group_id``. + +.. code-block:: text + + GET /api/v1/controller/device/?group={group_id} + +You can filter a list of devices that have a device +location object using the ``with_geo`` (eg. true or false). + +.. code-block:: text + + GET /api/v1/controller/device/?with_geo={with_geo} + +You can filter a list of devices based on +their creation time using the ``creation_time``. + +.. code-block:: text + + # Created exact + GET /api/v1/controller/device/?created={creation_time} + + # Created greater than or equal to + GET /api/v1/controller/device/?created__gte={creation_time} + + # Created is less than + GET /api/v1/controller/device/?created__lt={creation_time} + + + Create device ############# @@ -1637,6 +1701,27 @@ List device groups GET /api/v1/controller/group/ +**Available filters** + +You can filter a list of device groups based on their +organization using the ``organization_id`` or ``organization_slug``. + +.. code-block:: text + + GET /api/v1/controller/group/?organization={organization_id} + +.. code-block:: text + + GET /api/v1/controller/group/?organization_slug={organization_slug} + +You can filter a list of device groups that have a +device object using the ``empty`` (eg. true or false). + +.. code-block:: text + + GET /api/v1/controller/group/?empty={empty} + + Create device group ################### @@ -1913,12 +1998,18 @@ List locations GET /api/v1/controller/location/ -You can filter using ``organization_slug`` to get list locations that -belongs to an organization. +**Available filters** + +You can filter using ``organization_id`` or ``organization_slug`` +to get list locations that belongs to an organization. + +.. code-block:: text + + GET /api/v1/controller/location/?organization={organization_id} .. code-block:: text - GET /api/v1/controller/location/?organization_slug= + GET /api/v1/controller/location/?organization_slug={organization_slug} Create location ############### @@ -2023,12 +2114,18 @@ List locations with devices deployed (in GeoJSON format) GET /api/v1/controller/location/geojson/ -You can filter using ``organization_slug`` to get list location of -devices from that organization. +**Available filters** + +You can filter using ``organization_id`` or ``organization_slug`` +to get list location of devices from that organization. + +.. code-block:: text + + GET /api/v1/controller/location/geojson/?organization_id={organization_id} .. code-block:: text - GET /api/v1/controller/location/geojson/?organization_slug= + GET /api/v1/controller/location/geojson/?organization_slug={organization_slug} List floorplans ############### @@ -2037,12 +2134,18 @@ List floorplans GET /api/v1/controller/floorplan/ -You can filter using ``organization_slug`` to get list floorplans that -belongs to an organization. +**Available filters** + +You can filter using ``organization_id`` or ``organization_slug`` +to get list floorplans that belongs to an organization. .. code-block:: text - GET /api/v1/controller/floorplan/?organization_slug= + GET /api/v1/controller/floorplan/?organization={organization_id} + +.. code-block:: text + + GET /api/v1/controller/floorplan/?organization_slug={organization_slug} Create floorplan ################ @@ -2079,6 +2182,64 @@ List templates GET /api/v1/controller/template/ +**Available filters** + +You can filter a list of templates based on their organization +using the ``organization_id`` or ``organization_slug``. + +.. code-block:: text + + GET /api/v1/controller/template/?organization={organization_id} + +.. code-block:: text + + GET /api/v1/controller/template/?organization_slug={organization_slug} + +You can filter a list of templates based on their backend using +the ``backend`` (e.g netjsonconfig.OpenWrt or netjsonconfig.OpenWisp). + +.. code-block:: text + + GET /api/v1/controller/template/?backend={backend} + +You can filter a list of templates based on their +type using the ``type`` (eg. vpn or generic). + +.. code-block:: text + + GET /api/v1/controller/template/?type={type} + +You can filter a list of templates that are enabled +by default or not using the ``default`` (eg. true or false). + +.. code-block:: text + + GET /api/v1/controller/template/?default={default} + +You can filter a list of templates that are required +or not using the ``required`` (eg. true or false). + +.. code-block:: text + + GET /api/v1/controller/template/?required={required} + +You can filter a list of templates based on +their creation time using the ``creation_time``. + +.. code-block:: text + + # Created exact + + GET /api/v1/controller/template/?created={creation_time} + + # Created greater than or equal to + + GET /api/v1/controller/template/?created__gte={creation_time} + + # Created is less than + + GET /api/v1/controller/template/?created__lt={creation_time} + Create template ############### @@ -2131,6 +2292,34 @@ List VPNs GET /api/v1/controller/vpn/ +**Available filters** + +You can filter a list of vpns based +on their backend using the ``backend`` +(e.g openwisp_controller.vpn_backends.OpenVpn +or openwisp_controller.vpn_backends.Wireguard). + +.. code-block:: text + + GET /api/v1/controller/vpn/?backend={backend} + +You can filter a list of vpns based on their subnet using the ``subnet_id``. + +.. code-block:: text + + GET /api/v1/controller/vpn/?subnet={subnet_id} + +You can filter a list of vpns based on their organization +using the ``organization_id`` or ``organization_slug``. + +.. code-block:: text + + GET /api/v1/controller/vpn/?organization={organization_id} + +.. code-block:: text + + GET /api/v1/controller/vpn/?organization_slug={organization_slug} + Create VPN ########## diff --git a/openwisp_controller/config/api/filters.py b/openwisp_controller/config/api/filters.py new file mode 100644 index 000000000..9831fc9e8 --- /dev/null +++ b/openwisp_controller/config/api/filters.py @@ -0,0 +1,134 @@ +from uuid import UUID + +from django.utils.translation import gettext_lazy as _ +from django_filters import rest_framework as filters +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework.exceptions import ValidationError +from swapper import load_model + +from openwisp_users.api.filters import OrganizationManagedFilter + +Template = load_model('config', 'Template') +Vpn = load_model('config', 'Vpn') +Device = load_model('config', 'Device') +DeviceGroup = load_model('config', 'DeviceGroup') + + +class BaseConfigAPIFilter(OrganizationManagedFilter): + def _set_valid_filterform_lables(self): + # When not filtering on a model field, an error message + # with the label "[invalid_name]" will be displayed in filter form. + # To avoid this error, we need to provide the label explicitly. + raise NotImplementedError + + +class TemplateListFilter(BaseConfigAPIFilter): + created__gte = filters.DateTimeFilter( + field_name='created', + lookup_expr='gte', + ) + created__lt = filters.DateTimeFilter( + field_name='created', + lookup_expr='lt', + ) + + def _set_valid_filterform_lables(self): + self.filters['backend'].label = _('Template backend') + self.filters['type'].label = _('Template type') + + def __init__(self, *args, **kwargs): + super(TemplateListFilter, self).__init__(*args, **kwargs) + self._set_valid_filterform_lables() + + class Meta(BaseConfigAPIFilter.Meta): + model = Template + fields = BaseConfigAPIFilter.Meta.fields + [ + 'backend', + 'type', + 'default', + 'required', + 'created', + ] + + +class VPNListFilter(BaseConfigAPIFilter): + def _set_valid_filterform_lables(self): + self.filters['backend'].label = _('VPN Backend') + self.filters['subnet'].label = _('VPN Subnet') + + def __init__(self, *args, **kwargs): + super(VPNListFilter, self).__init__(*args, **kwargs) + self._set_valid_filterform_lables() + + class Meta(BaseConfigAPIFilter.Meta): + model = Vpn + fields = BaseConfigAPIFilter.Meta.fields + ['backend', 'subnet'] + + +class DeviceListFilterBackend(DjangoFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + Validate that the request parameters contain + a valid configuration template uuid format + """ + config_template_uuid = request.query_params.get('config__templates') + if config_template_uuid: + try: + # Attempt to convert the uuid string to a UUID object + config_template_uuid_obj = UUID(config_template_uuid) + except ValueError: + raise ValidationError({'config__templates': _('Invalid UUID format')}) + # Add the config__templates filter to the queryset + return queryset.filter(config__templates=config_template_uuid_obj) + return super().filter_queryset(request, queryset, view) + + +class DeviceListFilter(BaseConfigAPIFilter): + created__gte = filters.DateTimeFilter( + field_name='created', + lookup_expr='gte', + ) + created__lt = filters.DateTimeFilter( + field_name='created', + lookup_expr='lt', + ) + + def _set_valid_filterform_lables(self): + self.filters['group'].label = _('Device group') + self.filters['config__templates'].label = _('Config template') + self.filters['config__status'].label = _('Config status') + self.filters['config__backend'].label = _('Config backend') + + def __init__(self, *args, **kwargs): + super(DeviceListFilter, self).__init__(*args, **kwargs) + self._set_valid_filterform_lables() + + class Meta(BaseConfigAPIFilter.Meta): + model = Device + fields = BaseConfigAPIFilter.Meta.fields + [ + 'config__status', + 'config__backend', + 'config__templates', + 'group', + 'created', + ] + + +class DeviceGroupListFilter(BaseConfigAPIFilter): + # Using filter query param name `empty` + # which is similar to admin filter + empty = filters.BooleanFilter(field_name='device', method='filter_device') + + def filter_device(self, queryset, name, value): + # Returns list of device groups that have devicelocation objects + return queryset.exclude(device__isnull=value).distinct() + + def _set_valid_filterform_lables(self): + self.filters['empty'].label = _('Has devices?') + + def __init__(self, *args, **kwargs): + super(DeviceGroupListFilter, self).__init__(*args, **kwargs) + self._set_valid_filterform_lables() + + class Meta(BaseConfigAPIFilter.Meta): + model = DeviceGroup diff --git a/openwisp_controller/config/api/views.py b/openwisp_controller/config/api/views.py index 79be9cd3b..dd06aaf72 100644 --- a/openwisp_controller/config/api/views.py +++ b/openwisp_controller/config/api/views.py @@ -3,6 +3,7 @@ from django.db.models import F, Q from django.http import Http404 from django.urls.base import reverse +from django_filters.rest_framework import DjangoFilterBackend from rest_framework import pagination from rest_framework.generics import ( ListCreateAPIView, @@ -12,6 +13,13 @@ from swapper import load_model from ...mixins import ProtectedAPIMixin +from .filters import ( + DeviceGroupListFilter, + DeviceListFilter, + DeviceListFilterBackend, + TemplateListFilter, + VPNListFilter, +) from .serializers import ( DeviceDetailSerializer, DeviceGroupSerializer, @@ -38,8 +46,10 @@ class ListViewPagination(pagination.PageNumberPagination): class TemplateListCreateView(ProtectedAPIMixin, ListCreateAPIView): serializer_class = TemplateSerializer - queryset = Template.objects.order_by('-created') + queryset = Template.objects.prefetch_related('tags').order_by('-created') pagination_class = ListViewPagination + filter_backends = [DjangoFilterBackend] + filterset_class = TemplateListFilter class TemplateDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): @@ -49,8 +59,10 @@ class TemplateDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): class VpnListCreateView(ProtectedAPIMixin, ListCreateAPIView): serializer_class = VpnSerializer - queryset = Vpn.objects.order_by('-created') + queryset = Vpn.objects.select_related('subnet').order_by('-created') pagination_class = ListViewPagination + filter_backends = [DjangoFilterBackend] + filterset_class = VPNListFilter class VpnDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): @@ -66,9 +78,11 @@ class DeviceListCreateView(ProtectedAPIMixin, ListCreateAPIView): serializer_class = DeviceListSerializer queryset = Device.objects.select_related( - 'config', 'group', 'organization' + 'config', 'group', 'organization', 'devicelocation' ).order_by('-created') pagination_class = ListViewPagination + filter_backends = [DeviceListFilterBackend] + filterset_class = DeviceListFilter class DeviceDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): @@ -83,8 +97,10 @@ class DeviceDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): class DeviceGroupListCreateView(ProtectedAPIMixin, ListCreateAPIView): serializer_class = DeviceGroupSerializer - queryset = DeviceGroup.objects.select_related('organization').order_by('-created') + queryset = DeviceGroup.objects.prefetch_related('templates').order_by('-created') pagination_class = ListViewPagination + filter_backends = [DjangoFilterBackend] + filterset_class = DeviceGroupListFilter class DeviceGroupDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): diff --git a/openwisp_controller/config/tests/test_api.py b/openwisp_controller/config/tests/test_api.py index f03eaecee..df3081f4e 100644 --- a/openwisp_controller/config/tests/test_api.py +++ b/openwisp_controller/config/tests/test_api.py @@ -2,6 +2,7 @@ from django.test import TestCase from django.test.testcases import TransactionTestCase from django.urls import reverse +from openwisp_ipam.tests import CreateModelsMixin as CreateIpamModelsMixin from swapper import load_model from openwisp_controller.tests.utils import TestAdminMixin @@ -10,7 +11,12 @@ from openwisp_utils.tests import capture_any_output, catch_signal from ..signals import group_templates_changed -from .utils import CreateConfigTemplateMixin, CreateDeviceGroupMixin, TestVpnX509Mixin +from .utils import ( + CreateConfigTemplateMixin, + CreateDeviceGroupMixin, + TestVpnX509Mixin, + TestWireguardVpnMixin, +) Template = load_model('config', 'Template') Vpn = load_model('config', 'Vpn') @@ -18,6 +24,8 @@ Device = load_model('config', 'Device') Config = load_model('config', 'Config') DeviceGroup = load_model('config', 'DeviceGroup') +DeviceLocation = load_model('geo', 'DeviceLocation') +Location = load_model('geo', 'Location') OrganizationUser = load_model('openwisp_users', 'OrganizationUser') @@ -78,6 +86,8 @@ class ApiTestMixin: class TestConfigApi( ApiTestMixin, TestAdminMixin, + CreateIpamModelsMixin, + TestWireguardVpnMixin, TestOrganizationMixin, CreateConfigTemplateMixin, TestVpnX509Mixin, @@ -194,6 +204,92 @@ def test_device_list_api(self): r = self.client.get(path) self.assertEqual(r.status_code, 200) + def test_device_list_api_filter(self): + org1 = self._create_org() + d1, _, t1 = self._create_wireguard_vpn_template(organization=org1) + org2 = self._create_org(name='test org 2') + dg2 = self._create_device_group(organization=org2) + d2 = self._create_device( + mac_address='00:11:22:33:44:66', group=dg2, organization=org2 + ) + t2 = self._create_template(name='t2', organization=org2) + c2 = self._create_config(device=d2, backend='netjsonconfig.OpenWisp') + c2.templates.add(t2) + c2.status = 'applied' + c2.full_clean() + c2.save() + path = reverse('config_api:device_list') + + def _assert_device_list_filter(response=None, device=None): + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results'][0]), 15) + self.assertEqual(data['results'][0]['id'], str(device.pk)) + self.assertEqual(data['results'][0]['name'], str(device.name)) + self.assertEqual(data['results'][0]['organization'], device.organization.pk) + self.assertEqual( + data['results'][0]['config']['status'], device.config.status + ) + self.assertEqual( + data['results'][0]['config']['backend'], device.config.backend + ) + self.assertIn('created', data['results'][0].keys()) + if device.group: + self.assertEqual(data['results'][0]['group'], device.group.pk) + + with self.subTest('Test filtering using config backend'): + r1 = self.client.get(f'{path}?config__backend=netjsonconfig.OpenWrt') + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(f'{path}?config__backend=netjsonconfig.OpenWisp') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using config status'): + r1 = self.client.get(f'{path}?config__status=modified') + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(f'{path}?config__status=applied') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using organization id'): + r1 = self.client.get(f'{path}?organization={org1.pk}') + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(f'{path}?organization={org2.pk}') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using organization slug'): + r1 = self.client.get(f'{path}?organization_slug={org1.slug}') + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(f'{path}?organization_slug={org2.slug}') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using config templates'): + r1 = self.client.get(f'{path}?config__templates={t1.pk}') + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(f'{path}?config__templates={t2.pk}') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using config templates invalid uuid'): + # test with invalid uuid string + r1 = self.client.get(f'{path}?config__templates={t1.pk}invalid_uuid') + self.assertEqual(r1.status_code, 400) + self.assertIn('Invalid UUID format', str(r1.content)) + # test with comma seperated uuid's string + r2 = self.client.get(f'{path}?config__templates={t1.pk},{t2.pk}') + self.assertEqual(r2.status_code, 400) + self.assertIn('Invalid UUID format', str(r2.content)) + + with self.subTest('Test filtering using device groups'): + r2 = self.client.get(f'{path}?group={dg2.pk}') + _assert_device_list_filter(response=r2, device=d2) + + with self.subTest('Test filtering using device created'): + r1 = self.client.get(path, {'created': d1.created}) + _assert_device_list_filter(response=r1, device=d1) + r2 = self.client.get(path, {'created__gte': d2.created}) + _assert_device_list_filter(response=r2, device=d2) + r2 = self.client.get(path, {'created__lt': d2.created}) + _assert_device_list_filter(response=r2, device=d1) + def test_device_filter_templates(self): org1 = self._create_org(name='org1') org2 = self._create_org(name='org2') @@ -444,6 +540,80 @@ def test_template_list_api(self): self.assertEqual(r.status_code, 200) self.assertEqual(Template.objects.count(), 1) + def test_template_list_api_filter(self): + org1 = self._create_org() + _, _, t1 = self._create_wireguard_vpn_template(organization=org1) + org2 = self._create_org(name='test org 2') + t2 = self._create_template( + name='t2', + organization=org2, + backend='netjsonconfig.OpenWisp', + default=True, + required=True, + ) + path = reverse('config_api:template_list') + + def _assert_template_list_filter(response=None, template=None): + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results'][0]), 13) + self.assertEqual(data['results'][0]['id'], str(template.pk)) + self.assertEqual(data['results'][0]['name'], str(template.name)) + self.assertEqual( + data['results'][0]['organization'], template.organization.pk + ) + self.assertEqual(data['results'][0]['type'], template.type) + self.assertEqual(data['results'][0]['backend'], template.backend) + self.assertEqual(data['results'][0]['default'], template.default) + self.assertEqual(data['results'][0]['required'], template.required) + if template.vpn: + self.assertEqual(data['results'][0]['vpn'], template.vpn.pk) + + with self.subTest('Test filtering using organization id'): + r1 = self.client.get(f'{path}?organization={org1.pk}') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?organization={org2.pk}') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using organization slug'): + r1 = self.client.get(f'{path}?organization_slug={org1.slug}') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?organization_slug={org2.slug}') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using template backend'): + r1 = self.client.get(f'{path}?backend=netjsonconfig.OpenWrt') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?backend=netjsonconfig.OpenWisp') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using template type'): + r1 = self.client.get(f'{path}?type=vpn') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?type=generic') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using template default'): + r1 = self.client.get(f'{path}?default=false') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?default=true') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using template required'): + r1 = self.client.get(f'{path}?required=false') + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(f'{path}?required=true') + _assert_template_list_filter(response=r2, template=t2) + + with self.subTest('Test filtering using template created'): + r1 = self.client.get(path, {'created': t1.created}) + _assert_template_list_filter(response=r1, template=t1) + r2 = self.client.get(path, {'created__gte': t2.created}) + _assert_template_list_filter(response=r2, template=t2) + r2 = self.client.get(path, {'created__lt': t2.created}) + _assert_template_list_filter(response=r2, template=t1) + def test_template_list_for_shared_objects(self): org1 = self._get_org() self._create_vpn(name='shared-vpn', organization=None) @@ -534,6 +704,48 @@ def test_vpn_list_api(self): r = self.client.get(path) self.assertEqual(r.status_code, 200) + def test_vpn_list_api_filter(self): + org1 = self._create_org() + org2 = self._create_org(name='test org 2') + vpn1 = self._create_vpn(organization=org1) + vpn2 = self._create_wireguard_vpn(organization=org2) + path = reverse('config_api:vpn_list') + + def _assert_vpn_list_filter(response=None, vpn=None): + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results'][0]), 13) + self.assertEqual(data['results'][0]['id'], str(vpn.pk)) + self.assertEqual(data['results'][0]['name'], str(vpn.name)) + self.assertEqual(data['results'][0]['organization'], vpn.organization.pk) + + with self.subTest('Test filtering using VPN backend'): + r1 = self.client.get( + f'{path}?backend=openwisp_controller.vpn_backends.OpenVpn' + ) + _assert_vpn_list_filter(response=r1, vpn=vpn1) + r2 = self.client.get( + f'{path}?backend=openwisp_controller.vpn_backends.Wireguard' + ) + _assert_vpn_list_filter(response=r2, vpn=vpn2) + + with self.subTest('Test filtering using VPN subnet'): + r2 = self.client.get(f'{path}?subnet={vpn2.subnet.pk}') + _assert_vpn_list_filter(response=r2, vpn=vpn2) + + with self.subTest('Test filtering using organization id'): + r1 = self.client.get(f'{path}?organization={org1.pk}') + _assert_vpn_list_filter(response=r1, vpn=vpn1) + r2 = self.client.get(f'{path}?organization={org2.pk}') + _assert_vpn_list_filter(response=r2, vpn=vpn2) + + with self.subTest('Test filtering using organization slug'): + r1 = self.client.get(f'{path}?organization_slug={org1.slug}') + _assert_vpn_list_filter(response=r1, vpn=vpn1) + r2 = self.client.get(f'{path}?organization_slug={org2.slug}') + _assert_vpn_list_filter(response=r2, vpn=vpn2) + def test_vpn_list_for_shared_objects(self): ca = self._create_ca(name='shared_ca', organization=None) self._create_cert(ca=ca, name='shared_cert', organization=None) @@ -682,6 +894,46 @@ def test_devicegroup_list_api(self): r, f'', html=True ) + def test_devicegroup_list_api_filter(self): + org1 = self._create_org() + dg1 = self._create_device_group() + self._create_device(organization=org1) + org2 = self._create_org(name='test org 2') + dg2 = self._create_device_group(organization=org2) + self._create_device( + mac_address='00:11:22:33:44:66', group=dg2, organization=org2 + ) + path = reverse('config_api:devicegroup_list') + + def _assert_devicegroup_list_filter(response=None, device_group=None): + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results'][0]), 8) + self.assertEqual(data['results'][0]['id'], str(device_group.pk)) + self.assertEqual(data['results'][0]['name'], str(device_group.name)) + self.assertEqual( + data['results'][0]['organization'], device_group.organization.pk + ) + + with self.subTest('Test filtering using organization id'): + r1 = self.client.get(f'{path}?organization={org1.pk}') + _assert_devicegroup_list_filter(response=r1, device_group=dg1) + r2 = self.client.get(f'{path}?organization={org2.pk}') + _assert_devicegroup_list_filter(response=r2, device_group=dg2) + + with self.subTest('Test filtering using organization slug'): + r1 = self.client.get(f'{path}?organization_slug={org1.slug}') + _assert_devicegroup_list_filter(response=r1, device_group=dg1) + r2 = self.client.get(f'{path}?organization_slug={org2.slug}') + _assert_devicegroup_list_filter(response=r2, device_group=dg2) + + with self.subTest('Test filtering using device'): + r1 = self.client.get(f'{path}?empty=false') + _assert_devicegroup_list_filter(response=r1, device_group=dg1) + r2 = self.client.get(f'{path}?empty=true') + _assert_devicegroup_list_filter(response=r2, device_group=dg2) + def test_devicegroup_detail_api(self): device_group = self._create_device_group() path = reverse('config_api:devicegroup_detail', args=[device_group.pk]) diff --git a/openwisp_controller/config/tests/utils.py b/openwisp_controller/config/tests/utils.py index 3fa1fa74e..581e049f0 100644 --- a/openwisp_controller/config/tests/utils.py +++ b/openwisp_controller/config/tests/utils.py @@ -155,9 +155,9 @@ def _create_wireguard_vpn(self, config=None, **kwargs): vpn.save() return vpn - def _create_wireguard_vpn_template(self, auto_cert=True): + def _create_wireguard_vpn_template(self, auto_cert=True, **kwargs): vpn = self._create_wireguard_vpn() - org1 = vpn.organization + org1 = kwargs.get('organization', vpn.organization) template = self._create_template( name='wireguard', type='vpn', diff --git a/openwisp_controller/geo/api/filters.py b/openwisp_controller/geo/api/filters.py new file mode 100644 index 000000000..444664ade --- /dev/null +++ b/openwisp_controller/geo/api/filters.py @@ -0,0 +1,27 @@ +from django.utils.translation import gettext_lazy as _ +from django_filters import rest_framework as filters + +from openwisp_controller.config.api.filters import ( + DeviceListFilter as BaseDeviceListFilter, +) + + +class DeviceListFilter(BaseDeviceListFilter): + # Using filter query param name `with_geo` + # which is similar to admin filter + with_geo = filters.BooleanFilter( + field_name='devicelocation', method='filter_devicelocation' + ) + + def _set_valid_filterform_lables(self): + super()._set_valid_filterform_lables() + self.filters['with_geo'].label = _('Has geographic location set?') + + def filter_devicelocation(self, queryset, name, value): + # Returns list of device that have devicelocation objects + return queryset.exclude(devicelocation__isnull=value) + + class Meta: + model = BaseDeviceListFilter.Meta.model + fields = BaseDeviceListFilter.Meta.fields[:] + fields.insert(fields.index('created'), 'with_geo') diff --git a/openwisp_controller/geo/api/views.py b/openwisp_controller/geo/api/views.py index 1d7b48b2e..9d65d1fad 100644 --- a/openwisp_controller/geo/api/views.py +++ b/openwisp_controller/geo/api/views.py @@ -10,9 +10,12 @@ from rest_framework_gis.pagination import GeoJsonPagination from swapper import load_model +from openwisp_controller.config.api.views import DeviceListCreateView +from openwisp_users.api.filters import OrganizationManagedFilter from openwisp_users.api.mixins import FilterByOrganizationManaged, FilterByParentManaged from ...mixins import ProtectedAPIMixin +from .filters import DeviceListFilter from .serializers import ( DeviceCoordinatesSerializer, DeviceLocationSerializer, @@ -37,21 +40,14 @@ def has_object_permission(self, request, view, obj): return hasattr(obj, 'key') and request.query_params.get('key') == obj.key -class BaseOrganizationSlugFilter(filters.FilterSet): - organization_slug = filters.CharFilter(field_name='organization__slug') - - class Meta: - fields = ['organization_slug'] - - -class LocationOrganizationSlugFilter(BaseOrganizationSlugFilter): - class Meta(BaseOrganizationSlugFilter.Meta): +class LocationOrganizationFilter(OrganizationManagedFilter): + class Meta(OrganizationManagedFilter.Meta): model = Location - fields = BaseOrganizationSlugFilter.Meta.fields + ['is_mobile', 'type'] + fields = OrganizationManagedFilter.Meta.fields + ['is_mobile', 'type'] -class FloorPlanOrganizationSlugFilter(BaseOrganizationSlugFilter): - class Meta(BaseOrganizationSlugFilter.Meta): +class FloorPlanOrganizationFilter(OrganizationManagedFilter): + class Meta(OrganizationManagedFilter.Meta): model = FloorPlan @@ -180,7 +176,7 @@ class GeoJsonLocationList( serializer_class = GeoJsonLocationSerializer pagination_class = GeoJsonLocationListPagination filter_backends = [filters.DjangoFilterBackend] - filterset_class = LocationOrganizationSlugFilter + filterset_class = LocationOrganizationFilter class LocationDeviceList( @@ -205,7 +201,7 @@ class FloorPlanListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): queryset = FloorPlan.objects.select_related().order_by('-created') pagination_class = ListViewPagination filter_backends = [filters.DjangoFilterBackend] - filterset_class = FloorPlanOrganizationSlugFilter + filterset_class = FloorPlanOrganizationFilter class FloorPlanDetailView( @@ -221,7 +217,7 @@ class LocationListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): queryset = Location.objects.order_by('-created') pagination_class = ListViewPagination filter_backends = [filters.DjangoFilterBackend] - filterset_class = LocationOrganizationSlugFilter + filterset_class = LocationOrganizationFilter class LocationDetailView( @@ -232,6 +228,9 @@ class LocationDetailView( queryset = Location.objects.all() +# add with_geo filter to device API +DeviceListCreateView.filterset_class = DeviceListFilter + device_coordinates = DeviceCoordinatesView.as_view() device_location = DeviceLocationView.as_view() geojson = GeoJsonLocationList.as_view() diff --git a/openwisp_controller/geo/tests/test_api.py b/openwisp_controller/geo/tests/test_api.py index 361e3277f..2e982d306 100644 --- a/openwisp_controller/geo/tests/test_api.py +++ b/openwisp_controller/geo/tests/test_api.py @@ -11,7 +11,10 @@ from rest_framework.authtoken.models import Token from swapper import load_model -from openwisp_controller.config.tests.utils import CreateConfigTemplateMixin +from openwisp_controller.config.tests.utils import ( + CreateConfigTemplateMixin, + CreateDeviceMixin, +) from openwisp_controller.tests.utils import TestAdminMixin from openwisp_users.tests.utils import TestOrganizationMixin from openwisp_utils.tests import AssertNumQueriesSubTestMixin, capture_any_output @@ -251,6 +254,14 @@ def test_geojson_list(self): self.assertEqual( response_data['features'][0]['properties']['organization'], org_a.id ) + with self.subTest('Test filtering using organization id'): + self.client.login(username='admin', password='tester') + response = self.client.get(reverse(url), data={'organization': org_a.id}) + response_data = response.data + self.assertEqual(response_data['count'], 1) + self.assertEqual( + response_data['features'][0]['properties']['organization'], org_a.id + ) with self.subTest('Test geojson list unauthenticated user'): self.client.logout() @@ -263,6 +274,7 @@ class TestGeoApi( TestOrganizationMixin, TestGeoMixin, TestAdminMixin, + CreateDeviceMixin, TestCase, ): object_model = Device @@ -315,6 +327,14 @@ def test_filter_floorplan_list(self): self.assertContains(response, org1_floorplan.id) self.assertNotContains(response, org2_floorplan.id) + with self.subTest('Test filtering with organization id'): + with self.assertNumQueries(5): + response = self.client.get(path, {'organization': org1.id}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['count'], 1) + self.assertContains(response, org1_floorplan.id) + self.assertNotContains(response, org2_floorplan.id) + with self.subTest('Test multi-tenancy filtering'): self.client.logout() user = self._create_administrator([org1]) @@ -413,6 +433,14 @@ def test_filter_location_list(self): self.assertContains(response, org1_location.id) self.assertNotContains(response, org2_location.id) + with self.subTest('Test filtering with organization id'): + with self.assertNumQueries(6): + response = self.client.get(path, {'organization': org1.id}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['count'], 1) + self.assertContains(response, org1_location.id) + self.assertNotContains(response, org2_location.id) + with self.subTest('Test filtering with location type'): with self.assertNumQueries(5): response = self.client.get(path, {'type': 'indoor'}) @@ -899,3 +927,32 @@ def test_retrieve_devicelocation(self): self.assertIn('image', response.data['floorplan'].keys()) self.assertIn('created', response.data['floorplan'].keys()) self.assertIn('modified', response.data['floorplan'].keys()) + + def test_device_list_api_with_geo_filter(self): + org_a = self._create_org() + org_b = self._create_org(name='test org b') + device_a = self._create_device(organization=org_a) + device_b = self._create_device(organization=org_b) + location_b = self._create_location(organization=org_b) + # create device location for device_b + self._create_device_location(content_object=device_b, location=location_b) + path = reverse('config_api:device_list') + + def _assert_device_list_with_geo_filter(response=None, device=None): + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results'][0]), 15) + self.assertEqual(data['results'][0]['id'], str(device.pk)) + self.assertEqual(data['results'][0]['name'], str(device.name)) + self.assertEqual(data['results'][0]['organization'], device.organization.pk) + self.assertEqual(data['results'][0]['config'], None) + self.assertIn('created', data['results'][0].keys()) + + with self.subTest('Test filtering using device location'): + # make sure device_a is in the api response + r1 = self.client.get(f'{path}?with_geo=false') + _assert_device_list_with_geo_filter(response=r1, device=device_a) + # make sure device_b is in the api response + r2 = self.client.get(f'{path}?with_geo=true') + _assert_device_list_with_geo_filter(response=r2, device=device_b)