Skip to content

Commit

Permalink
datastore: get user by numeric identity attribute (#59)
Browse files Browse the repository at this point in the history
* BETTER Supports login by a numeric identity attribute (e.g. phone
  number, social security number).

Signed-off-by: Jiri Kuncar <jiri.kuncar@gmail.com>
  • Loading branch information
jwag956 committed May 6, 2019
1 parent 113d16b commit f982c5a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 39 deletions.
64 changes: 38 additions & 26 deletions flask_security/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def _prepare_create_user_args(self, **kwargs):
kwargs['roles'] = roles
return kwargs

def _is_numeric(self, value):
try:
int(value)
except (TypeError, ValueError):
return False
return True

def get_user(self, id_or_email):
"""Returns a user matching the specified ID or email address."""
raise NotImplementedError
Expand Down Expand Up @@ -239,22 +246,20 @@ def get_user(self, identifier):
from sqlalchemy.orm import joinedload
user_model_query = user_model_query.options(joinedload('roles'))

if self._is_numeric(identifier):
return user_model_query.get(identifier)
rv = self.user_model.query.get(identifier)
if rv is not None:
return rv

# Not PK - iterate through other attributes and look for 'identifier'
for attr in get_identity_attributes():
# Look for exact case-insensitive match - 'ilike' honors wild cards
# which isn't what we want.
query = alchemyFn.lower(getattr(self.user_model, attr)) \
== alchemyFn.lower(identifier)
rv = user_model_query.filter(query).first()
if rv is not None:
return rv

def _is_numeric(self, value):
try:
int(value)
except (TypeError, ValueError):
return False
return True

def find_user(self, **kwargs):
query = self.user_model.query
if hasattr(self.user_model, 'roles'):
Expand Down Expand Up @@ -310,12 +315,20 @@ def get_user(self, identifier):
return self.user_model.objects(id=identifier).first()
except (ValidationError, ValueError):
pass

is_numeric = self._is_numeric(identifier)

for attr in get_identity_attributes():
query_key = '%s__iexact' % attr
query_key = attr if is_numeric else '%s__iexact' % attr
query = {query_key: identifier}
rv = self.user_model.objects(**query).first()
if rv is not None:
return rv
try:
rv = self.user_model.objects(**query).first()
if rv is not None:
return rv
except (ValidationError, ValueError):
# This can happen if identifier is a string but attribute is
# an int.
pass

def find_user(self, **kwargs):
try:
Expand Down Expand Up @@ -360,15 +373,15 @@ def get_user(self, identifier):
from peewee import fn as peeweeFn
try:
return self.user_model.get(self.user_model.id == identifier)
except ValueError:
except (self.user_model.DoesNotExist, ValueError):
pass

for attr in get_identity_attributes():
column = getattr(self.user_model, attr)
try:
return self.user_model.get(
peeweeFn.Lower(column) == peeweeFn.Lower(identifier))
except self.user_model.DoesNotExist:
except (self.user_model.DoesNotExist, ValueError):
pass

def find_user(self, **kwargs):
Expand Down Expand Up @@ -443,22 +456,21 @@ def __init__(self, db, user_model, role_model):

@with_pony_session
def get_user(self, identifier):
if self._is_numeric(identifier):
from pony.orm.core import ObjectNotFound
try:
return self.user_model[identifier]
except (ObjectNotFound, ValueError):
pass

for attr in get_identity_attributes():
# this is a nightmare, tl;dr we need to get the thing that
# corresponds to email (usually)
user = self.user_model.get(**{attr: identifier})
if user is not None:
return user

def _is_numeric(self, value):
try:
int(value)
except ValueError:
return False
return True
try:
user = self.user_model.get(**{attr: identifier})
if user is not None:
return user
except (TypeError, ValueError):
pass

@with_pony_session
def find_user(self, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class User(db.Document, UserMixin):
email = db.StringField(unique=True, max_length=255)
username = db.StringField(max_length=255)
password = db.StringField(required=False, max_length=255)
security_number = db.IntField(unique=True)
last_login_at = db.DateTimeField()
current_login_at = db.DateTimeField()
last_login_ip = db.StringField(max_length=100)
Expand Down Expand Up @@ -211,6 +212,7 @@ class Role(db.Model, RoleMixin):
class User(db.Model, UserMixin):
id = db.Column(db.Integer, primary_key=True)
email = db.Column(db.String(255), unique=True)
security_number = db.Column(db.Integer, unique=True)
username = db.Column(db.String(255))
password = db.Column(db.String(255))
last_login_at = db.Column(db.DateTime())
Expand Down Expand Up @@ -274,6 +276,7 @@ class User(Base, UserMixin):
email = Column(String(255), unique=True)
username = Column(String(255))
password = Column(String(255))
security_number = Column(Integer, unique=True)
last_login_at = Column(DateTime())
current_login_at = Column(DateTime())
last_login_ip = Column(String(100))
Expand Down Expand Up @@ -319,6 +322,7 @@ class Role(db.Model, RoleMixin):
class User(db.Model, UserMixin):
email = TextField()
username = TextField()
security_number = IntegerField(null=True)
password = TextField(null=True)
last_login_at = DateTimeField(null=True)
current_login_at = DateTimeField(null=True)
Expand Down Expand Up @@ -366,6 +370,7 @@ class Role(db.Entity):
class User(db.Entity):
email = Required(str)
username = Optional(str)
security_number = Optional(int)
password = Optional(str, nullable=True)
last_login_at = Optional(datetime)
current_login_at = Optional(datetime)
Expand Down
28 changes: 25 additions & 3 deletions tests/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from pytest import raises
from utils import init_app_with_options, get_num_queries
from utils import init_app_with_options, get_num_queries, is_sqlalchemy

from flask_security import RoleMixin, Security, UserMixin
from flask_security.datastore import Datastore, UserDatastore
Expand Down Expand Up @@ -83,7 +83,8 @@ def test_activate_returns_false_if_already_true():

def test_get_user(app, datastore):
init_app_with_options(app, datastore, **{
'SECURITY_USER_IDENTITY_ATTRIBUTES': ('email', 'username')
'SECURITY_USER_IDENTITY_ATTRIBUTES': ('email', 'username',
'security_number')
})

with app.app_context():
Expand All @@ -98,10 +99,31 @@ def test_get_user(app, datastore):
user = datastore.get_user('matt')
assert user is not None

# Regression check
# Regression check (make sure we don't match wildcards)
user = datastore.get_user('%lp.com')
assert user is None

# Verify that numeric non PK works
user = datastore.get_user(123456)
assert user is not None


def test_find_user(app, datastore):
init_app_with_options(app, datastore)

with app.app_context():
user_id = datastore.find_user(email='gene@lp.com').id

current_nqueries = get_num_queries(datastore)
assert user_id == datastore.find_user(security_number=889900).id
end_nqueries = get_num_queries(datastore)
if current_nqueries is not None:
if is_sqlalchemy(datastore):
# This should have done just 1 query across all attrs.
assert end_nqueries == (current_nqueries + 1)

assert user_id == datastore.find_user(username='gene').id


def test_find_role(app, datastore):
init_app_with_options(app, datastore)
Expand Down
26 changes: 16 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ def create_roles(ds):


def create_users(ds, count=None):
users = [('matt@lp.com', 'matt', 'password', ['admin'], True),
('joe@lp.com', 'joe', 'password', ['editor'], True),
('dave@lp.com', 'dave', 'password', ['admin', 'editor'], True),
('jill@lp.com', 'jill', 'password', ['author'], True),
('tiya@lp.com', 'tiya', 'password', [], False),
('gene@lp.com', 'gene', 'password', [], True),
('jess@lp.com', 'jess', None, [], True)]
users = [('matt@lp.com', 'matt', 'password', ['admin'], True, 123456),
('joe@lp.com', 'joe', 'password', ['editor'], True, 234567),
('dave@lp.com', 'dave', 'password', ['admin', 'editor'], True,
345678),
('jill@lp.com', 'jill', 'password', ['author'], True, 456789),
('tiya@lp.com', 'tiya', 'password', [], False, 567890),
('gene@lp.com', 'gene', 'password', [], True, 889900),
('jess@lp.com', 'jess', None, [], True, 678901)]
count = count or len(users)

for u in users[:count]:
Expand All @@ -69,7 +70,8 @@ def create_users(ds, count=None):
email=u[0],
username=u[1],
password=pw,
active=u[4])
active=u[4],
security_number=u[5])
ds.commit()
for role in roles:
ds.add_role_to_user(user, role)
Expand Down Expand Up @@ -108,8 +110,12 @@ def get_num_queries(datastore):
""" Return # of queries executed during test.
return None if datastore doesn't support this.
"""
if isinstance(datastore, SQLAlchemyUserDatastore) and\
not isinstance(datastore, SQLAlchemySessionUserDatastore):
if is_sqlalchemy(datastore):
from flask_sqlalchemy import get_debug_queries
return len(get_debug_queries())
return None


def is_sqlalchemy(datastore):
return isinstance(datastore, SQLAlchemyUserDatastore) and\
not isinstance(datastore, SQLAlchemySessionUserDatastore)

0 comments on commit f982c5a

Please sign in to comment.