diff --git a/.travis.yml b/.travis.yml index 5a7c0e5c..f155ed93 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: python env: - - DJANGO="django>=1.3,<1.4" - DJANGO="django>=1.4,<1.5" - DJANGO="django>=1.5,<1.6" + - DJANGO="django>=1.6,<1.7" python: - "2.6" - "2.7" @@ -12,4 +12,3 @@ install: - pip install -q $DJANGO --use-mirrors -U - python setup.py develop script: ./test.sh - diff --git a/MANIFEST.in b/MANIFEST.in index 35ad499f..ff1a14c8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include LICENSE include README.rst recursive-include provider/templates *.html recursive-include provider/templates *.txt -recursive-include provider/ *json \ No newline at end of file +recursive-include provider *json diff --git a/README.rst b/README.rst index e374b586..7c5ee74c 100644 --- a/README.rst +++ b/README.rst @@ -4,8 +4,13 @@ django-oauth2-provider .. image:: https://travis-ci.org/caffeinehit/django-oauth2-provider.png?branch=master *django-oauth2-provider* is a Django application that provides -customizable OAuth2\_ authentication for your Django projects. +customizable OAuth2\-authentication for your Django projects. `Documentation `_ +`Help `_ +License +======= + +*django-oauth2-provider* is released under the MIT License. Please see the LICENSE file for details. diff --git a/docs/api.rst b/docs/api.rst index ff353693..3d10d5e4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -28,7 +28,7 @@ :settings: `OAUTH_EXPIRE_DELTA` :default: `datetime.timedelta(days=365)` - + The time to expiry for access tokens as outlined in :rfc:`4.2.2` and :rfc:`5.1`. @@ -36,9 +36,17 @@ :settings: `OAUTH_EXPIRE_CODE_DELTA` :default: `datetime.timedelta(seconds=10*60)` - + The time to expiry for an authorization code grant as outlined in :rfc:`4.1.2`. - + +.. attribute:: DELETE_EXPIRED + + :settings: `OAUTH_DELETE_EXPIRED` + :default: `False` + + To remove expired tokens immediately instead of letting them persist, set + to `True`. + .. attribute:: ENFORCE_SECURE :settings: `OAUTH_ENFORCE_SECURE` diff --git a/provider/__init__.py b/provider/__init__.py index 06e1bcae..1c19d463 100644 --- a/provider/__init__.py +++ b/provider/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.2.6.epyx.4" +__version__ = "0.2.7.epyx.5" diff --git a/provider/constants.py b/provider/constants.py index cf92ed42..7f94b8e5 100644 --- a/provider/constants.py +++ b/provider/constants.py @@ -11,6 +11,8 @@ RESPONSE_TYPE_CHOICES = getattr(settings, 'OAUTH_RESPONSE_TYPE_CHOICES', ("code", "token")) +TOKEN_TYPE = 'Bearer' + READ = 1 << 1 WRITE = 1 << 2 READ_WRITE = READ | WRITE @@ -25,8 +27,14 @@ EXPIRE_DELTA = getattr(settings, 'OAUTH_EXPIRE_DELTA', timedelta(days=365)) +# Expiry delta for public clients (which typically have shorter lived tokens) +EXPIRE_DELTA_PUBLIC = getattr(settings, 'OAUTH_EXPIRE_DELTA_PUBLIC', timedelta(days=30)) + EXPIRE_CODE_DELTA = getattr(settings, 'OAUTH_EXPIRE_CODE_DELTA', timedelta(seconds=10 * 60)) +# Remove expired tokens immediately instead of letting them persist. +DELETE_EXPIRED = getattr(settings, 'OAUTH_DELETE_EXPIRED', False) + ENFORCE_SECURE = getattr(settings, 'OAUTH_ENFORCE_SECURE', False) ENFORCE_CLIENT_SECURE = getattr(settings, 'OAUTH_ENFORCE_CLIENT_SECURE', True) diff --git a/provider/oauth2/__init__.py b/provider/oauth2/__init__.py index e69de29b..34796220 100644 --- a/provider/oauth2/__init__.py +++ b/provider/oauth2/__init__.py @@ -0,0 +1,6 @@ +import backends +import forms +import managers +import models +import urls +import views \ No newline at end of file diff --git a/provider/oauth2/backends.py b/provider/oauth2/backends.py index 5da5118a..db0fb853 100644 --- a/provider/oauth2/backends.py +++ b/provider/oauth2/backends.py @@ -1,5 +1,5 @@ from ..utils import now -from .forms import ClientAuthForm +from .forms import ClientAuthForm, PublicPasswordGrantForm from .models import AccessToken @@ -61,6 +61,27 @@ def authenticate(self, request=None): return None +class PublicPasswordBackend(object): + """ + Backend that tries to authenticate a client using username, password + and client ID. This is only available in specific circumstances: + + - grant_type is "password" + - client.client_type is 'public' + """ + + def authenticate(self, request=None): + if request is None: + return None + + form = PublicPasswordGrantForm(request.REQUEST) + + if form.is_valid(): + return form.cleaned_data.get('client') + + return None + + class AccessTokenBackend(object): """ Authenticate a user via access token and client object. diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index 92bb7c2a..4e14bbde 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -56,8 +56,13 @@ def to_python(self, value): if not value: return [] + # New in Django 1.6: value may come in as a string. + # Instead of raising an `OAuthValidationError`, try to parse and + # ultimately return an empty list if nothing remains -- this will + # eventually raise an `OAuthValidationError` in `validate` where + # it should be anyways. if not isinstance(value, (list, tuple)): - raise OAuthValidationError({'error': 'invalid_request'}) + value = value.split(' ') # Split values into list return u' '.join([smart_unicode(val) for val in value]).split(u' ') @@ -301,3 +306,30 @@ def clean(self): data['user'] = user return data + + +class PublicPasswordGrantForm(PasswordGrantForm): + client_id = forms.CharField(required=True) + grant_type = forms.CharField(required=True) + + def clean_grant_type(self): + grant_type = self.cleaned_data.get('grant_type') + + if grant_type != 'password': + raise OAuthValidationError({'error': 'invalid_grant'}) + + return grant_type + + def clean(self): + data = super(PublicPasswordGrantForm, self).clean() + + try: + client = Client.objects.get(client_id=data.get('client_id')) + except Client.DoesNotExist: + raise OAuthValidationError({'error': 'invalid_client'}) + + if client.client_type != 1: # public + raise OAuthValidationError({'error': 'invalid_client'}) + + data['client'] = client + return data diff --git a/provider/oauth2/migrations/0004_auto__add_index_accesstoken_token.py b/provider/oauth2/migrations/0004_auto__add_index_accesstoken_token.py new file mode 100644 index 00000000..9eb6ff64 --- /dev/null +++ b/provider/oauth2/migrations/0004_auto__add_index_accesstoken_token.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +import datetime +from south.db import db +from south.v2 import SchemaMigration +from django.db import models + + +class Migration(SchemaMigration): + + def forwards(self, orm): + # Adding index on 'AccessToken', fields ['token'] + db.create_index('oauth2_accesstoken', ['token']) + + + def backwards(self, orm): + # Removing index on 'AccessToken', fields ['token'] + db.delete_index('oauth2_accesstoken', ['token']) + + + models = { + 'auth.group': { + 'Meta': {'object_name': 'Group'}, + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}), + 'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}) + }, + 'auth.permission': { + 'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'}, + 'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) + }, + 'auth.user': { + 'Meta': {'object_name': 'User'}, + 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}), + 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), + 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}), + 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), + 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), + 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), + 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}), + 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}), + 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'}) + }, + 'contenttypes.contenttype': { + 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, + 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) + }, + 'oauth2.accesstoken': { + 'Meta': {'object_name': 'AccessToken'}, + 'client': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['oauth2.Client']"}), + 'expires': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime(2014, 8, 8, 0, 0)'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'scope': ('django.db.models.fields.IntegerField', [], {'default': '2'}), + 'token': ('django.db.models.fields.CharField', [], {'default': "'ab8c8bcd91e8750462b631516b60b0b95dffe1f4'", 'max_length': '255', 'db_index': 'True'}), + 'user': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']"}) + }, + 'oauth2.client': { + 'Meta': {'object_name': 'Client'}, + 'client_id': ('django.db.models.fields.CharField', [], {'default': "'0a8e54e38c024606ba0a'", 'max_length': '255'}), + 'client_secret': ('django.db.models.fields.CharField', [], {'default': "'e53ddb9736f9eea65100885a1b20fb5f2bb0fb4d'", 'max_length': '255'}), + 'client_type': ('django.db.models.fields.IntegerField', [], {}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'name': ('django.db.models.fields.CharField', [], {'max_length': '255', 'blank': 'True'}), + 'redirect_uri': ('django.db.models.fields.URLField', [], {'max_length': '200'}), + 'url': ('django.db.models.fields.URLField', [], {'max_length': '200'}), + 'user': ('django.db.models.fields.related.ForeignKey', [], {'blank': 'True', 'related_name': "'oauth2_client'", 'null': 'True', 'to': "orm['auth.User']"}) + }, + 'oauth2.grant': { + 'Meta': {'object_name': 'Grant'}, + 'client': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['oauth2.Client']"}), + 'code': ('django.db.models.fields.CharField', [], {'default': "'5e0ca84e98678a3b55b8901e85a20f995672aea2'", 'max_length': '255'}), + 'expires': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime(2013, 8, 8, 0, 0)'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'redirect_uri': ('django.db.models.fields.CharField', [], {'max_length': '255', 'blank': 'True'}), + 'scope': ('django.db.models.fields.IntegerField', [], {'default': '0'}), + 'user': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']"}) + }, + 'oauth2.refreshtoken': { + 'Meta': {'object_name': 'RefreshToken'}, + 'access_token': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'refresh_token'", 'unique': 'True', 'to': "orm['oauth2.AccessToken']"}), + 'client': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['oauth2.Client']"}), + 'expired': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), + 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'token': ('django.db.models.fields.CharField', [], {'default': "'32e9eb7edda764ba8f752ae49223d42acea7cb88'", 'max_length': '255'}), + 'user': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']"}) + } + } + + complete_apps = ['oauth2'] \ No newline at end of file diff --git a/provider/oauth2/models.py b/provider/oauth2/models.py index 55e8763d..18c86909 100644 --- a/provider/oauth2/models.py +++ b/provider/oauth2/models.py @@ -8,9 +8,8 @@ from django.conf import settings from .. import constants from ..constants import CLIENT_TYPES -from ..utils import short_token, long_token, get_token_expiry -from ..utils import get_code_expiry -from ..utils import now +from ..utils import now, short_token, long_token, get_code_expiry +from ..utils import get_token_expiry, serialize_instance, deserialize_instance from .managers import AccessTokenManager try: @@ -49,6 +48,39 @@ class Client(models.Model): def __unicode__(self): return self.name + def get_default_token_expiry(self): + public = (self.client_type == 1) + return get_token_expiry(public) + + def serialize(self): + return dict(user=serialize_instance(self.user), + name=self.name, + url=self.url, + redirect_uri=self.redirect_uri, + client_id=self.client_id, + client_secret=self.client_secret, + client_type=self.client_type) + + @classmethod + def deserialize(cls, data): + if not data: + return None + + kwargs = {} + + # extract values that we care about + for field in cls._meta.fields: + name = field.name + val = data.get(field.name, None) + + # handle relations + if val and field.rel: + val = deserialize_instance(field.rel.to, val) + + kwargs[name] = val + + return cls(**kwargs) + class Grant(models.Model): """ @@ -98,9 +130,9 @@ class AccessToken(models.Model): expiry """ user = models.ForeignKey(AUTH_USER_MODEL) - token = models.CharField(max_length=255, default=long_token) + token = models.CharField(max_length=255, default=long_token, db_index=True) client = models.ForeignKey(Client) - expires = models.DateTimeField(default=get_token_expiry) + expires = models.DateTimeField() scope = models.IntegerField(default=constants.SCOPES[0][0], choices=constants.SCOPES) @@ -109,6 +141,11 @@ class AccessToken(models.Model): def __unicode__(self): return self.token + def save(self, *args, **kwargs): + if not self.expires: + self.expires = self.client.get_default_token_expiry() + super(AccessToken, self).save(*args, **kwargs) + def get_expire_delta(self, reference=None): """ Return the number of seconds until this token expires. diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 96b92bf1..1ed57f79 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -12,7 +12,7 @@ from ..templatetags.scope import scopes from ..utils import now as date_now from .forms import ClientForm -from .models import Client, Grant, AccessToken +from .models import Client, Grant, AccessToken, RefreshToken from .backends import BasicClientBackend, RequestParamsClientBackend from .backends import AccessTokenBackend @@ -195,6 +195,11 @@ def test_preserving_the_state_variable(self): self.assertTrue('code' in response['Location']) self.assertTrue('state=abc' in response['Location']) + def test_redirect_requires_valid_data(self): + self.login() + response = self.client.get(self.redirect_url()) + self.assertEqual(400, response.status_code) + class AccessTokenTest(BaseOAuth2TestCase): fixtures = ['test_oauth2.json'] @@ -234,6 +239,8 @@ def test_fetching_access_token_with_invalid_grant(self): self.assertEqual('invalid_grant', json.loads(response.content)['error']) def _login_authorize_get_token(self): + required_props = ['access_token', 'token_type'] + self.login() self._login_and_authorize() @@ -249,7 +256,13 @@ def _login_authorize_get_token(self): self.assertEqual(200, response.status_code, response.content) - return json.loads(response.content) + token = json.loads(response.content) + + for prop in required_props: + self.assertIn(prop, token, "Access token response missing " + "required property: %s" % prop) + + return token def test_fetching_access_token_with_valid_grant(self): self._login_authorize_get_token() @@ -283,6 +296,23 @@ def test_fetching_single_access_token(self): constants.SINGLE_ACCESS_TOKEN = False + def test_fetching_single_access_token_after_refresh(self): + constants.SINGLE_ACCESS_TOKEN = True + + token = self._login_authorize_get_token() + + self.client.post(self.access_token_url(), { + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + }) + + new_token = self._login_authorize_get_token() + self.assertNotEqual(token['access_token'], new_token['access_token']) + + constants.SINGLE_ACCESS_TOKEN = False + def test_fetching_access_token_multiple_times(self): self._login_authorize_get_token() code = self.get_grant().code @@ -334,21 +364,79 @@ def test_refreshing_an_access_token(self): self.assertEqual('invalid_grant', json.loads(response.content)['error'], response.content) - def test_password_grant(self): + def test_password_grant_public(self): + c = self.get_client() + c.client_type = 1 # public + c.save() + response = self.client.post(self.access_token_url(), { 'grant_type': 'password', - 'client_id': self.get_client().client_id, - 'client_secret': self.get_client().client_secret, + 'client_id': c.client_id, + # No secret needed 'username': self.get_user().username, 'password': self.get_password(), }) self.assertEqual(200, response.status_code, response.content) + self.assertNotIn('refresh_token', json.loads(response.content)) + expires_in = json.loads(response.content)['expires_in'] + expires_in_days = round(expires_in / (60.0 * 60.0 * 24.0)) + self.assertEqual(expires_in_days, constants.EXPIRE_DELTA_PUBLIC.days) + + def test_password_grant_confidential(self): + c = self.get_client() + c.client_type = 0 # confidential + c.save() response = self.client.post(self.access_token_url(), { 'grant_type': 'password', - 'client_id': self.get_client().client_id, - 'client_secret': self.get_client().client_secret, + 'client_id': c.client_id, + 'client_secret': c.client_secret, + 'username': self.get_user().username, + 'password': self.get_password(), + }) + + self.assertEqual(200, response.status_code, response.content) + self.assertTrue(json.loads(response.content)['refresh_token']) + + def test_password_grant_confidential_no_secret(self): + c = self.get_client() + c.client_type = 0 # confidential + c.save() + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'password', + 'client_id': c.client_id, + 'username': self.get_user().username, + 'password': self.get_password(), + }) + + self.assertEqual('invalid_client', json.loads(response.content)['error']) + + def test_password_grant_invalid_password_public(self): + c = self.get_client() + c.client_type = 1 # public + c.save() + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'password', + 'client_id': c.client_id, + 'username': self.get_user().username, + 'password': self.get_password() + 'invalid', + }) + + self.assertEqual(400, response.status_code, response.content) + self.assertEqual('invalid_client', json.loads(response.content)['error']) + + def test_password_grant_invalid_password_confidential(self): + c = self.get_client() + c.client_type = 0 # confidential + c.save() + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'password', + 'client_id': c.client_id, + 'client_secret': c.client_secret, 'username': self.get_user().username, 'password': self.get_password() + 'invalid', }) @@ -356,6 +444,10 @@ def test_password_grant(self): self.assertEqual(400, response.status_code, response.content) self.assertEqual('invalid_grant', json.loads(response.content)['error']) + def test_access_token_response_valid_token_type(self): + token = self._login_authorize_get_token() + self.assertEqual(token['token_type'], constants.TOKEN_TYPE, token) + class AuthBackendTest(BaseOAuth2TestCase): fixtures = ['test_oauth2'] @@ -461,3 +553,71 @@ def test_template_filter(self): names.sort() self.assertEqual('read read+write write', ' '.join(names)) + + +class DeleteExpiredTest(BaseOAuth2TestCase): + fixtures = ['test_oauth2'] + + def setUp(self): + self._delete_expired = constants.DELETE_EXPIRED + constants.DELETE_EXPIRED = True + + def tearDown(self): + constants.DELETE_EXPIRED = self._delete_expired + + def test_clear_expired(self): + self.login() + + self._login_and_authorize() + + response = self.client.get(self.redirect_url()) + + self.assertEqual(302, response.status_code) + location = response['Location'] + self.assertFalse('error' in location) + self.assertTrue('code' in location) + + # verify that Grant with code exists + code = urlparse.parse_qs(location)['code'][0] + self.assertTrue(Grant.objects.filter(code=code).exists()) + + # use the code/grant + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + 'code': code}) + self.assertEquals(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + access_token = token['access_token'] + self.assertTrue('refresh_token' in token) + refresh_token = token['refresh_token'] + + # make sure the grant is gone + self.assertFalse(Grant.objects.filter(code=code).exists()) + # and verify that the AccessToken and RefreshToken exist + self.assertTrue(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertTrue(RefreshToken.objects.filter(token=refresh_token) + .exists()) + + # refresh the token + response = self.client.post(self.access_token_url(), { + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + }) + self.assertEqual(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + self.assertNotEquals(access_token, token['access_token']) + self.assertTrue('refresh_token' in token) + self.assertNotEquals(refresh_token, token['refresh_token']) + + # make sure the orig AccessToken and RefreshToken are gone + self.assertFalse(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertFalse(RefreshToken.objects.filter(token=refresh_token) + .exists()) diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index 3e599292..c8ecb701 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -1,5 +1,6 @@ from datetime import timedelta from django.core.urlresolvers import reverse +from .. import constants from ..views import Capture, Authorize, Redirect from ..views import AccessToken as AccessTokenView, OAuthError from ..utils import now @@ -7,7 +8,7 @@ from .forms import PasswordGrantForm, RefreshTokenGrantForm from .forms import AuthorizationCodeGrantForm from .models import Client, RefreshToken, AccessToken -from .backends import BasicClientBackend, RequestParamsClientBackend +from .backends import BasicClientBackend, RequestParamsClientBackend, PublicPasswordBackend class Capture(Capture): @@ -70,6 +71,7 @@ class AccessTokenView(AccessTokenView): authentication = ( BasicClientBackend, RequestParamsClientBackend, + PublicPasswordBackend, ) def get_authorization_code_grant(self, request, data, client): @@ -93,8 +95,8 @@ def get_password_grant(self, request, data, client): def get_access_token(self, request, user, scope, client): try: # Attempt to fetch an existing valid access token. - at = AccessToken.objects.get(user=user, client=client, scope=scope, - expires__gt=now()) + at = AccessToken.objects.get(user=user, client=client, + scope=scope, expires__gt=now()) except AccessToken.DoesNotExist: # None found... make a new one! at = self.create_access_token(request, user, scope, client) @@ -116,13 +118,22 @@ def create_refresh_token(self, request, user, scope, access_token, client): ) def invalidate_grant(self, grant): - grant.expires = now() - timedelta(days=1) - grant.save() + if constants.DELETE_EXPIRED: + grant.delete() + else: + grant.expires = now() - timedelta(days=1) + grant.save() def invalidate_refresh_token(self, rt): - rt.expired = True - rt.save() + if constants.DELETE_EXPIRED: + rt.delete() + else: + rt.expired = True + rt.save() def invalidate_access_token(self, at): - at.expires = now() - timedelta(days=1) - at.save() + if constants.DELETE_EXPIRED: + at.delete() + else: + at.expires = now() - timedelta(days=1) + at.save() diff --git a/provider/tests/__init__.py b/provider/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/provider/tests/test_utils.py b/provider/tests/test_utils.py new file mode 100644 index 00000000..72aea0be --- /dev/null +++ b/provider/tests/test_utils.py @@ -0,0 +1,35 @@ +""" +Test cases for functionality provided by the provider.utils module +""" + +from datetime import datetime, time, date +from django.test import TestCase +from django.db import models +from .. import utils + + +class UtilsTestCase(TestCase): + def test_serialization(self): + class SomeModel(models.Model): + dt = models.DateTimeField() + t = models.TimeField() + d = models.DateField() + instance = SomeModel(dt=datetime.now(), + d=date.today(), + t=datetime.now().time()) + instance.nonfield = 'hello' + data = utils.serialize_instance(instance) + instance2 = utils.deserialize_instance(SomeModel, data) + self.assertEqual(instance.nonfield, instance2.nonfield) + self.assertEqual(instance.d, instance2.d) + self.assertEqual(instance.dt.date(), instance2.dt.date()) + for t1, t2 in [(instance.t, instance2.t), + (instance.dt.time(), instance2.dt.time())]: + self.assertEqual(t1.hour, t2.hour) + self.assertEqual(t1.minute, t2.minute) + self.assertEqual(t1.second, t2.second) + # AssertionError: + # datetime.time(10, 6, 28, 705776) != + # datetime.time(10, 6, 28, 705000) + self.assertEqual(int(t1.microsecond/1000), + int(t2.microsecond/1000)) diff --git a/provider/utils.py b/provider/utils.py index 981a1795..957a5c75 100644 --- a/provider/utils.py +++ b/provider/utils.py @@ -2,7 +2,17 @@ import shortuuid from datetime import datetime, tzinfo from django.conf import settings -from .constants import EXPIRE_DELTA, EXPIRE_CODE_DELTA +from django.utils import dateparse +from django.db.models.fields import (DateTimeField, DateField, + EmailField, TimeField, + FieldDoesNotExist) +from django.core.serializers.json import DjangoJSONEncoder +from .constants import EXPIRE_DELTA, EXPIRE_DELTA_PUBLIC, EXPIRE_CODE_DELTA + +try: + import json +except ImporError: + import simplejson as json try: from django.utils import timezone @@ -35,13 +45,16 @@ def long_token(): return hash.hexdigest() -def get_token_expiry(): +def get_token_expiry(public=True): """ Return a datetime object indicating when an access token should expire. Can be customized by setting :attr:`settings.OAUTH_EXPIRE_DELTA` to a :attr:`datetime.timedelta` object. """ - return now() + EXPIRE_DELTA + if public: + return now() + EXPIRE_DELTA_PUBLIC + else: + return now() + EXPIRE_DELTA def get_code_expiry(): @@ -52,3 +65,36 @@ def get_code_expiry(): :attr:`datetime.timedelta` object. """ return now() + EXPIRE_CODE_DELTA + + +def serialize_instance(instance): + """ + Since Django 1.6 items added to the session are no longer pickled, + but JSON encoded by default. We are storing partially complete models + in the session (user, account, token, ...). We cannot use standard + Django serialization, as these are models are not "complete" yet. + Serialization will start complaining about missing relations et al. + """ + ret = dict([(k, v) + for k, v in instance.__dict__.items() + if not k.startswith('_')]) + return json.loads(json.dumps(ret, cls=DjangoJSONEncoder)) + + +def deserialize_instance(model, data={}): + "Translate raw data into a model instance." + ret = model() + for k, v in data.items(): + if v is not None: + try: + f = model._meta.get_field(k) + if isinstance(f, DateTimeField): + v = dateparse.parse_datetime(v) + elif isinstance(f, TimeField): + v = dateparse.parse_time(v) + elif isinstance(f, DateField): + v = dateparse.parse_date(v) + except FieldDoesNotExist: + pass + setattr(ret, k, v) + return ret diff --git a/provider/views.py b/provider/views.py index a774d4df..dd1200df 100644 --- a/provider/views.py +++ b/provider/views.py @@ -4,6 +4,8 @@ from django.http import HttpResponseRedirect, QueryDict from django.utils.translation import ugettext as _ from django.views.generic.base import TemplateView +from django.core.exceptions import ObjectDoesNotExist +from oauth2.models import Client from . import constants, scope @@ -259,7 +261,6 @@ def handle(self, request, post_data=None): authorization_form = self.get_authorization_form(request, client, post_data, data) - if not authorization_form.is_bound or not authorization_form.is_valid(): return self.render_to_response({ 'client': client, @@ -269,9 +270,11 @@ def handle(self, request, post_data=None): code = self.save_authorization(request, client, authorization_form, data) + # be sure to serialize any objects that aren't natively json + # serializable because these values are stored as session data self.cache_data(request, data) self.cache_data(request, code, "code") - self.cache_data(request, client, "client") + self.cache_data(request, client.serialize(), "client") return HttpResponseRedirect(self.get_redirect_url(request)) @@ -288,12 +291,33 @@ class Redirect(OAuthView, Mixin): This can be either parameters indicating success or parameters indicating an error. """ + + def error_response(self, error, mimetype='application/json', status=400, + **kwargs): + """ + Return an error response to the client with default status code of + *400* stating the error as outlined in :rfc:`5.2`. + """ + return HttpResponse(json.dumps(error), mimetype=mimetype, + status=status, **kwargs) + def get(self, request): data = self.get_data(request) code = self.get_data(request, "code") error = self.get_data(request, "error") client = self.get_data(request, "client") + # client must be properly deserialized to become a valid instance + client = Client.deserialize(client) + + # this is an edge case that is caused by making a request with no data + # it should only happen if this view is called manually, out of the + # normal capture-authorize-redirect flow. + if data is None or client is None: + return self.error_response({ + 'error': 'invalid_data', + 'error_description': _('Data has not been captured')}) + redirect_uri = data.get('redirect_uri', None) or client.redirect_uri parsed = urlparse.urlparse(redirect_uri) @@ -447,13 +471,24 @@ def access_token_response(self, access_token): Returns a successful response after creating the access token as defined in :rfc:`5.1`. """ + + response_data = { + 'access_token': access_token.token, + 'token_type': constants.TOKEN_TYPE, + 'expires_in': access_token.get_expire_delta(), + 'scope': ' '.join(scope.names(access_token.scope)), + } + + # Not all access_tokens are given a refresh_token + # (for example, public clients doing password auth) + try: + rt = access_token.refresh_token + response_data['refresh_token'] = rt.token + except ObjectDoesNotExist: + pass + return HttpResponse( - json.dumps({ - 'access_token': access_token.token, - 'expires_in': access_token.get_expire_delta(), - 'refresh_token': access_token.refresh_token.token, - 'scope': ' '.join(scope.names(access_token.scope)), - }), mimetype='application/json' + json.dumps(response_data), mimetype='application/json' ) def authorization_code(self, request, data, client): @@ -480,6 +515,7 @@ def refresh_token(self, request, data, client): """ rt = self.get_refresh_token_grant(request, data, client) + # this must be called first in case we need to purge expired tokens self.invalidate_refresh_token(rt) self.invalidate_access_token(rt.access_token) @@ -502,7 +538,9 @@ def password(self, request, data, client): at = self.get_access_token(request, user, scope, client) else: at = self.create_access_token(request, user, scope, client) - rt = self.create_refresh_token(request, user, scope, at, client) + # Public clients don't get refresh tokens + if client.client_type != 1: + rt = self.create_refresh_token(request, user, scope, at, client) return self.access_token_response(at) diff --git a/requirements.txt b/requirements.txt index cf610e53..a79a5d9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -Django>=1.3 +Django>=1.4 shortuuid>=0.3 diff --git a/setup.py b/setup.py index 50f4627a..f9f55b85 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ classifiers=[ 'Environment :: Web Environment', 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', + 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Framework :: Django', diff --git a/test.sh b/test.sh index 5fb7179b..094e309b 100755 --- a/test.sh +++ b/test.sh @@ -1,5 +1,18 @@ #!/bin/bash -python manage.py test provider oauth2 --traceback --failfast +DJ_VERSION=$(django-admin.py --version) +# exit if fail +[[ "$?" -ne "0" ]] && exit; +IS_16=$(echo $DJ_VERSION | grep "1.6") + +# if django version is not 1.6 (non-0 exit) we have to pass different +# app names to test runner +if [ "$?" -ne "1" ]; then + app_names=( provider provider.oauth2 ) +else + app_names=( provider oauth2 ) +fi + +python manage.py test ${app_names[@]} --traceback --failfast diff --git a/tox.ini b/tox.ini index 07fad6b2..41860909 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] downloadcache = {toxworkdir}/cache/ -envlist = py2.7-django.dev,py2.7-django1.5,py2.7-django1.4,py2.7-django1.3,py2.6-django.dev,py2.6-django1.5,py2.6-django1.4,py2.6-django1.3 +envlist = py2.7-django.dev,py2.7-django1.6,py2.7-django1.5,py2.7-django1.4,py2.6-django.dev,py2.6-django1.6,py2.6-django1.5,py2.6-django1.4 [testenv] setenv = @@ -14,6 +14,11 @@ basepython = python2.7 deps = https://github.com/django/django/zipball/master {[testenv]deps} +[testenv:py2.7-django1.6] +basepython = python2.7 +deps = django>=1.6,<1.7 + {[testenv]deps} + [testenv:py2.7-django1.5] basepython = python2.7 deps = django>=1.5,<1.6 @@ -24,16 +29,16 @@ basepython = python2.7 deps = django>=1.4,<1.5 {[testenv]deps} -[testenv:py2.7-django1.3] -basepython = python2.7 -deps = django>=1.3,<1.4 - {[testenv]deps} - [testenv:py2.6-django.dev] basepython = python2.6 deps = https://github.com/django/django/zipball/master {[testenv]deps} +[testenv:py2.6-django1.6] +basepython = python2.6 +deps = django>=1.6,<1.7 + {[testenv]deps} + [testenv:py2.6-django1.5] basepython = python2.6 deps = django>=1.5,<1.6 @@ -43,9 +48,3 @@ deps = django>=1.5,<1.6 basepython = python2.6 deps = django>=1.4,<1.5 {[testenv]deps} - -[testenv:py2.6-django1.3] -basepython = python2.6 -deps = django>=1.3,<1.4 - {[testenv]deps} -