diff --git a/api/auth.py b/api/auth.py index c3e176d0..854382ce 100644 --- a/api/auth.py +++ b/api/auth.py @@ -12,6 +12,13 @@ from pydantic import BaseModel, BaseSettings, Field from .db import Database from .models import User +from fastapi_users.db import BeanieUserDatabase, ObjectIDIDMixin +from fastapi_users import BaseUserManager, FastAPIUsers +from typing import Optional, Any, Dict +from .models import TestUser +from beanie import PydanticObjectId +from fastapi import Depends, Request, Response +from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy class Token(BaseModel): @@ -125,3 +132,71 @@ async def validate_scopes(self, requested_scopes): if scope not in self._user_scopes: return False, scope return True, None + + +SECRET = "SECRET" + +class UserManager(ObjectIDIDMixin, BaseUserManager[TestUser, PydanticObjectId]): + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + async def on_after_register(self, user: TestUser, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + async def on_after_login( + self, + user: TestUser, + request: Optional[Request] = None, + response: Optional[Response] = None, + ): + print(f"User {user.id} logged in.") + + async def on_after_forgot_password( + self, user: TestUser, token: str, request: Optional[Request] = None + ): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify( + self, user: TestUser, token: str, request: Optional[Request] = None + ): + print(f"Verification requested for user {user.id}. Verification token: {token}") + return {"token": token} + + async def on_after_verify( + self, user: TestUser, request: Optional[Request] = None + ): + print(f"Verification successful for user {user.id}") + + async def on_after_update( + self, + user: TestUser, + update_dict: Dict[str, Any], + request: Optional[Request] = None, + ): + print(f"User {user.id} has been updated with {update_dict}.") + + async def on_before_delete(self, user: TestUser, request: Optional[Request] = None): + print(f"User {user.id} is going to be deleted") + +async def get_user_db(): + """Database adapter for fastapi-users""" + yield BeanieUserDatabase(TestUser) + +async def get_user_manager(user_db: BeanieUserDatabase = Depends(get_user_db)): + yield UserManager(user_db) + +bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") + +def get_jwt_strategy() -> JWTStrategy: + return JWTStrategy(secret=SECRET, lifetime_seconds=3600) + +auth_backend = AuthenticationBackend( + name="jwt", + transport=bearer_transport, + get_strategy=get_jwt_strategy, +) + +fastapi_users_instance = FastAPIUsers[TestUser, PydanticObjectId]( + get_user_manager, + [auth_backend], +) diff --git a/api/db.py b/api/db.py index dab7b823..c5700604 100644 --- a/api/db.py +++ b/api/db.py @@ -9,7 +9,8 @@ from bson import ObjectId from fastapi_pagination.ext.motor import paginate from motor import motor_asyncio -from .models import Hierarchy, Node, User, Regression, UserGroup +from .models import Hierarchy, Node, User, Regression, UserGroup, TestUser +from fastapi_users.db import BeanieUserDatabase class Database: @@ -38,6 +39,10 @@ def __init__(self, service='mongodb://db:27017', db_name='kernelci'): self._motor = motor_asyncio.AsyncIOMotorClient(service) self._db = self._motor[db_name] + @property + def db(self): + return self._db + def _get_collection(self, model): col = self.COLLECTIONS[model] return self._db[col] diff --git a/api/main.py b/api/main.py index 6c88b3d2..f09f2883 100644 --- a/api/main.py +++ b/api/main.py @@ -37,11 +37,18 @@ User, UserGroup, UserProfile, + TestUser, + UserCreate, + UserRead, + UserUpdate, Password, get_model_from_kind ) from .paginator_models import PageModel from .pubsub import PubSub, Subscription +from beanie import init_beanie +from .auth import fastapi_users_instance, auth_backend + app = FastAPI() db = Database(service=(os.getenv('MONGO_SERVICE') or 'mongodb://db:27017')) @@ -63,6 +70,15 @@ async def create_indexes(): """Startup event handler to create database indexes""" await db.create_indexes() +@app.on_event('startup') +async def fastapi_users_beanie_init(): + """Startup event handler for beanie""" + await init_beanie( + database=db.db, + document_models=[ + TestUser, + ], + ) @app.exception_handler(ValueError) async def value_error_exception_handler(request: Request, exc: ValueError): @@ -630,6 +646,38 @@ async def put_regression(regression_id: str, regression: Regression, return obj +app.include_router( + fastapi_users_instance.get_auth_router(auth_backend, requires_verification=True), + prefix="/auth/jwt", + tags=["auth"] +) +app.include_router( + fastapi_users_instance.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users_instance.get_reset_password_router(), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users_instance.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users_instance.get_users_router(UserRead, UserUpdate, requires_verification=True), + prefix="/users", + tags=["users"], +) + +current_active_user = fastapi_users_instance.current_user(active=True) + +@app.get("/authenticated-route") +async def authenticated_route(user: TestUser = Depends(current_active_user)): + return {"message": f"Hello {user.username}!"} + app = VersionedFastAPI( app, version_format='{major}', @@ -639,6 +687,7 @@ async def put_regression(regression_id: str, regression: Regression, on_startup=[ pubsub_startup, create_indexes, + fastapi_users_beanie_init, ] ) diff --git a/api/models.py b/api/models.py index 52b79e75..b1d817db 100644 --- a/api/models.py +++ b/api/models.py @@ -24,6 +24,11 @@ FileUrl, SecretStr, ) +from beanie import Indexed, Document +from fastapi_users.db import BeanieBaseUser +from fastapi_users import schemas +from beanie import PydanticObjectId, Indexed +from typing import Optional class PyObjectId(ObjectId): @@ -148,6 +153,24 @@ def create_indexes(cls, collection): collection.create_index("profile.username", unique=True) +class TestUser(BeanieBaseUser, Document): + """Test user""" + username: Indexed(str, unique=True) + + +class UserRead(schemas.BaseUser[PydanticObjectId]): + username: Indexed(str, unique=True) + + +class UserCreate(schemas.BaseUserCreate): + username: Indexed(str, unique=True) + + +class UserUpdate(schemas.BaseUserUpdate): + username: Optional[Indexed(str, unique=True)] + + + class KernelVersion(BaseModel): """Linux kernel version model""" version: int = Field( diff --git a/docker/api/requirements.txt b/docker/api/requirements.txt index 5ff9a01b..b00e02d3 100644 --- a/docker/api/requirements.txt +++ b/docker/api/requirements.txt @@ -10,3 +10,4 @@ motor==2.5.1 pymongo-migrate==0.11.0 pyyaml==5.3.1 fastapi-versioning==0.10.0 +fastapi-users[beanie, oauth]==10.4.0