Skip to content

Commit

Permalink
[fix] Fixed command APIs permissions #754
Browse files Browse the repository at this point in the history
- Command APIs inherits "ProtectedAPIMixin" for consistency with
  other API endpoints
- If an organization admin tries to access commands for a device
  of another organization, then the API would return 404 response.
  Earlier, the API returned 200 reponse with an empty list.

Fixes #754
  • Loading branch information
pandafy authored Jun 20, 2023
1 parent b8169d9 commit 6f5cf3f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 40 deletions.
39 changes: 13 additions & 26 deletions openwisp_controller/connection/api/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from django.core.exceptions import ValidationError
from rest_framework import pagination
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import NotFound
from rest_framework.generics import (
GenericAPIView,
Expand All @@ -11,7 +10,8 @@
)
from swapper import load_model

from openwisp_users.api.authentication import BearerAuthentication
from openwisp_users.api.mixins import FilterByParentManaged
from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin

from ...mixins import ProtectedAPIMixin
from .serializer import (
Expand All @@ -32,32 +32,18 @@ class ListViewPagination(pagination.PageNumberPagination):
max_page_size = 100


class BaseCommandView(GenericAPIView):
class BaseCommandView(FilterByParentManaged, BaseProtectedAPIMixin):
model = Command
queryset = Command.objects.prefetch_related('device')
serializer_class = CommandSerializer
authentication_classes = [BearerAuthentication, SessionAuthentication]

def get_queryset(self):
qs = Command.objects.prefetch_related('device')
if not self.request.user.is_superuser:
qs = qs.filter(
device__organization__in=self.request.user.organizations_managed
)
return qs

def initial(self, *args, **kwargs):
super().initial(*args, **kwargs)
self.assert_parent_exists()

def assert_parent_exists(self):
try:
assert self.get_parent_queryset().exists()
except (AssertionError, ValidationError):
device_id = self.kwargs['id']
raise NotFound(detail=f'Device with ID "{device_id}" not found.')

def get_parent_queryset(self):
return Device.objects.filter(pk=self.kwargs['id'])
return Device.objects.filter(
pk=self.kwargs['id'],
)

def get_queryset(self):
return super().get_queryset().filter(device_id=self.kwargs['id'])

def get_serializer_context(self):
context = super().get_serializer_context()
Expand All @@ -68,8 +54,9 @@ def get_serializer_context(self):
class CommandListCreateView(BaseCommandView, ListCreateAPIView):
pagination_class = ListViewPagination

def get_queryset(self):
return super().get_queryset().filter(device_id=self.kwargs['id'])
def create(self, request, *args, **kwargs):
self.assert_parent_exists()
return super().create(request, *args, **kwargs)


class CommandDetailsView(BaseCommandView, RetrieveAPIView):
Expand Down
31 changes: 17 additions & 14 deletions openwisp_controller/connection/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Command = load_model('connection', 'Command')
command_qs = Command.objects.order_by('-created')
OrganizationUser = load_model('openwisp_users', 'OrganizationUser')
Group = load_model('openwisp_users', 'Group')


class TestCommandsAPI(TestCase, AuthenticationMixin, CreateCommandMixin):
Expand All @@ -41,11 +42,7 @@ def _get_path(self, url_name, *args, **kwargs):
return f'{path}?{query_string}'

def _get_device_not_found_error(self, device_id):
return {
'detail': ErrorDetail(
f'Device with ID "{device_id}" not found.', code='not_found'
)
}
return {'detail': ErrorDetail('Not found.', code='not_found')}

@patch.object(ListViewPagination, 'page_size', 3)
def test_command_list_api(self):
Expand Down Expand Up @@ -249,8 +246,7 @@ def test_endpoints_for_non_existent_device(self):
'input': {'command': 'echo test'},
}
response = self.client.post(
url,
data=payload,
url, data=payload, content_type='application/json'
)
self.assertEqual(response.status_code, 404)
self.assertDictEqual(response.data, device_not_found)
Expand All @@ -268,24 +264,31 @@ def test_non_superuser(self):
command = self._create_command(device_conn=self.device_conn)
device = command.device

with self.subTest('Test non organization member'):
operator = self._create_operator()
self.client.force_login(operator)
self.assertNotIn(device.organization, operator.organizations_managed)

with self.subTest('Test with unauthenticated user'):
self.client.logout()
response = self.client.get(list_url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data['count'], 0)
self.assertEqual(response.status_code, 401)

with self.subTest('Test with organization member'):
org_user = self._create_org_user(is_admin=True)
org_user.user.groups.add(Group.objects.get(name='Operator'))
self.client.force_login(org_user.user)
self.assertEqual(device.organization, org_user.organization)

response = self.client.get(list_url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data['count'], 1)

with self.subTest('Test with org member of different org'):
org2 = self._create_org(name='org2', slug='org2')
org2_user = self._create_user(username='org2user', email='user@org2.com')
self._create_org_user(organization=org2, user=org2_user, is_admin=True)
self.client.force_login(org2_user)
org2_user.groups.add(Group.objects.get(name='Operator'))

response = self.client.get(list_url)
self.assertEqual(response.status_code, 404)

def test_non_existent_command(self):
url = self._get_path('device_command_list', self.device_id)
with patch.dict(
Expand Down

0 comments on commit 6f5cf3f

Please sign in to comment.