Skip to content

Commit

Permalink
Merge pull request #68 from briehl/cache_bad_tokens
Browse files Browse the repository at this point in the history
cache bad tokens, better cache testing
  • Loading branch information
briehl committed Apr 8, 2019
2 parents 315ad52 + 1ee1879 commit a2ed4cb
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 9 deletions.
3 changes: 3 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
KBase Feeds Service

### Version 1.0.2
- Add a cache for bad tokens so they aren't looked up over and over. Maxes out at 10000, then throws out the oldest bad token.

### Version 1.0.1
- Fix issue where groups notifications were being seen by users in the target field as well as the users list.

Expand Down
18 changes: 14 additions & 4 deletions feeds/external_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
List
)
from requests.models import Response
from collections import (
OrderedDict
)

config = get_config()
AUTH_URL = config.auth_url
AUTH_API_PATH = '/api/V2/'
CACHE_EXPIRE_TIME = 300 # seconds
MAX_BAD_TOKENS = 10000


class TokenCache(TTLCache):
Expand All @@ -37,14 +41,15 @@ class TokenCache(TTLCache):
"""
def __getitem__(self, key: str, cache_getitem: Any=Cache.__getitem__):
token = super(TokenCache, self).__getitem__(key, cache_getitem=cache_getitem)
if token.get('expires', 0) <= epoch_ms():
if token is not None and token.get('expires', 0) <= epoch_ms():
return self.__missing__(key)
else:
return token


__token_cache = TokenCache(1000, CACHE_EXPIRE_TIME)
__user_cache = TTLCache(1000, CACHE_EXPIRE_TIME)
__bad_token_cache = OrderedDict()


def validate_service_token(token: str) -> str:
Expand Down Expand Up @@ -136,9 +141,11 @@ def __fetch_token(token: str) -> dict:
If the token is invalid or there's any other auth problems, either
an InvalidTokenError or TokenLookupError gets raised.
"""
if token in __bad_token_cache:
raise InvalidTokenError(msg="Invalid token")
try:
fetched = __token_cache.get(token)
except KeyError: # this wants to throw a KeyError in some tests. Don't know why.
except KeyError: # extending the TTLCache is annoying.
fetched = None
if fetched is not None:
return fetched
Expand All @@ -153,7 +160,7 @@ def __fetch_token(token: str) -> dict:
__token_cache[token] = token_info
return token_info
except requests.HTTPError as e:
_handle_errors(e)
_handle_errors(e, token)


def __auth_request(path: str, token: str) -> Response:
Expand All @@ -170,7 +177,7 @@ def __auth_request(path: str, token: str) -> Response:
return r


def _handle_errors(err: Response) -> None:
def _handle_errors(err: Response, token=None) -> None:
"""
Wrapper to handle errors for Auth requests.
Raises either an InvalidTokenError (on a 403) or TokenLookupError as needed.
Expand All @@ -179,6 +186,9 @@ def _handle_errors(err: Response) -> None:
if err.response.status_code == 403:
err_content = json.loads(err.response.content)
err_msg = err_content.get('error', {}).get('apperror', 'Invalid token')
__bad_token_cache[token] = 1
if len(__bad_token_cache) > MAX_BAD_TOKENS:
__bad_token_cache.popitem(last=False)
raise InvalidTokenError(msg=err_msg, http_error=err)
else:
raise TokenLookupError(http_error=err)
2 changes: 1 addition & 1 deletion feeds/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
log_error
)

VERSION = "1.0.1"
VERSION = "1.0.2"

try:
from feeds import gitcommit
Expand Down
13 changes: 9 additions & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import json
import feeds
from feeds.util import epoch_ms
import test.util as test_util
from .util import test_config
from .mongo_controller import MongoController
Expand Down Expand Up @@ -126,7 +127,8 @@ def test_something(mock_valid_user_token):
mock_valid_user_token('someuser', 'Some User')
... continue test ...
"""
def auth_valid_user_token(user_id: str, user_name: str, group_membership: List[Dict[str, str]]=[]):
def auth_valid_user_token(user_id: str, user_name: str, group_membership: List[Dict[str, str]]=[],
expires: int=10000):
"""
group_membership, if present should be a list of dicts, where each dict has the id and name of
a group the "user" is in.
Expand All @@ -144,7 +146,8 @@ def auth_valid_user_token(user_id: str, user_name: str, group_membership: List[D
requests_mock.get('{}/api/V2/token'.format(auth_url), json={
'user': user_id,
'type': 'Login',
'name': None
'name': None,
'expires': epoch_ms() + expires
})
requests_mock.get('{}/api/V2/me'.format(auth_url), json={
'customroles': [],
Expand All @@ -170,7 +173,8 @@ def auth_valid_service_token(user_id, user_name, service_name):
requests_mock.get('{}/api/V2/token'.format(auth_url), json={
'user': user_id,
'type': 'Service',
'name': service_name
'name': service_name,
'expires': epoch_ms() + 10000
})
requests_mock.get('{}/api/V2/me'.format(auth_url), json={
'customroles': [],
Expand All @@ -194,7 +198,8 @@ def auth_valid_admin_token(user_id, user_name):
requests_mock.get('{}/api/V2/token'.format(auth_url), json={
'user': user_id,
'type': 'Login',
'name': None
'name': None,
'expires': epoch_ms() + 10000
})
requests_mock.get('{}/api/V2/me'.format(auth_url), json={
'customroles': ['FEEDS_ADMIN'],
Expand Down
68 changes: 68 additions & 0 deletions test/external_api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
get_auth_token,
is_feeds_admin
)
import feeds.external_api.auth as auth
from ..util import test_config
from feeds.exceptions import InvalidTokenError

cfg = test_config()




def test_validate_service_token_ok(requests_mock):

# test service token
Expand Down Expand Up @@ -68,3 +71,68 @@ def test_is_feeds_admin_ok(requests_mock):
requests_mock.get('{}/api/V2/me'.format(cfg.get('feeds', 'auth-url')),
text=json.dumps({'customroles': []}))
assert not is_feeds_admin(token)


def test_valid_token_cache(mock_valid_user_token):
user = "some_user"
name = "Some User"
token = "some_token" + str(uuid.uuid4())
mock_valid_user_token(user, name, [])

assert token not in auth.__token_cache
info = validate_user_token(token)
assert info == user

assert token in auth.__token_cache
assert auth.__token_cache.get(token)['user'] == user
info = validate_user_token(token)
assert info == user


def test_expired_token_cache(mock_valid_user_token):
user = "some_user"
name = "Some User"
token = "some_token" + str(uuid.uuid4())
mock_valid_user_token(user, name, [], expires=0)

assert token not in auth.__token_cache
info = validate_user_token(token)
assert info == user

with pytest.raises(KeyError):
info = auth.__token_cache[token]


def test_bad_token_cache(mock_invalid_user_token):
user = "some_user"
token = "bad_token" + str(uuid.uuid4())
mock_invalid_user_token(user)

assert token not in auth.__bad_token_cache
with pytest.raises(InvalidTokenError):
validate_user_token(token)
assert token in auth.__bad_token_cache

with pytest.raises(InvalidTokenError) as e:
validate_user_token(token)
assert "Invalid token" in str(e.value)


def test_bad_token_cache_size(mock_invalid_user_token):
user = "some_user"
token = "bad_token" + str(uuid.uuid4())
mock_invalid_user_token(user)

assert token not in auth.__bad_token_cache
with pytest.raises(InvalidTokenError):
validate_user_token(token)
assert token in auth.__bad_token_cache

for i in range(10000):
t = "bad" + str(i)
try:
validate_user_token(t)
except:
pass

assert token not in auth.__bad_token_cache

0 comments on commit a2ed4cb

Please sign in to comment.