Skip to content

Commit

Permalink
[change] Add reusable FilterByParent class
Browse files Browse the repository at this point in the history
  • Loading branch information
purhan committed Dec 19, 2020
1 parent 07af296 commit 29ae7f5
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 36 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ ENV/

# Generated CSV files
data_*.csv

# IDE settings
.vscode/
28 changes: 28 additions & 0 deletions openwisp_ipam/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from django.core.exceptions import ValidationError
from rest_framework.exceptions import NotFound


class FilterByOrganization:
def get_queryset(self):
qs = super().get_queryset()
Expand All @@ -18,3 +22,27 @@ class FilterByOrganizationManaged(FilterByOrganization):

def get_organization_queryset(self, qs):
return qs.filter(organization__in=self.request.user.organizations_managed)


class FilterByParent:
def get_queryset(self):
qs = super().get_queryset()
self.assert_parent_exists()
return qs

def assert_parent_exists(self):
parent_queryset = self.get_parent_queryset()
if not self.request.user.is_superuser:
parent_queryset = self.get_organization_queryset(parent_queryset)
try:
assert parent_queryset.exists()
except (AssertionError, ValidationError):
raise NotFound(detail='No relevant data found.')

def get_parent_queryset(self):
raise NotImplementedError()


class FilterByParentManaged(FilterByParent):
def get_organization_queryset(self, qs):
return qs.filter(organization__in=self.request.user.organizations_managed)
45 changes: 15 additions & 30 deletions openwisp_ipam/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from openwisp_users.api.permissions import IsOrganizationManager
from rest_framework import pagination, serializers, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import PermissionDenied
from rest_framework.generics import (
CreateAPIView,
ListAPIView,
Expand All @@ -30,28 +29,17 @@
IpRequestSerializer,
SubnetSerializer,
)
from .utils import FilterByOrganizationManaged
from .utils import FilterByOrganizationManaged, FilterByParentManaged

IpAddress = swapper.load_model('openwisp_ipam', 'IpAddress')
Subnet = swapper.load_model('openwisp_ipam', 'Subnet')
Organization = swapper.load_model('openwisp_users', 'Organization')


class DispatchOrgMixin(object):
def dispatch(self, *args, **kwargs):
self.organization = get_object_or_404(
self.subnet_model, pk=self.kwargs['subnet_id']
).organization
return super().dispatch(*args, **kwargs)

def validate_membership(self, user):
if not (
user.is_superuser
or IsOrganizationManager.validate_membership(self, user, self.organization)
):
raise PermissionDenied(
"User does not have access to the specified organization"
)
class IpAddressOrgMixin(FilterByParentManaged):
def get_parent_queryset(self):
qs = Subnet.objects.filter(pk=self.kwargs['subnet_id'])
return qs


class ListViewPagination(pagination.PageNumberPagination):
Expand Down Expand Up @@ -158,19 +146,18 @@ def index_of(self, address):
return index


class AvailableIpView(DispatchOrgMixin, RetrieveAPIView):
subnet_model = Subnet
queryset = IpAddress.objects.none()
class AvailableIpView(IpAddressOrgMixin, RetrieveAPIView):
queryset = IpAddress.objects.all()
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)

def get(self, request, *args, **kwargs):
subnet = get_object_or_404(self.subnet_model, pk=self.kwargs['subnet_id'])
self.validate_membership(self.request.user)
subnet = get_object_or_404(Subnet, pk=self.kwargs['subnet_id'])
return Response(subnet.get_next_available_ip())


class IpAddressListCreateView(DispatchOrgMixin, ListCreateAPIView):
class IpAddressListCreateView(IpAddressOrgMixin, ListCreateAPIView):
queryset = IpAddress.objects.none()
subnet_model = Subnet
serializer_class = IpAddressSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
Expand All @@ -179,7 +166,7 @@ class IpAddressListCreateView(DispatchOrgMixin, ListCreateAPIView):

def get_queryset(self):
subnet = get_object_or_404(self.subnet_model, pk=self.kwargs['subnet_id'])
self.validate_membership(self.request.user)
super().get_queryset()
return subnet.ipaddress_set.all().order_by('ip_address')


Expand Down Expand Up @@ -212,7 +199,7 @@ class IpAddressView(RetrieveUpdateDestroyAPIView):
organization_field = 'subnet__organization'


class RequestIPView(DispatchOrgMixin, CreateAPIView):
class RequestIPView(IpAddressOrgMixin, CreateAPIView):
subnet_model = Subnet
queryset = IpAddress.objects.none()
serializer_class = IpRequestSerializer
Expand All @@ -222,7 +209,6 @@ class RequestIPView(DispatchOrgMixin, CreateAPIView):
def post(self, request, *args, **kwargs):
options = {'description': request.data.get('description')}
subnet = get_object_or_404(self.subnet_model, pk=kwargs['subnet_id'])
self.validate_membership(self.request.user)
ip_address = subnet.request_ip(options)
if ip_address:
serializer = IpAddressSerializer(ip_address)
Expand Down Expand Up @@ -251,23 +237,22 @@ def post(self, request, *args, **kwargs):
return Response({'detail': _('Data imported successfully.')})


class ExportSubnetView(DispatchOrgMixin, CreateAPIView):
class ExportSubnetView(IpAddressOrgMixin, CreateAPIView):
subnet_model = Subnet
queryset = Subnet.objects.none()
serializer_class = serializers.Serializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)

def post(self, request, *args, **kwargs):
self.validate_membership(self.request.user)
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="ip_address.csv"'
writer = csv.writer(response)
self.subnet_model().export_csv(kwargs['subnet_id'], writer)
return response


class SubnetHostsView(DispatchOrgMixin, ListAPIView):
class SubnetHostsView(IpAddressOrgMixin, ListAPIView):
subnet_model = Subnet
queryset = Subnet.objects.none()
serializer_class = HostsResponseSerializer
Expand All @@ -276,9 +261,9 @@ class SubnetHostsView(DispatchOrgMixin, ListAPIView):
pagination_class = HostsListPagination

def get_queryset(self):
super().get_queryset()
subnet = get_object_or_404(self.subnet_model, pk=self.kwargs['subnet_id'])
qs = HostsSet(subnet)
self.validate_membership(self.request.user)
return qs


Expand Down
12 changes: 6 additions & 6 deletions openwisp_ipam/tests/test_multitenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestMultitenantApi(
):
def setUp(self):
super().setUp()
# Creates a user for each of org_a and org_b
# Creates a manager for each of org_a and org_b
org_a = self._create_org(name='org_a', slug='org_a')
org_b = self._create_org(name='org_b', slug='org_b')
user_a = self._create_operator(
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_subnet_hosts(self):
self.assertEqual(response.status_code, 200)
self._login(username='user_b', password='tester')
response = self.client.get(reverse('ipam:hosts', args=(subnet.id,)))
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 404)

def test_subnet_list_ipaddress(self):
org_a = self._get_org(org_name='org_a')
Expand All @@ -143,7 +143,7 @@ def test_subnet_list_ipaddress(self):
response = self.client.get(
reverse('ipam:list_create_ip_address', args=(subnet.id,))
)
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 404)

def test_ipaddress(self):
org_a = self._get_org(org_name='org_a')
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_next_available_ip(self):
response = self.client.get(
reverse('ipam:get_next_available_ip', args=(subnet.id,))
)
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 404)

def test_subnet_list(self):
org_a = self._get_org(org_name='org_a')
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_request_ip(self):
data=post_data,
content_type='application/json',
)
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 404)

def test_import_subnet(self):
csv_data = """Monachers - Matera,
Expand Down Expand Up @@ -268,4 +268,4 @@ def test_export_subnet_api(self):
self.assertEqual(response.content, csv_data)
self._login(username='user_b', password='tester')
response = self.client.post(reverse('ipam:export-subnet', args=(subnet.id,)))
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 404)

0 comments on commit 29ae7f5

Please sign in to comment.