Permalink
Browse files

Work in progress

  • Loading branch information...
1 parent d95a8c9 commit f1447b2adcbc19c2172060f797ecda0f51674cac @mattupstate committed Dec 19, 2013
View
@@ -19,7 +19,7 @@
from werkzeug.datastructures import ImmutableList
from werkzeug.local import LocalProxy
-from .utils import config_value as cv, get_config, md5, url_for_security
+from .utils import config_value as cv, get_config, md5, url_for_security, string_types
from .views import create_blueprint
from .forms import LoginForm, ConfirmRegisterForm, RegisterForm, \
ForgotPasswordForm, ChangePasswordForm, ResetPasswordForm, \
@@ -249,6 +249,7 @@ def _context_processor():
class RoleMixin(object):
"""Mixin for `Role` model definitions"""
+
def __eq__(self, other):
return (self.name == other or
self.name == getattr(other, 'name', None))
@@ -273,7 +274,7 @@ def has_role(self, role):
"""Returns `True` if the user identifies with the specified role.
:param role: A role name or `Role` instance"""
- if isinstance(role, basestring):
+ if isinstance(role, string_types):
return role in (role.name for role in self.roles)
else:
return role in self.roles
@@ -9,7 +9,7 @@
:license: MIT, see LICENSE for more details.
"""
-from .utils import get_identity_attributes
+from .utils import get_identity_attributes, string_types
class Datastore(object):
@@ -68,9 +68,9 @@ def __init__(self, user_model, role_model):
self.role_model = role_model
def _prepare_role_modify_args(self, user, role):
- if isinstance(user, basestring):
+ if isinstance(user, string_types):
user = self.find_user(email=user)
- if isinstance(role, basestring):
+ if isinstance(role, string_types):
role = self.find_role(role)
return user, role
@@ -105,6 +105,7 @@ def add_role_to_user(self, user, role):
user, role = self._prepare_role_modify_args(user, role)
if role not in user.roles:
user.roles.append(role)
+ self.put(user)
return True
return False
@@ -161,8 +162,8 @@ def find_or_create_role(self, name, **kwargs):
def create_user(self, **kwargs):
"""Creates and returns a new user from the given parameters."""
-
- user = self.user_model(**self._prepare_create_user_args(**kwargs))
+ kwargs = self._prepare_create_user_args(**kwargs)
+ user = self.user_model(**kwargs)
return self.put(user)
def delete_user(self, user):
View
@@ -10,9 +10,10 @@
"""
import inspect
-import urlparse
-
-import flask_wtf as wtf
+try:
+ from urlparse import urlsplit
+except ImportError:
+ from urllib.parse import urlsplit
from flask import request, current_app
from flask_wtf import Form as BaseForm
@@ -22,7 +23,7 @@
from werkzeug.local import LocalProxy
from .confirmable import requires_confirmation
-from .utils import verify_and_update_password, get_message, encrypt_password, config_value
+from .utils import verify_and_update_password, get_message, config_value
# Convenient reference
_datastore = LocalProxy(lambda: current_app.extensions['security'].datastore)
@@ -137,8 +138,8 @@ class NextFormMixin():
def validate_next(self, field):
if field.data:
- url_next = urlparse.urlsplit(field.data)
- url_base = urlparse.urlsplit(request.host_url)
+ url_next = urlsplit(field.data)
+ url_base = urlsplit(request.host_url)
if url_next.netloc and url_next.netloc != url_base.netloc:
field.data = ''
raise ValidationError(get_message('INVALID_REDIRECT')[0])
View
@@ -14,6 +14,7 @@
import functools
import hashlib
import hmac
+import sys
from contextlib import contextmanager
from datetime import datetime, timedelta
@@ -37,6 +38,15 @@
_pwd_context = LocalProxy(lambda: _security.pwd_context)
+PY3 = sys.version_info[0] == 3
+
+if PY3:
+ string_types = str,
+ text_type = str
+else:
+ string_types = basestring,
+ text_type = unicode
+
def login_user(user, remember=None):
"""Performs the login routine.
@@ -85,16 +95,13 @@ def get_hmac(password):
:param password: The password to sign
"""
- if _security.password_hash == 'plaintext':
- return password
-
if _security.password_salt is None:
raise RuntimeError(
'The configuration value `SECURITY_PASSWORD_SALT` must '
'not be None when the value of `SECURITY_PASSWORD_HASH` is '
'set to "%s"' % _security.password_hash)
- h = hmac.new(_security.password_salt, password.encode('utf-8'), hashlib.sha512)
+ h = hmac.new(_security.password_salt.encode('utf-8'), password.encode('utf-8'), hashlib.sha512)
return base64.b64encode(h.digest())
@@ -104,7 +111,7 @@ def verify_password(password, password_hash):
:param password: A plaintext password to verify
:param password_hash: The expected hash value of the password (usually form your database)
"""
- return _pwd_context.verify(get_hmac(password), password_hash)
+ return _pwd_context.verify(encrypt_password(password), password_hash)
def verify_and_update_password(password, user):
@@ -114,7 +121,7 @@ def verify_and_update_password(password, user):
:param password: A plaintext password to verify
:param user: The user to verify against
"""
- verified, new_password = _pwd_context.verify_and_update(get_hmac(password), user.password)
+ verified, new_password = _pwd_context.verify_and_update(encrypt_password(password), user.password)
if verified and new_password:
user.password = new_password
_datastore.put(user)
@@ -126,11 +133,14 @@ def encrypt_password(password):
:param password: The plaintext passwrod to encrypt
"""
- return _pwd_context.encrypt(get_hmac(password))
+ if _security.password_hash == 'plaintext':
+ return password
+ signed = get_hmac(password)
+ return _pwd_context.encrypt(signed.decode('ascii'))
def md5(data):
- return hashlib.md5(data).hexdigest()
+ return hashlib.md5(data.encode('ascii')).hexdigest()
def do_flash(message, category=None):
@@ -408,19 +418,19 @@ def _record(self, signal, *args, **kwargs):
self._records[signal].append((args, kwargs))
def __enter__(self):
- for signal, receiver in self._receivers.iteritems():
+ for signal, receiver in self._receivers.items():
signal.connect(receiver)
return self
def __exit__(self, type, value, traceback):
- for signal, receiver in self._receivers.iteritems():
+ for signal, receiver in self._receivers.items():
signal.disconnect(receiver)
def signals_sent(self):
"""Return a set of the signals sent.
:rtype: list of blinker `NamedSignals`.
"""
- return set([signal for signal, _ in self._records.iteritems() if self._records[signal]])
+ return set([signal for signal, _ in self._records.items() if self._records[signal]])
def capture_signals():
View
@@ -61,7 +61,7 @@ def logout(self, endpoint=None):
return self._get(endpoint or '/logout', follow_redirects=True)
def assertIsHomePage(self, data):
- self.assertIn('Home Page', data)
+ self.assertIn(b'Home Page', data)
def assertIn(self, member, container, msg=None):
if hasattr(TestCase, 'assertIn'):
Oops, something went wrong.

0 comments on commit f1447b2

Please sign in to comment.