diff --git a/.travis.yml b/.travis.yml index ff0c2ed..6a29afc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ language: python services: - postgresql - + addons: postgresql: "9.4" @@ -24,6 +24,8 @@ matrix: env: TOXENV=py36-django111 - python: "3.6" env: TOXENV=py36-djangomaster + - python: "3.5" + env: TOXENV="flake8" exclude: - python: "3.5" env: TOXENV=py36-django111 @@ -41,13 +43,13 @@ cache: install: - pip install coveralls tox - + before_script: - psql -c 'create database travis_ci_test;' -U postgres script: - tox - + after_script: - coveralls diff --git a/api_bouncer/middlewares/key_auth.py b/api_bouncer/middlewares/key_auth.py index 253399a..537c108 100644 --- a/api_bouncer/middlewares/key_auth.py +++ b/api_bouncer/middlewares/key_auth.py @@ -1,3 +1,5 @@ +import json + from django.http import JsonResponse from ..models import ( @@ -19,37 +21,50 @@ def __call__(self, request): if plugin_conf: config = plugin_conf.config - apikey = self.get_key(request, config['key_names']) - - if not self.verify_key(request, config, apikey): + apikey = self.get_key( + request, + config['key_names'], + key_in_body=config['key_in_body'] + ) + consumer_key = self.verify_key(request, config, apikey) + if not consumer_key: return JsonResponse( data={'error': 'Unauthorized'}, status=403 ) + if not config['hide_credentials']: + request.META.update({ + 'HTTP_X_CONSUMER_USERNAME': consumer_key.consumer.username, + 'HTTP_X_CONSUMER_ID': str(consumer_key.consumer.id), + }) + else: + # Remove apikey from headers + for k in config['key_names']: + request.META.pop(k, None) response = self.get_response(request) return response def verify_key(self, request, config, key): - if not key and config.get('key_in_body'): - key = request.body.get('key') - - c_key = ( + apikey = ( ConsumerKey.objects .select_related('consumer') .filter(key=key).first() ) + return apikey - if c_key: - request.META['X-Consumer-Username'] = c_key.consumer.username - request.META['X-Consumer-Id'] = c_key.consumer.id - - return True - - return False + def get_key(self, request, key_names, key_in_body=False): + if key_in_body: + try: + body = json.loads(request.body.decode('utf-8')) + for k in key_names: + if k in body: + return body[k] + return None + except json.JSONDecodeError: + return None - def get_key(self, request, key_names): for n in key_names: name = n.upper().replace('-', '_') key_name = 'HTTP_{0}'.format(name) diff --git a/api_bouncer/models.py b/api_bouncer/models.py index 67c38dd..afbb932 100644 --- a/api_bouncer/models.py +++ b/api_bouncer/models.py @@ -4,7 +4,10 @@ ArrayField, JSONField, ) -from django.core.validators import RegexValidator +from django.core.validators import ( + MinLengthValidator, + RegexValidator, +) from django.db import models @@ -30,7 +33,10 @@ class Api(models.Model): validators=[ RegexValidator(regex=FQDN_REGEX), ] - ) + ), + validators=[ + MinLengthValidator(1, message='At least one is required'), + ] ) upstream_url = models.URLField(null=False) @@ -59,7 +65,10 @@ class ConsumerKey(models.Model): key = models.CharField( max_length=200, blank=False, - null=False + null=False, + validators=[ + MinLengthValidator(8), + ] ) class Meta: diff --git a/api_bouncer/serializers.py b/api_bouncer/serializers.py index 473e657..7a86b2e 100644 --- a/api_bouncer/serializers.py +++ b/api_bouncer/serializers.py @@ -1,4 +1,5 @@ import uuid + import jsonschema from rest_framework import serializers @@ -40,7 +41,7 @@ class Meta: def validate_key(self, value): """Verify if no key is given and generate one""" if not value: - value = str(uuid.uuid4()).replace('-', '') + value = str(uuid.uuid4().int) return value @@ -74,9 +75,6 @@ def validate(self, data): return data - def process_headers(self, headers={}): - return headers - class ApiSerializer(serializers.ModelSerializer): plugins = PluginSerializer( @@ -94,11 +92,8 @@ class BouncerSerializer(serializers.Serializer): api = serializers.CharField(allow_blank=False, allow_null=False) headers = serializers.DictField(child=serializers.CharField()) - def validate(self, data): - api = Api.objects.get(name=data['api']) + def validate_api(self, value): + api = Api.objects.get(name=value) if not api: - raise serializers.ValidationError({ - 'api': 'Unknown API', - }) - - return data + raise serializers.ValidationError('Unknown API') + return value diff --git a/api_bouncer/views.py b/api_bouncer/views.py index d320747..41e338d 100644 --- a/api_bouncer/views.py +++ b/api_bouncer/views.py @@ -1,13 +1,13 @@ import re -from requests import Request, Session from django.http import HttpResponse, JsonResponse +from requests import Request, Session from rest_framework import ( mixins, permissions, response, - viewsets, status, + viewsets, ) from rest_framework.decorators import detail_route @@ -24,19 +24,20 @@ from .serializers import ( ApiSerializer, BouncerSerializer, - ConsumerSerializer, ConsumerKeySerializer, + ConsumerSerializer, PluginSerializer, ) def api_bouncer(request): def get_headers(meta): + """Get all headers beginning with HTTP_ and that have a value""" regex = re.compile(r'^HTTP_') return { (regex.sub('', k)).replace('_', '-'): v for k, v in meta.items() - if k.startswith('HTTP_') + if k.startswith('HTTP_') and v } dest_host = request.META.get('HTTP_HOST') @@ -57,11 +58,12 @@ def get_headers(meta): request.method, url, params=request.GET, - data=request.POST + data=request.POST, + headers=serializer.data['headers'] ) prepped = session.prepare_request(req) resp = session.send(prepped) - content_type = resp.headers['content-type'] + content_type = resp.headers.get('content-type', 'text/html') return HttpResponse( content=resp.content, @@ -85,12 +87,13 @@ class ApiViewSet(viewsets.ModelViewSet): lookup_field = 'name' @detail_route( - methods=['patch', 'put'], + methods=['post'], permission_classes=[permissions.IsAdminUser], url_path='plugins' ) def add_plugin(self, request, name=None): api = self.get_object() + plugin_name = request.data.get('name') plugin_conf = request.data.get('config') @@ -108,21 +111,20 @@ def add_plugin(self, request, name=None): api_plugin_conf.update(plugin_conf) data = { - 'api': api.id, + 'api': api, 'name': plugin_name, 'config': api_plugin_conf, } - if api_plugin: - serializer = PluginSerializer(api_plugin, data=data) - else: - serializer = PluginSerializer(data=data) + if not api_plugin: + api_plugin = Plugin(**data) + serializer = PluginSerializer(api_plugin, data=data) if serializer.is_valid(): serializer.save() return response.Response( serializer.data, - status=status.HTTP_201_CREATED + status=status.HTTP_200_OK ) return response.Response( diff --git a/tests/__init__.py b/tests/__init__.py index 8b13789..4ede8e6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ - +# noqa diff --git a/tests/settings.py b/tests/settings.py index 3e9273a..2badbe5 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,35 +1,13 @@ -""" -Django settings for bouncer project. - -Generated by 'django-admin startproject' using Django 1.11.2. - -For more information on this file, see -https://docs.djangoproject.com/en/1.11/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/1.11/ref/settings/ -""" - import os -# Build paths inside the project like this: os.path.join(BASE_DIR, ...) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = 'supersecret12345menecio67890' -# SECURITY WARNING: don't run with debug turned on in production! DEBUG = True ALLOWED_HOSTS = ['*'] - -# Application definition - INSTALLED_APPS = [ 'django.contrib.admin', 'django.contrib.auth', @@ -45,7 +23,6 @@ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', @@ -72,10 +49,6 @@ WSGI_APPLICATION = 'bouncer.wsgi.application' - -# Database -# https://docs.djangoproject.com/en/1.11/ref/settings/#databases - DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql', @@ -86,28 +59,7 @@ } } - -# Password validation -# https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - -# Internationalization -# https://docs.djangoproject.com/en/1.11/topics/i18n/ +AUTH_PASSWORD_VALIDATORS = [] LANGUAGE_CODE = 'en-us' @@ -119,8 +71,4 @@ USE_TZ = True - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/1.11/howto/static-files/ - STATIC_URL = '/static/' diff --git a/tests/test_apis.py b/tests/test_apis.py index 873193d..5c38a05 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -1,6 +1,6 @@ +from django.contrib.auth import get_user_model from rest_framework import status from rest_framework.test import APITestCase -from django.contrib.auth import get_user_model from api_bouncer.models import Api @@ -9,24 +9,163 @@ class ApiTests(APITestCase): def setUp(self): - self.user = User.objects.create_superuser( + self.superuser = User.objects.create_superuser( 'john', 'john@localhost.local', 'john123john' ) + self.user = User.objects.create_user( + 'jane', + 'jane@localhost.local', + 'jane123jane' + ) + self.url = '/apis/' - def test_create_api(self): + def test_create_api_ok(self): """ Ensure we can create a new api object. """ - url = '/apis/' data = { 'name': 'example-api', 'hosts': ['example.com'], 'upstream_url': 'https://httpbin.org' } self.client.login(username='john', password='john123john') - response = self.client.post(url, data, format='json') + response = self.client.post(self.url, data, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(Api.objects.count(), 1) self.assertEqual(Api.objects.get().name, 'example-api') + + def test_create_api_regular_user_403(self): + """ + Ensure only admin user can create a new api object. + """ + data = { + 'name': 'example-api', + 'hosts': ['example.com'], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='jane', password='jane123jane') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_api_unauthenticated_403(self): + """ + Ensure unauthenticated requests can't create a new api. + """ + data = { + 'name': 'example-api', + 'hosts': ['example.com'], + 'upstream_url': 'https://httpbin.org' + } + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_api_missing_name(self): + """ + Ensure name is required for api creation + """ + data = { + 'hosts': ['example.com'], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(Api.objects.count(), 0) + + def test_create_api_empty_hosts(self): + """ + Ensure at least one host is required for api creation + """ + data = { + 'name': 'example-api', + 'hosts': [], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(Api.objects.count(), 0) + + def test_create_api_empty_string_hosts(self): + """ + Ensure that empty strings are not valid hostnames + """ + data = { + 'name': 'example-api', + 'hosts': ['example.com', ''], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(Api.objects.count(), 0) + + def test_create_ip4_hosts_not_ok(self): + """ + Ensure we cant use IPv4 addresses as valid hosts. Only FQDN. + """ + data = { + 'name': 'example-api', + 'hosts': ['172.10.0.13'], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(Api.objects.count(), 0) + + def test_create_empty_upstream_not_ok(self): + """ + Ensure we require upstream_url. + """ + data = { + 'name': 'example-api', + 'hosts': ['172.10.0.13'], + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(Api.objects.count(), 0) + + def test_details_api_ok(self): + """ + Ensure we can get the details of a api object. + """ + data = { + 'name': 'example-api', + 'hosts': ['example.com'], + 'upstream_url': 'https://httpbin.org' + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + url_get = '{}{}'.format(self.url, 'example-api/') + response_get = self.client.get(url_get) + self.assertEqual(response_get.status_code, status.HTTP_200_OK) + self.assertEqual( + response_get.data['id'], + str(Api.objects.all().first().pk) + ) + + def test_details_api_unauthenticated_403(self): + """ + Ensure we can get the details of an api object only if unauthenticated. + """ + data = { + 'name': 'example-api', + 'hosts': ['example.com'], + 'upstream_url': 'https://httpbin.org' + } + + # Log as superuser and create an api and then logout + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.client.logout() + + url_get = '{}{}/'.format(self.url, 'example-api') + response_get = self.client.get(url_get) + self.assertEqual(response_get.status_code, status.HTTP_403_FORBIDDEN) diff --git a/tests/test_bouncer.py b/tests/test_bouncer.py new file mode 100644 index 0000000..26c135d --- /dev/null +++ b/tests/test_bouncer.py @@ -0,0 +1,42 @@ +from django.contrib.auth import get_user_model +from rest_framework import status +from rest_framework.test import APITestCase + +from api_bouncer.models import Api + +User = get_user_model() + + +class BouncerTests(APITestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + 'john', + 'john@localhost.local', + 'john123john' + ) + self.example_api = Api.objects.create( + name='httpbin', + hosts=['httpbin.org'], + upstream_url='https://httpbin.org' + ) + + def test_bounce_api_request(self): + """ + Ensure we can bouncer a request to an api and get the same response. + """ + url = '/status/418' # teapot + self.client.credentials(HTTP_HOST='httpbin.org') + response = self.client.get(url) + self.assertEqual(response.status_code, 418) + self.assertIn('teapot', response.content.decode('utf-8')) + + def test_bounce_api_request_unknown_host(self): + """ + Ensure we send a response when the hosts making the request is not + trying to call an api. + """ + url = '/test' + self.client.credentials(HTTP_HOST='the-unknown.com') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json(), {}) diff --git a/tests/test_consumer.py b/tests/test_consumer.py new file mode 100644 index 0000000..cb3cccc --- /dev/null +++ b/tests/test_consumer.py @@ -0,0 +1,62 @@ +from django.contrib.auth import get_user_model +from rest_framework import status +from rest_framework.test import APITestCase + +from api_bouncer.models import Consumer + +User = get_user_model() + + +class ConsumerTests(APITestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + 'john', + 'john@localhost.local', + 'john123john' + ) + self.user = User.objects.create_user( + 'jane', + 'jane@localhost.local', + 'jane123jane' + ) + self.url = '/consumers/' + + def test_create_consumer_ok(self): + """ + Ensure we can create a new consumer object. + """ + data = { + 'username': 'django', + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Consumer.objects.count(), 1) + self.assertEqual(Consumer.objects.get().username, 'django') + + def test_create_consumer_403(self): + """ + Ensure we can create a new consumer object only as superuser. + """ + data = { + 'username': 'django', + } + self.client.login(username='jane', password='jane123jane') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_duplicate_consumer(self): + """ + Ensure we can't duplicate consumer object. + """ + data = { + 'username': 'django', + } + self.client.login(username='john', password='john123john') + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Consumer.objects.count(), 1) + self.assertEqual(Consumer.objects.get().username, 'django') + + response = self.client.post(self.url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/tests/test_consumer_keys.py b/tests/test_consumer_keys.py new file mode 100644 index 0000000..f46a32e --- /dev/null +++ b/tests/test_consumer_keys.py @@ -0,0 +1,84 @@ +from django.contrib.auth import get_user_model +from rest_framework import status +from rest_framework.test import APITestCase + +from api_bouncer.models import Consumer, ConsumerKey + +User = get_user_model() + + +class ConsumerKeyTests(APITestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + 'john', + 'john@localhost.local', + 'john123john' + ) + self.user = User.objects.create_user( + 'jane', + 'jane@localhost.local', + 'jane123jane' + ) + self.consumer = Consumer.objects.create(username='django') + + self.url = '/consumers/{}/key-auth/' + + def test_create_consumer_key_auto(self): + """ + Ensure we can create a new consumer key object, with a default value. + """ + self.client.login(username='john', password='john123john') + url = self.url.format(self.consumer.username) + + response = self.client.post(url, {}, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConsumerKey.objects.count(), 1) + self.assertEqual(ConsumerKey.objects.get().consumer.username, 'django') + + def test_create_consumer_key_auto_403(self): + """ + Ensure we can create a new consumer key object with a default value, + only as superusers. + """ + url = self.url.format(self.consumer.username) + + response = self.client.post(url, {}, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_consumer_key_given_key(self): + """ + Ensure we can create a new consumer key object, with a default value. + """ + data = {'key': 'abc123456'} + self.client.login(username='john', password='john123john') + url = self.url.format(self.consumer.username) + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ConsumerKey.objects.count(), 1) + self.assertEqual(ConsumerKey.objects.get().consumer.username, 'django') + self.assertEqual(ConsumerKey.objects.get().key, data['key']) + + def test_create_consumer_key_given_key_too_short(self): + """ + Ensure we can create a new consumer key object, with a default value. + """ + data = {'key': 'abc123'} + self.client.login(username='john', password='john123john') + url = self.url.format(self.consumer.username) + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_create_consumer_key_given_empty_key(self): + """ + Ensure we can't create an empty consumer key object, if an empty key is + given, we must generate a hash for the key + """ + data = {'key': ''} + self.client.login(username='john', password='john123john') + url = self.url.format(self.consumer.username) + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotEqual(response.data['key'], '') diff --git a/tests/test_middleware_key_auth.py b/tests/test_middleware_key_auth.py new file mode 100644 index 0000000..592eba6 --- /dev/null +++ b/tests/test_middleware_key_auth.py @@ -0,0 +1,97 @@ +import json + +from django.contrib.auth import get_user_model +from rest_framework import status +from rest_framework.test import APITestCase + +from api_bouncer.models import Api, Consumer + +User = get_user_model() + + +class KeyAuthMiddlewareTests(APITestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + 'john', + 'john@localhost.local', + 'john123john' + ) + self.example_api = Api.objects.create( + name='httpbin', + hosts=['httpbin.org'], + upstream_url='https://httpbin.org' + ) + self.key_auth_url = '/apis/{}/plugins/'.format(self.example_api.name) + + self.consumer = Consumer.objects.create(username='django') + self.consumer_key_url = ( + '/consumers/{}/key-auth/'.format(self.consumer.username) + ) + + def test_bounce_api_authorization_ok(self): + """ + Ensure we can perform requests on an api using a valid key. + """ + self.client.login(username='john', password='john123john') + self.client.post(self.key_auth_url) + response = self.client.post(self.consumer_key_url) + self.client.logout() + apikey = response.data['key'] + + url = '/get?msg=Bounce' + self.client.credentials(HTTP_HOST='httpbin.org', HTTP_APIKEY=apikey) + response = self.client.get(url) + content = response.content.decode('utf-8') + data = json.loads(content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(data['args']['msg'], 'Bounce') + + def test_bounce_api_key_in_body(self): + """ + Ensure we can perform requests on an api using a valid key sent on + request body. + """ + self.client.login(username='john', password='john123john') + data = { + 'name': 'key-auth', + 'config': { + 'anonymous': '', + 'key_names': ['apikey'], + 'key_in_body': True, + 'hide_credentials': False, + } + } + self.client.post(self.key_auth_url, data, format='json') + response = self.client.post(self.consumer_key_url) + self.client.logout() + apikey = response.data['key'] + + url = '/post' + self.client.credentials(HTTP_HOST='httpbin.org') + response = self.client.post( + url, + data={'apikey': apikey}, + format='json' + ) + content = response.content.decode('utf-8') + data = json.loads(content) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + data['headers']['X-Consumer-Username'], + self.consumer.username + ) + + def test_bounce_api_authorization_invalid(self): + """ + Ensure we can't perform requests on an api without using a valid key. + """ + self.client.login(username='john', password='john123john') + self.client.post(self.key_auth_url, {'name': 'key-auth'}) + response = self.client.post(self.consumer_key_url) + self.client.logout() + apikey = 'you_know_nothing' + + url = '/get' + self.client.credentials(HTTP_HOST='httpbin.org', HTTP_APIKEY=apikey) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/tests/test_plugin.py b/tests/test_plugin.py new file mode 100644 index 0000000..d6fd07c --- /dev/null +++ b/tests/test_plugin.py @@ -0,0 +1,113 @@ +from django.contrib.auth import get_user_model +from rest_framework import status +from rest_framework.test import APITestCase + +from api_bouncer.models import Api + +User = get_user_model() + + +class PluginTests(APITestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + 'john', + 'john@localhost.local', + 'john123john' + ) + self.user = User.objects.create_user( + 'jane', + 'jane@localhost.local', + 'jane123jane' + ) + self.example_api = Api.objects.create( + name='example-api', + hosts=['example.com'], + upstream_url='https://httpbin.org' + ) + + self.url = '/apis/{}/plugins/' + + def test_api_add_plugin(self): + """ + Ensure we can add a plugin to an api as superusers. + """ + self.client.login(username='john', password='john123john') + url = self.url.format(self.example_api.name) + + data = { + 'name': 'key-auth', + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.example_api.plugins.count(), 1) + self.assertEqual(self.example_api.plugins.first().name, data['name']) + + def test_api_add_plugin_403(self): + """ + Ensure we can add a plugin to an api only as superusers. + """ + self.client.login(username='jane', password='jane123jane') + url = self.url.format(self.example_api.name) + + data = { + 'name': 'key-auth', + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_api_add_plugin_wrong_name(self): + """ + Ensure we can't add a plugin to an api that doesn't exist. + """ + self.client.login(username='john', password='john123john') + url = self.url.format(self.example_api.name) + + data = { + 'name': 'na-ah', + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data['errors'], 'Invalid plugin name') + + def test_api_add_plugin_modify_partially_config(self): + """ + Ensure we can partially modify a plugin configuration. + """ + self.client.login(username='john', password='john123john') + url = self.url.format(self.example_api.name) + + data = { + 'name': 'key-auth', + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.example_api.plugins.count(), 1) + self.assertEqual(self.example_api.plugins.first().name, data['name']) + + expected_res = response.data + expected_res['config'].update({'anonymous': 'citizen-four'}) + + data.update({'config': {'anonymous': 'citizen-four'}}) + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.example_api.plugins.count(), 1) + self.assertEqual(response.data, expected_res) + + def test_api_add_plugin_no_extra_keys(self): + """ + Ensure we can't add arguments not defined on plugin's schema. + """ + self.client.login(username='john', password='john123john') + url = self.url.format(self.example_api.name) + + data = { + 'name': 'key-auth', + } + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.example_api.plugins.count(), 1) + self.assertEqual(self.example_api.plugins.first().name, data['name']) + + data.update({'config': {'you_shall_not_pass': True}}) + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/tests/urls.py b/tests/urls.py index e3c411f..1429278 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,8 +1,11 @@ -from django.conf.urls import url, include +from django.conf.urls import include, url from django.contrib import admin urlpatterns = [ url(r'^admin/', admin.site.urls), - url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')), + url( + r'^api-auth/', + include('rest_framework.urls', namespace='rest_framework') + ), url(r'^', include('api_bouncer.urls')), ] diff --git a/tox.ini b/tox.ini index 7504e96..b566cfe 100644 --- a/tox.ini +++ b/tox.ini @@ -40,4 +40,4 @@ max-line-length = 79 exclude = docs/, migrations/, .tox/ import-order-style = smarkets application-import-names = api_bouncer -inline-quotes = " +inline-quotes = '