-
Notifications
You must be signed in to change notification settings - Fork 21
Integrate fastapi-users
for user management
#377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4092e9
018157c
2239da5
aade3cb
b53127e
5f4d261
5aca0dc
a72b8fd
87eaddd
e91df8b
bbd0390
766e406
e40c882
49e20de
cb1c7b5
0450034
887441d
be9916b
8836489
ee67577
56a429f
46557c5
c9dde29
77a7b89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,8 @@ | |
|
||
from .auth import Authentication | ||
from .db import Database | ||
from .models import User, UserGroup, UserProfile | ||
from .models import UserGroup | ||
gctucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from .user_models import User | ||
|
||
|
||
async def setup_admin_group(db, admin_group): | ||
|
@@ -42,19 +43,18 @@ async def setup_admin_user(db, username, email, admin_group): | |
return None | ||
hashed_password = Authentication.get_password_hash(password) | ||
print(f"Creating {username} user...") | ||
profile = UserProfile( | ||
return await db.create(User( | ||
username=username, | ||
hashed_password=hashed_password, | ||
email=email, | ||
groups=[admin_group] | ||
) | ||
return await db.create(User( | ||
profile=profile | ||
groups=[admin_group], | ||
is_superuser=1 | ||
)) | ||
|
||
|
||
async def main(args): | ||
db = Database(args.mongo, args.database) | ||
await db.initialize_beanie() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need Beanie here. I'll dig a bit deeper, as far as I know we don't actually use Beanie anywhere in the API code at the moment, just fastapi-users uses it internally. We might however decide to rely more on Beanie going forward to simplify the code, it could basically be a replacement for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are accessing DB to create an admin user with |
||
group = await setup_admin_group(db, args.admin_group) | ||
user = await setup_admin_user(db, args.username, args.email, group) | ||
return True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,62 @@ | ||
# SPDX-License-Identifier: LGPL-2.1-or-later | ||
# | ||
# Copyright (C) 2021 Collabora Limited | ||
# Copyright (C) 2021-2023 Collabora Limited | ||
# Author: Guillaume Tucker <guillaume.tucker@collabora.com> | ||
# Author: Jeny Sadadia <jeny.sadadia@collabora.com> | ||
|
||
"""User authentication utilities""" | ||
|
||
from datetime import datetime, timedelta | ||
from fastapi.security import OAuth2PasswordBearer | ||
from jose import JWTError, jwt | ||
from passlib.context import CryptContext | ||
from pydantic import BaseModel, BaseSettings, Field | ||
from .db import Database | ||
from .models import User | ||
|
||
|
||
class Token(BaseModel): | ||
"""Authentication token model""" | ||
access_token: str = Field( | ||
description='Authentication access token' | ||
) | ||
token_type: str = Field( | ||
description='Access token type e.g. Bearer' | ||
) | ||
from pydantic import BaseSettings | ||
from fastapi_users.authentication import ( | ||
AuthenticationBackend, | ||
BearerTransport, | ||
JWTStrategy, | ||
) | ||
|
||
|
||
class Settings(BaseSettings): | ||
"""Authentication settings""" | ||
secret_key: str | ||
algorithm: str = "HS256" | ||
# Set to None so tokens don't expire | ||
access_token_expire_minutes: float = None | ||
access_token_expire_seconds: int = None | ||
|
||
|
||
class Authentication: | ||
"""Authentication utility class | ||
|
||
This class accepts a single argument `database` in its constructor, which | ||
should be a db.Database object. | ||
""" | ||
"""Authentication utility class""" | ||
|
||
CRYPT_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") | ||
|
||
def __init__(self, database: Database, token_url: str, user_scopes: dict): | ||
self._db = database | ||
def __init__(self, token_url: str): | ||
self._settings = Settings() | ||
self._user_scopes = user_scopes | ||
self._oauth2_scheme = OAuth2PasswordBearer( | ||
tokenUrl=token_url, | ||
scopes=self._user_scopes | ||
) | ||
|
||
@property | ||
def oauth2_scheme(self): | ||
"""Get authentication scheme""" | ||
return self._oauth2_scheme | ||
self._token_url = token_url | ||
|
||
@classmethod | ||
def get_password_hash(cls, password): | ||
"""Get a password hash for a given clear text password string""" | ||
return cls.CRYPT_CTX.hash(password) | ||
|
||
@classmethod | ||
def verify_password(cls, password_hash, user): | ||
"""Verify that the password hash matches the user's password""" | ||
return cls.CRYPT_CTX.verify(password_hash, user.hashed_password) | ||
def get_jwt_strategy(self) -> JWTStrategy: | ||
"""Get JWT strategy for authentication backend""" | ||
return JWTStrategy( | ||
secret=self._settings.secret_key, | ||
algorithm=self._settings.algorithm, | ||
lifetime_seconds=self._settings.access_token_expire_seconds | ||
) | ||
|
||
async def authenticate_user(self, username: str, password: str): | ||
"""Authenticate a username / password pair | ||
def get_user_authentication_backend(self): | ||
"""Authentication backend for user management | ||
|
||
Look up a `User` in the database with the provided `username` | ||
and check whether the provided clear text `password` matches the hash | ||
associated with it. | ||
Authentication backend for `fastapi-users` is composed of two | ||
parts: Transaport and Strategy. | ||
Transport is a mechanism for token transmisson i.e. bearer or cookie. | ||
Strategy is a method to generate and secure tokens. It can be JWT, | ||
database or Redis. | ||
""" | ||
user = await self._db.find_one_by_attributes( | ||
User, {'profile.username': username}) | ||
if not user: | ||
return False | ||
if not self.verify_password(password, user.profile): | ||
return False | ||
return user.profile | ||
|
||
def create_access_token(self, data: dict): | ||
"""Create a JWT access token using the provided arbitrary `data`""" | ||
to_encode = data.copy() | ||
if self._settings.access_token_expire_minutes: | ||
expires_delta = timedelta( | ||
minutes=self._settings.access_token_expire_minutes | ||
) | ||
expire = datetime.utcnow() + expires_delta | ||
to_encode.update({"exp": expire}) | ||
encoded_jwt = jwt.encode( | ||
to_encode, | ||
self._settings.secret_key, algorithm=self._settings.algorithm | ||
) | ||
return encoded_jwt | ||
|
||
async def get_current_user(self, token, security_scopes): | ||
"""Decode the given JWT `token` and look up a matching `User`""" | ||
try: | ||
payload = jwt.decode( | ||
token, | ||
self._settings.secret_key, | ||
algorithms=[self._settings.algorithm] | ||
) | ||
username: str = payload.get("sub") | ||
token_scopes = payload.get("scopes", []) | ||
if username is None: | ||
return None, "Could not validate credentials" | ||
|
||
for scope in security_scopes: | ||
if scope not in token_scopes: | ||
return None, "Access denied" | ||
|
||
except JWTError as error: | ||
return None, str(error) | ||
|
||
user = await self._db.find_one_by_attributes( | ||
User, {'profile.username': username}) | ||
return user, None | ||
|
||
async def validate_scopes(self, requested_scopes): | ||
"""Check if requested scopes are valid user scopes""" | ||
for scope in requested_scopes: | ||
if scope not in self._user_scopes: | ||
return False, scope | ||
return True, None | ||
bearer_transport = BearerTransport(tokenUrl=self._token_url) | ||
return AuthenticationBackend( | ||
name="jwt", | ||
transport=bearer_transport, | ||
get_strategy=self.get_jwt_strategy, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,11 @@ | |
"""Database abstraction""" | ||
|
||
from bson import ObjectId | ||
from beanie import init_beanie | ||
from fastapi_pagination.ext.motor import paginate | ||
from motor import motor_asyncio | ||
from .models import Hierarchy, Node, User, Regression, UserGroup | ||
from .models import Hierarchy, Node, Regression, UserGroup | ||
from .user_models import User | ||
|
||
|
||
class Database: | ||
|
@@ -39,6 +41,15 @@ def __init__(self, service='mongodb://db:27017', db_name='kernelci'): | |
self._motor = motor_asyncio.AsyncIOMotorClient(service) | ||
self._db = self._motor[db_name] | ||
|
||
async def initialize_beanie(self): | ||
"""Initialize Beanie ODM to use `fastapi-users` tools for MongoDB""" | ||
await init_beanie( | ||
database=self._db, | ||
document_models=[ | ||
User, | ||
], | ||
) | ||
Comment on lines
+44
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this should instead be in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is specific to DB-related initialization. Hence, I added it to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I mean is that the |
||
|
||
def _get_collection(self, model): | ||
col = self.COLLECTIONS[model] | ||
return self._db[col] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The commit message for fixing the staging bug doesn't explain what this does. It looks like it's temporary so should it really be merged? This looks like a staging configuration issue, or maybe the API code should be able to deal with such cases and not try to send emails when the sender is not valid. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# SPDX-License-Identifier: LGPL-2.1-or-later | ||
# | ||
# Copyright (C) 2023 Jeny Sadadia | ||
# Author: Jeny Sadadia <jeny.sadadia@collabora.com> | ||
|
||
"""SMTP Email Sender module""" | ||
|
||
from email.mime.multipart import MIMEMultipart | ||
import email | ||
import email.mime.text | ||
import smtplib | ||
# from pydantic import BaseSettings, EmailStr | ||
from pydantic import BaseSettings | ||
|
||
|
||
class Settings(BaseSettings): | ||
"""Email settings""" | ||
smtp_host: str | ||
smtp_port: int | ||
# email_sender: EmailStr | ||
email_sender: str | ||
email_password: str | ||
|
||
|
||
class EmailSender: | ||
"""Class to send email report using SMTP""" | ||
def __init__(self): | ||
self._settings = Settings() | ||
|
||
def _smtp_connect(self): | ||
"""Method to create a connection with SMTP server""" | ||
if self._settings.smtp_port == 465: | ||
smtp = smtplib.SMTP_SSL(self._settings.smtp_host, | ||
self._settings.smtp_port) | ||
else: | ||
smtp = smtplib.SMTP(self._settings.smtp_host, | ||
self._settings.smtp_port) | ||
smtp.starttls() | ||
smtp.login(self._settings.email_sender, | ||
self._settings.email_password) | ||
return smtp | ||
|
||
def _create_email(self, email_subject, email_content, email_recipient): | ||
"""Method to create an email message from email subject, contect, | ||
sender, and receiver""" | ||
email_msg = MIMEMultipart() | ||
email_text = email.mime.text.MIMEText(email_content, "plain", "utf-8") | ||
email_text.replace_header('Content-Transfer-Encoding', 'quopri') | ||
email_text.set_payload(email_content, 'utf-8') | ||
email_msg.attach(email_text) | ||
email_msg['To'] = email_recipient | ||
email_msg['From'] = self._settings.email_sender | ||
email_msg['Subject'] = email_subject | ||
return email_msg | ||
|
||
def _send_email(self, email_msg): | ||
"""Method to send an email message using SMTP""" | ||
smtp = self._smtp_connect() | ||
if smtp: | ||
smtp.send_message(email_msg) | ||
smtp.quit() | ||
|
||
def create_and_send_email(self, email_subject, email_content, | ||
email_recipient): | ||
"""Method to create and send email""" | ||
email_msg = self._create_email( | ||
email_subject, email_content, email_recipient | ||
) | ||
self._send_email(email_msg) |
Uh oh!
There was an error while loading. Please reload this page.