Skip to content

Commit

Permalink
Format and fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
nickw444 committed Oct 27, 2018
1 parent cc4bb62 commit 00bff09
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E302, E121, W291, W391, W293, E261, W292, E401, E303, E701
ignore = E261, W503, W504
exclude =
.git,
venv,
Expand Down
29 changes: 15 additions & 14 deletions examples/token_auth/simple.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from flask import current_app, Flask, request, abort, jsonify, g
from flask_utils.token_auth import (
ShortlivedTokenMixin, parse_auth_header, auth_required)
from marshmallow import fields
import uuid
import datetime
import logging
import random
import string
import uuid

import pytz
import datetime
import logging
from flask import current_app, Flask, request, abort, jsonify, g
from marshmallow import fields

from flask_utils.token_auth import (
ShortlivedTokenMixin, parse_auth_header, auth_required)

log = logging.getLogger(__name__)


class ShortlivedToken(ShortlivedTokenMixin):

class TokenSchema(ShortlivedTokenMixin.TokenSchema):
rfid = fields.String(attribute='refresh_token_id')
user_id = fields.String(attribute='user_id')
Expand Down Expand Up @@ -41,6 +42,7 @@ def from_refresh_token(Cls, refresh_token):

random_alpha = string.digits + string.ascii_letters


class RefreshToken():
def __init__(self, user_id, scopes):
self.id = uuid.uuid4().hex
Expand All @@ -64,7 +66,7 @@ def login():
# Check if the credentials were correct
if request.form.get('username') != 'test' or \
request.form.get('password') != 'test':
abort(401)
abort(401)

# Create a new refresh token
refresh_token = RefreshToken(user_id=request.form.get('username'),
Expand All @@ -86,13 +88,12 @@ def login():
@parse_auth_header(ShortlivedToken)
@auth_required()
def logout():

# Find the associated refresh token
for refresh_token in refresh_tokens:
if refresh_token.id == g.token.refresh_token_id:
break # Found the associated token
break # Found the associated token

else: # nobreak
else: # nobreak
# Couldn't find the token. Maybe it has been revoked.
abort(401)

Expand All @@ -111,9 +112,9 @@ def renew():
# Find the refresh token in the store
for refresh_token in refresh_tokens:
if refresh_token.token == token_string:
break # Found the token that we need.
break # Found the token that we need.

else: # nobreak
else: # nobreak
# Couldn't find the token. Oops
abort(401)

Expand Down
2 changes: 2 additions & 0 deletions flask_utils/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from celery import Celery

from flask_utils.deployment_release import get_release


def create_celery(name, config_obj, inject_version=True, **kwargs):
"""Creates a celery app.
Expand Down
33 changes: 20 additions & 13 deletions flask_utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,37 @@
except ImportError:
from urlparse import urlparse

def build_url(url, scheme=None, username=None, password=None, hostname=None,

def build_url(url, scheme=None, username=None, password=None, hostname=None,
port=None, path=None):
"""
Parse a URL and override specific segments of it.
:param url: The url to parse/build upon
:param scheme:
:param username:
:param password:
:param hostname:
:param scheme:
:param username:
:param password:
:param hostname:
:param port:
:param path:
:return: A URL with overridden components
"""
dsn = urlparse(url)

if scheme is None: scheme = dsn.scheme
if username is None: username = dsn.username
if password is None: password = dsn.password
if hostname is None: hostname = dsn.hostname
if port is None: port = dsn.port
if path is None: path = dsn.path
if scheme is None:
scheme = dsn.scheme
if username is None:
username = dsn.username
if password is None:
password = dsn.password
if hostname is None:
hostname = dsn.hostname
if port is None:
port = dsn.port
if path is None:
path = dsn.path

def build_auth():
if username is not None or password is not None:
Expand Down
2 changes: 1 addition & 1 deletion flask_utils/deployment_release/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def get_release():
"""
Opens a file ``version.txt`` and returns it's stripped contents.
:returns: The stripped file contents
"""
try:
Expand Down
9 changes: 5 additions & 4 deletions flask_utils/pagination/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
'limit': wfields.Integer(missing=20)
}


def paginated(basequery, schema_type, offset=None, limit=None):
"""
Paginate a sqlalchemy query
:param basequery: The base query to be iterated upon
:param schema_type: The ``Marshmallow`` schema to dump data with
:param offset: (Optional) The offset into the data. If omitted it will
:param offset: (Optional) The offset into the data. If omitted it will
be read from the query string in the ``?offset=`` argument. If
not query string, defaults to 0.
:param limit: (Optional) The maximum results per page. If omitted it will
:param limit: (Optional) The maximum results per page. If omitted it will
be read from the query string in the ``?limit=`` argument. If
not query string, defaults to 20.
:returns: The page's data in a namedtuple form ``(data=, errors=)``
"""

Expand Down
22 changes: 11 additions & 11 deletions flask_utils/restful/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from flask import jsonify, request


def format_errors(*errors):
return {
'_errors': errors
}


def format_error(error):
return format_errors(error)


def output_json(data, code, headers=None):
"""
.. code-block:: python
api.representations['application/json'] = output_json
Generates better formatted responses for RESTFul APIs.
Generates better formatted responses for RESTFul APIs.
If the restful resource responds with a string, with a non 200 error,
the response will look like
Expand All @@ -36,8 +38,8 @@ def output_json(data, code, headers=None):
"message": "String the user returned with."
}
If a Non-200 response occured, and flask-restful added it's own error
If a Non-200 response occured, and flask-restful added it's own error
message in the "message" field of the response data, this data is moved into
"_errors":
Expand All @@ -49,7 +51,7 @@ def output_json(data, code, headers=None):
}
All data is returned using flask's jsonify method. This means you can
All data is returned using flask's jsonify method. This means you can
use simplejson to return decimal objects from your flask restful resources.
"""

Expand All @@ -75,7 +77,6 @@ def output_json(data, code, headers=None):
return resp



class ExpectedJSONException(Exception):
"""
Thrown when JSON was expected in a flask request but was not
Expand All @@ -91,7 +92,7 @@ def handle(cls, exc):
Usage:
.. code-block:: python
app.errorhandler(ExpectedJSONException)(ExpectedJSONException.handle)
Expand All @@ -110,6 +111,5 @@ def get_and_expect_json():
data = request.get_json()
if data is None:
raise ExpectedJSONException()

return data

return data
16 changes: 9 additions & 7 deletions flask_utils/sentry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from raven.contrib.celery import register_signal, register_logger_signal
from raven.contrib.flask import Sentry


def create_client(conf, app_version='__UNKNOWN__', ignore_common_http=True):
"""Creates a sentry client.
Expand All @@ -13,13 +14,13 @@ def create_client(conf, app_version='__UNKNOWN__', ignore_common_http=True):
ignore_exceptions = []
if ignore_common_http:
ignore_exceptions = [
'werkzeug.exceptions.BadRequest', # 400
'werkzeug.exceptions.Unauthorized', # 401
'werkzeug.exceptions.Forbidden', # 403
'werkzeug.exceptions.NotFound', # 404
'werkzeug.exceptions.MethodNotAllowed', # 405
'marshmallow.exceptions.ValidationError', # Marshmallow Validation Error.
'webargs.core.ValidationError', # Webargs Validation Error
'werkzeug.exceptions.BadRequest', # 400
'werkzeug.exceptions.Unauthorized', # 401
'werkzeug.exceptions.Forbidden', # 403
'werkzeug.exceptions.NotFound', # 404
'werkzeug.exceptions.MethodNotAllowed', # 405
'marshmallow.exceptions.ValidationError', # Marshmallow Validation Error.
'webargs.core.ValidationError', # Webargs Validation Error
]

client = Client(
Expand All @@ -30,6 +31,7 @@ def create_client(conf, app_version='__UNKNOWN__', ignore_common_http=True):
)
return client


def inject_sentry(app, ignore_common_http=True):
"""Injects sentry into a Flask Application
Expand Down
34 changes: 18 additions & 16 deletions flask_utils/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json


class AppReqTestHelper(object):
"""
Adds convenience request methods on the testcase object.
Assumes a flask app client is defined on ``self.client``
"""

def _req(self, meth, *args, **kwargs):
if kwargs.get('content_type') is None and meth != 'get':
kwargs['content_type'] = 'application/json'
Expand Down Expand Up @@ -88,37 +90,38 @@ def do_test_privileges(self, endpoint, data, object_id, expected_codes):
_meth = 'get'

rv = self._req(_meth, endp, **kwargs)
self.assertEqual(rv.status_code, expected_code,
self.assertEqual(rv.status_code, expected_code,
"Expected {} for method {} but got {}. {}".format(
expected_code, meth, rv.status_code, rv.get_json()))
expected_code, meth, rv.status_code, rv.get_json()))


class CRUDTestHelper(AppReqTestHelper):
"""
A helper to test generic CRUD operations on an endpoint.
"""
def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',
check_keys=[], keys_from_prev=[], create=True, delete=True,

def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',
check_keys=[], keys_from_prev=[], create=True, delete=True,
update=True, read=True, initial_count=0):
"""
Begins the CRUD test.
Begins the CRUD test.
:param endpoint: ``string``: The endpoint to test
:param data1: ``dict``: Data to create the initial entity with (POST)
:param data2: ``dict``: Data to update the entity with (PUT)
:param key: ``string``: The key field in the response returned when performing
a create.
:param check_keys: ``list``: A list of keys to compare ``data_1`` and
``data_2`` to returned API responses. (To
:param check_keys: ``list``: A list of keys to compare ``data_1`` and
``data_2`` to returned API responses. (To
ensure expected response data)
:param keys_from_prev: ``list``: A list of keys to check that they persisted
:param keys_from_prev: ``list``: A list of keys to check that they persisted
after a create/update.
:param create: ``bool``: Should create a new object and test it's existence
:param delete: ``bool``: Should delete the newly created object and test
:param delete: ``bool``: Should delete the newly created object and test
that it has been deleted.
:param update: ``bool``: Should performs PUT (update)
:param read: ``bool``: Should perform a plural read
:param initial_count: ``int``: The initial number of entities in the endpoint's
:param initial_count: ``int``: The initial number of entities in the endpoint's
dataset
"""

Expand All @@ -145,7 +148,7 @@ def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',
for item in rv.get_json():
if self.equalDicts(check_keys, item, data_1):
break
else: # nobreak
else: # nobreak
self.fail("Could not find the object that was created in the response.")

else:
Expand All @@ -156,7 +159,7 @@ def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',
# Singular Read
rv = self.get(endpoint + '/' + str(key_id))
self.assertEqual(rv.status_code, 200)
prev_data = rv.get_json() # Keep this data so we can use it after update.
prev_data = rv.get_json() # Keep this data so we can use it after update.
self.assertEqualDicts(check_keys, prev_data, data_1)

if update and create:
Expand All @@ -165,7 +168,7 @@ def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',
self.assertEqual(rv.status_code, 200)
self.assertEqualDicts(check_keys, rv.get_json(), data_2)
self.assertEqualDicts(keys_from_prev, rv.get_json(), prev_data)

if read and create:
# Singular Read to confirm persisted.
rv = self.get(endpoint + '/' + str(key_id))
Expand All @@ -191,11 +194,11 @@ def do_crud_test(self, endpoint, data_1=None, data_2=None, key='id',

def filteredDicts(self, keys, *dicts):
ret = []
for d in dicts:
for d in dicts:
d_filtered = dict((k, v) for k, v in d.items()
if k in keys)
ret.append(d_filtered)

return ret

def assertEqualDicts(self, keys, d1, d2):
Expand All @@ -204,4 +207,3 @@ def assertEqualDicts(self, keys, d1, d2):
def equalDicts(self, keys, d1, d2):
d1, d2 = self.filteredDicts(keys, d1, d2)
return d1 == d2

Loading

0 comments on commit 00bff09

Please sign in to comment.