Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Implement multitenancy in API #61 #76

Merged
merged 12 commits into from
Jan 28, 2021
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ ENV/

# Generated CSV files
data_*.csv

# IDE settings
.vscode/
85 changes: 85 additions & 0 deletions openwisp_ipam/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import swapper
from django.contrib.auth.models import Permission
from django.core.exceptions import ValidationError
from rest_framework.exceptions import NotFound, PermissionDenied

Organization = swapper.load_model('openwisp_users', 'Organization')

class FilterByOrganization:
def get_queryset(self):
qs = super().get_queryset()
# superuser has access to every organization
if self.request.user.is_superuser:
return qs
# non superuser has access only to some organizations
return self.get_organization_queryset(qs)

def get_organization_queryset(self):
raise NotImplementedError()


class FilterByOrganizationManaged(FilterByOrganization):
"""
Allows to filter only organizations which the current user manages
"""

def get_organization_queryset(self, qs):
return qs.filter(organization__in=self.request.user.organizations_managed)


class FilterByParent:
purhan marked this conversation as resolved.
Show resolved Hide resolved
def get_queryset(self):
qs = super().get_queryset()
self.assert_parent_exists()
return qs

def assert_parent_exists(self):
parent_queryset = self.get_parent_queryset()
if not self.request.user.is_superuser:
parent_queryset = self.get_organization_queryset(parent_queryset)
try:
assert parent_queryset.exists()
except (AssertionError, ValidationError):
raise NotFound(detail='No relevant data found.')

def get_parent_queryset(self):
raise NotImplementedError()

def get_organization_queryset(self):
raise NotImplementedError()


class FilterByParentManaged(FilterByParent):
def get_organization_queryset(self, qs):
return qs.filter(organization__in=self.request.user.organizations_managed)
purhan marked this conversation as resolved.
Show resolved Hide resolved


class AuthorizeCSVImport:
def post(self, request):
self.assert_organization_permissions(request)

def assert_organization_permissions(self, request):
if request.user.is_superuser:
return
try:
organization = self.get_csv_organization()
if str(organization.pk) in self.get_user_organizations():
return
except Organization.DoesNotExist:
# if organization in CSV doesn't exist, then check if
# user can create new organizations
permission = Permission.objects.filter(user=request.user).filter(codename='add_organization')
if permission.exists():
return
raise PermissionDenied()

def get_csv_organization(self):
raise NotImplementedError()

def get_user_organizations(self):
raise NotImplementedError()


class AuthorizeCSVOrgManaged(AuthorizeCSVImport):
def get_user_organizations(self):
return self.request.user.organizations_managed
52 changes: 43 additions & 9 deletions openwisp_ipam/api/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import csv
from collections import OrderedDict
from copy import deepcopy

import swapper
from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _
from openwisp_users.api.authentication import BearerAuthentication
from openwisp_users.api.permissions import IsOrganizationManager
from rest_framework import pagination, serializers, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.generics import (
Expand All @@ -28,9 +30,30 @@
IpRequestSerializer,
SubnetSerializer,
)
from .utils import (
AuthorizeCSVOrgManaged,
FilterByOrganizationManaged,
FilterByParentManaged,
)

IpAddress = swapper.load_model('openwisp_ipam', 'IpAddress')
Subnet = swapper.load_model('openwisp_ipam', 'Subnet')
Organization = swapper.load_model('openwisp_users', 'Organization')


class IpAddressOrgMixin(FilterByParentManaged):
def get_parent_queryset(self):
qs = Subnet.objects.filter(pk=self.kwargs['subnet_id'])
return qs


class ImportSubnetCSVMixin(AuthorizeCSVOrgManaged):
def get_csv_organization(self):
data = self.subnet_model._get_csv_reader(
self, deepcopy(self.request.FILES['csvfile'])
)
org = Organization.objects.get(name=list(data)[2][0].strip())
return org
purhan marked this conversation as resolved.
Show resolved Hide resolved


class ListViewPagination(pagination.PageNumberPagination):
Expand Down Expand Up @@ -137,7 +160,7 @@ def index_of(self, address):
return index


class AvailableIpView(RetrieveAPIView):
class AvailableIpView(IpAddressOrgMixin, RetrieveAPIView):
subnet_model = Subnet
queryset = IpAddress.objects.none()
authentication_classes = (BearerAuthentication, SessionAuthentication)
Expand All @@ -148,7 +171,8 @@ def get(self, request, *args, **kwargs):
return Response(subnet.get_next_available_ip())


class IpAddressListCreateView(ListCreateAPIView):
class IpAddressListCreateView(IpAddressOrgMixin, ListCreateAPIView):
queryset = IpAddress.objects.none()
subnet_model = Subnet
serializer_class = IpAddressSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
Expand All @@ -157,10 +181,11 @@ class IpAddressListCreateView(ListCreateAPIView):

def get_queryset(self):
subnet = get_object_or_404(self.subnet_model, pk=self.kwargs['subnet_id'])
super().get_queryset()
return subnet.ipaddress_set.all().order_by('ip_address')


class SubnetListCreateView(ListCreateAPIView):
class SubnetListCreateView(FilterByOrganizationManaged, ListCreateAPIView):
serializer_class = SubnetSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)
Expand All @@ -171,18 +196,25 @@ class SubnetListCreateView(ListCreateAPIView):
class SubnetView(RetrieveUpdateDestroyAPIView):
serializer_class = SubnetSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)
permission_classes = (
IsOrganizationManager,
DjangoModelPermissions,
)
queryset = Subnet.objects.all()


class IpAddressView(RetrieveUpdateDestroyAPIView):
serializer_class = IpAddressSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)
permission_classes = (
IsOrganizationManager,
DjangoModelPermissions,
)
queryset = IpAddress.objects.all()
organization_field = 'subnet__organization'


class RequestIPView(CreateAPIView):
class RequestIPView(IpAddressOrgMixin, CreateAPIView):
subnet_model = Subnet
queryset = IpAddress.objects.none()
serializer_class = IpRequestSerializer
Expand All @@ -202,14 +234,15 @@ def post(self, request, *args, **kwargs):
return Response(None)


class ImportSubnetView(CreateAPIView):
class ImportSubnetView(ImportSubnetCSVMixin, CreateAPIView):
subnet_model = Subnet
queryset = Subnet.objects.none()
serializer_class = ImportSubnetSerializer
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (DjangoModelPermissions,)

def post(self, request, *args, **kwargs):
super().post(request)
purhan marked this conversation as resolved.
Show resolved Hide resolved
file = request.FILES['csvfile']
if not file.name.endswith(('.csv', '.xls', '.xlsx')):
return Response({'error': _('File type not supported.')}, status=400)
Expand All @@ -220,7 +253,7 @@ def post(self, request, *args, **kwargs):
return Response({'detail': _('Data imported successfully.')})


class ExportSubnetView(CreateAPIView):
class ExportSubnetView(IpAddressOrgMixin, CreateAPIView):
subnet_model = Subnet
queryset = Subnet.objects.none()
serializer_class = serializers.Serializer
Expand All @@ -235,7 +268,7 @@ def post(self, request, *args, **kwargs):
return response


class SubnetHostsView(ListAPIView):
class SubnetHostsView(IpAddressOrgMixin, ListAPIView):
subnet_model = Subnet
queryset = Subnet.objects.none()
serializer_class = HostsResponseSerializer
Expand All @@ -244,6 +277,7 @@ class SubnetHostsView(ListAPIView):
pagination_class = HostsListPagination

def get_queryset(self):
super().get_queryset()
subnet = get_object_or_404(self.subnet_model, pk=self.kwargs['subnet_id'])
qs = HostsSet(subnet)
return qs
Expand Down
6 changes: 5 additions & 1 deletion openwisp_ipam/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _read_ipaddress_data(self, reader, subnet):
for ip in ipaddress_list:
ip.save()

def import_csv(self, file):
def _get_csv_reader(self, file):
if file.name.endswith(('.xls', '.xlsx')):
book = xlrd.open_workbook(file_contents=file.read())
sheet = book.sheet_by_index(0)
Expand All @@ -188,6 +188,10 @@ def import_csv(self, file):
reader = iter(row)
else:
reader = csv.reader(StringIO(file.read().decode('utf-8')), delimiter=',')
return reader
purhan marked this conversation as resolved.
Show resolved Hide resolved

def import_csv(self, file):
reader = self._get_csv_reader(file)
subnet = self._read_subnet_data(reader)
next(reader)
next(reader)
Expand Down
Loading