diff --git a/pyproject.toml b/pyproject.toml index f5b26924..24ac3e36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "gunicorn>=23.0.0", "ruff>=0.11.13", "mypy>=1.16.0", + "fastapi-sso>=0.18.0", ] [project.optional-dependencies] diff --git a/scripts/local_with_uvicorn/.env.example b/scripts/local_with_uvicorn/.env.example index 9f3e5f43..b3be8b3e 100644 --- a/scripts/local_with_uvicorn/.env.example +++ b/scripts/local_with_uvicorn/.env.example @@ -72,3 +72,13 @@ ENVIRONMENT="local" # ------------- first tier ------------- TIER_NAME="free" + +# ------------- auth settings ------------- +# ENABLE_PASSWORD_AUTH=true +# GOOGLE_CLIENT_ID= +# GOOGLE_CLIENT_SECRET= +# MICROSOFT_CLIENT_ID= +# MICROSOFT_CLIENT_SECRET= +# MICROSOFT_TENANT= +# GITHUB_CLIENT_ID= +# GITHUB_CLIENT_SECRET= diff --git a/src/app/api/v1/__init__.py b/src/app/api/v1/__init__.py index 7575848f..823fa147 100644 --- a/src/app/api/v1/__init__.py +++ b/src/app/api/v1/__init__.py @@ -3,6 +3,7 @@ from .health import router as health_router from .login import router as login_router from .logout import router as logout_router +from .oauth import router as oauth_router from .posts import router as posts_router from .rate_limits import router as rate_limits_router from .tasks import router as tasks_router @@ -13,8 +14,9 @@ router.include_router(health_router) router.include_router(login_router) router.include_router(logout_router) -router.include_router(users_router) +router.include_router(oauth_router) router.include_router(posts_router) +router.include_router(rate_limits_router) router.include_router(tasks_router) router.include_router(tiers_router) -router.include_router(rate_limits_router) +router.include_router(users_router) diff --git a/src/app/api/v1/login.py b/src/app/api/v1/login.py index e784731f..5303463c 100644 --- a/src/app/api/v1/login.py +++ b/src/app/api/v1/login.py @@ -1,4 +1,3 @@ -from datetime import timedelta from typing import Annotated from fastapi import APIRouter, Depends, Request, Response @@ -10,7 +9,6 @@ from ...core.exceptions.http_exceptions import UnauthorizedException from ...core.schemas import Token from ...core.security import ( - ACCESS_TOKEN_EXPIRE_MINUTES, TokenType, authenticate_user, create_access_token, @@ -21,27 +19,25 @@ router = APIRouter(tags=["login"]) -@router.post("/login", response_model=Token) -async def login_for_access_token( - response: Response, - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: Annotated[AsyncSession, Depends(async_get_db)], -) -> dict[str, str]: - user = await authenticate_user(username_or_email=form_data.username, password=form_data.password, db=db) - if not user: - raise UnauthorizedException("Wrong username, email or password.") - - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = await create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires) - - refresh_token = await create_refresh_token(data={"sub": user["username"]}) - max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 - - response.set_cookie( - key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age - ) - - return {"access_token": access_token, "token_type": "bearer"} +if settings.ENABLE_PASSWORD_AUTH: + + @router.post("/login", response_model=Token) + async def login_with_password( + response: Response, + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db: Annotated[AsyncSession, Depends(async_get_db)], + ) -> dict[str, str]: + user = await authenticate_user(username_or_email=form_data.username, password=form_data.password, db=db) + if not user: + raise UnauthorizedException("Wrong username, email or password.") + + access_token = await create_access_token(data={"sub": user["username"]}) + refresh_token = await create_refresh_token(data={"sub": user["username"]}) + max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age + ) + return {"access_token": access_token, "token_type": "bearer"} @router.post("/refresh") diff --git a/src/app/api/v1/oauth.py b/src/app/api/v1/oauth.py new file mode 100644 index 00000000..c7118205 --- /dev/null +++ b/src/app/api/v1/oauth.py @@ -0,0 +1,140 @@ +import logging +from abc import ABC +from typing import Any + +from fastapi import APIRouter, Depends, Request, Response +from fastapi_sso.sso.base import OpenID, SSOBase +from fastapi_sso.sso.github import GithubSSO +from fastapi_sso.sso.google import GoogleSSO +from fastapi_sso.sso.microsoft import MicrosoftSSO +from sqlalchemy.ext.asyncio import AsyncSession + +from ...core.config import settings +from ...core.db.database import async_get_db +from ...core.exceptions.http_exceptions import UnauthorizedException +from ...core.security import ( + create_access_token, + create_refresh_token, +) +from ...crud.crud_users import crud_users +from ...schemas.user import UserCreateInternal, UserRead +from .users import write_user_internal + +router = APIRouter(tags=["login", "oauth"]) +logger = logging.getLogger(__name__) + + +class BaseOAuthProvider(ABC): + provider_config: dict[str, Any] + sso_provider: type[SSOBase] + + def __init__(self, router: Any): + self.router = router + self.provider_name: str = self.sso_provider.provider + if self.is_enabled: + self.sso = self.sso_provider(redirect_uri=self.redirect_uri, **self.provider_config) + tag = f"{self.sso_provider.provider.title()} OAuth" + self.router.add_api_route( + f"/login/{self.provider_name}", + self._login_handler, + methods=["GET"], + tags=[tag], + summary=f"Login with {self.provider_name.title()} OAuth", + ) + self.router.add_api_route( + f"/callback/{self.provider_name}", + self._callback_handler, + methods=["GET"], + tags=[tag], + summary=f"Callback for {self.provider_name.title()} OAuth", + ) + + @property + def redirect_uri(self) -> str: + return f"{settings.APP_BACKEND_HOST}/api/v1/callback/{self.provider_name}" + + @property + def is_enabled(self) -> bool: + is_enabled = all(self.provider_config.values()) + if settings.ENABLE_PASSWORD_AUTH and is_enabled: + logger.warning( + f"Both password authentication and {self.provider_name} OAuth are enabled. " + "For enterprise or B2B deployments, it is recommended to disable password authentication " + "by setting ENABLE_PASSWORD_AUTH=false and relying solely on OAuth." + ) + return is_enabled + + async def _create_and_set_token(self, response: Response, user: dict[str, Any]) -> str: + access_token = await create_access_token(data={"sub": user["username"]}) + refresh_token = await create_refresh_token(data={"sub": user["username"]}) + max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age + ) + return access_token + + async def _login_handler(self): + async with self.sso: + return await self.sso.get_login_redirect() + + async def _callback_handler(self, request: Request, response: Response, db: AsyncSession = Depends(async_get_db)): + async with self.sso: + oauth_user: OpenID | None = await self.sso.verify_and_process(request) + if not oauth_user or not oauth_user.email: + raise UnauthorizedException(f"Invalid response from {self.provider_name.title()} OAuth.") + + db_user = await crud_users.get(db=db, email=oauth_user.email, is_deleted=False, schema_to_select=UserRead) + if not db_user: + user = await self._get_user_details(oauth_user) + db_user = await write_user_internal(user=user, db=db) + + access_token = await self._create_and_set_token(response, db_user) + return {"access_token": access_token, "token_type": "bearer"} + + async def _get_user_details(self, oauth_user: OpenID) -> UserCreateInternal: + """Get user details from the OAuth provider response. + + The exact details exposed by the OpenID class can be found here: + https://github.com/tomasvotava/fastapi-sso/blob/master/fastapi_sso/sso/base.py#L64 + """ + if not oauth_user.email: + raise UnauthorizedException(f"Invalid response from {self.provider_name.title()} OAuth.") + username = oauth_user.email.split("@")[0] + name = oauth_user.display_name or username + + return UserCreateInternal( + email=oauth_user.email, + name=name, + username=username, + hashed_password=None, # No password since OAuth is used + ) + + +class GoogleOAuthProvider(BaseOAuthProvider): + sso_provider = GoogleSSO + provider_config = { + "client_id": settings.GOOGLE_CLIENT_ID, + "client_secret": settings.GOOGLE_CLIENT_SECRET, + } + + +class MicrosoftOAuthProvider(BaseOAuthProvider): + sso_provider = MicrosoftSSO + provider_config = { + "client_id": settings.MICROSOFT_CLIENT_ID, + "client_secret": settings.MICROSOFT_CLIENT_SECRET, + "tenant": settings.MICROSOFT_TENANT, + } + + +class GithubSSOProvider(BaseOAuthProvider): + sso_provider = GithubSSO + provider_config = { + "client_id": settings.GITHUB_CLIENT_ID, + "client_secret": settings.GITHUB_CLIENT_SECRET, + } + + +GoogleOAuthProvider(router) +MicrosoftOAuthProvider(router) +GithubSSOProvider(router) diff --git a/src/app/api/v1/users.py b/src/app/api/v1/users.py index 60264cc2..49acb96d 100644 --- a/src/app/api/v1/users.py +++ b/src/app/api/v1/users.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ...api.dependencies import get_current_superuser, get_current_user +from ...core.config import settings from ...core.db.database import async_get_db from ...core.exceptions.http_exceptions import DuplicateValueException, ForbiddenException, NotFoundException from ...core.security import blacklist_token, get_password_hash, oauth2_scheme @@ -17,10 +18,17 @@ router = APIRouter(tags=["users"]) -@router.post("/user", response_model=UserRead, status_code=201) -async def write_user( - request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)] -) -> dict[str, Any]: +if settings.ENABLE_PASSWORD_AUTH: # If password auth is not enable there should be no way to create users via the API + + @router.post("/user", response_model=UserRead, status_code=201) + async def write_user( + request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)] + ) -> dict[str, Any]: + created_user = await write_user_internal(user=user, db=db) + return created_user + + +async def write_user_internal(user: UserCreate | UserCreateInternal, db: AsyncSession) -> dict[str, Any]: email_row = await crud_users.exists(db=db, email=user.email) if email_row: raise DuplicateValueException("Email is already registered") @@ -29,13 +37,13 @@ async def write_user( if username_row: raise DuplicateValueException("Username not available") - user_internal_dict = user.model_dump() - user_internal_dict["hashed_password"] = get_password_hash(password=user_internal_dict["password"]) - del user_internal_dict["password"] - - user_internal = UserCreateInternal(**user_internal_dict) - created_user = await crud_users.create(db=db, object=user_internal, schema_to_select=UserRead) + if isinstance(user, UserCreate): + user_internal_dict = user.model_dump() + user_internal_dict["hashed_password"] = get_password_hash(password=user_internal_dict["password"]) + del user_internal_dict["password"] + user = UserCreateInternal(**user_internal_dict) + created_user = await crud_users.create(db=db, object=user, schema_to_select=UserRead) if created_user is None: raise NotFoundException("Failed to create user") diff --git a/src/app/core/config.py b/src/app/core/config.py index c0312438..5693c52e 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -141,6 +141,17 @@ class CORSSettings(BaseSettings): CORS_HEADERS: list[str] = ["*"] +class AuthSettings(BaseSettings): + ENABLE_PASSWORD_AUTH: bool = True + GOOGLE_CLIENT_ID: str | None = None + GOOGLE_CLIENT_SECRET: str | None = None + MICROSOFT_CLIENT_ID: str | None = None + MICROSOFT_CLIENT_SECRET: str | None = None + MICROSOFT_TENANT: str | None = None + GITHUB_CLIENT_ID: str | None = None + GITHUB_CLIENT_SECRET: str | None = None + + class Settings( AppSettings, PostgresSettings, @@ -155,6 +166,7 @@ class Settings( CRUDAdminSettings, EnvironmentSettings, CORSSettings, + AuthSettings, ): model_config = SettingsConfigDict( env_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", ".env"), diff --git a/src/app/core/security.py b/src/app/core/security.py index d589078b..d6f18ca2 100644 --- a/src/app/core/security.py +++ b/src/app/core/security.py @@ -45,7 +45,8 @@ async def authenticate_user(username_or_email: str, password: str, db: AsyncSess if not db_user: return False - if not await verify_password(password, db_user["hashed_password"]): + # If the user has no password set (e.g. OAuth2 only accounts), reject authentication + if db_user["hashed_password"] is None or not await verify_password(password, db_user["hashed_password"]): return False return db_user diff --git a/src/app/core/setup.py b/src/app/core/setup.py index b2cdcbf7..766eae6c 100644 --- a/src/app/core/setup.py +++ b/src/app/core/setup.py @@ -18,6 +18,7 @@ from ..models import * # noqa: F403 from .config import ( AppSettings, + AuthSettings, ClientSideCacheSettings, CORSSettings, DatabaseSettings, @@ -86,6 +87,7 @@ def lifespan_factory( | RedisQueueSettings | RedisRateLimiterSettings | EnvironmentSettings + | AuthSettings ), create_tables_on_start: bool = True, ) -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]: @@ -142,6 +144,7 @@ def create_application( | RedisQueueSettings | RedisRateLimiterSettings | EnvironmentSettings + | AuthSettings ), create_tables_on_start: bool = True, lifespan: Callable[[FastAPI], _AsyncGeneratorContextManager[Any]] | None = None, diff --git a/src/app/models/user.py b/src/app/models/user.py index 07cca2d8..d8f1170e 100644 --- a/src/app/models/user.py +++ b/src/app/models/user.py @@ -17,7 +17,7 @@ class User(Base): name: Mapped[str] = mapped_column(String(30)) username: Mapped[str] = mapped_column(String(20), unique=True, index=True) email: Mapped[str] = mapped_column(String(50), unique=True, index=True) - hashed_password: Mapped[str] = mapped_column(String) + hashed_password: Mapped[str | None] = mapped_column(String, nullable=True) profile_image_url: Mapped[str] = mapped_column(String, default="https://profileimageurl.com") uuid: Mapped[uuid_pkg.UUID] = mapped_column(UUID(as_uuid=True), default_factory=uuid7, unique=True) diff --git a/src/app/schemas/user.py b/src/app/schemas/user.py index c33a94e3..6303168e 100644 --- a/src/app/schemas/user.py +++ b/src/app/schemas/user.py @@ -36,7 +36,9 @@ class UserCreate(UserBase): class UserCreateInternal(UserBase): - hashed_password: str + model_config = ConfigDict(extra="forbid") + + hashed_password: str | None class UserUpdate(BaseModel): diff --git a/uv.lock b/uv.lock index 5dda7a25..1dd702a5 100644 --- a/uv.lock +++ b/uv.lock @@ -387,6 +387,7 @@ dependencies = [ { name = "bcrypt" }, { name = "crudadmin" }, { name = "fastapi" }, + { name = "fastapi-sso" }, { name = "fastcrud" }, { name = "greenlet" }, { name = "gunicorn" }, @@ -435,6 +436,7 @@ requires-dist = [ { name = "crudadmin", specifier = ">=0.4.2" }, { name = "faker", marker = "extra == 'dev'", specifier = ">=26.0.0" }, { name = "fastapi", specifier = ">=0.109.1" }, + { name = "fastapi-sso", specifier = ">=0.18.0" }, { name = "fastcrud", specifier = ">=0.19.2" }, { name = "greenlet", specifier = ">=2.0.2" }, { name = "gunicorn", specifier = ">=23.0.0" }, @@ -470,6 +472,22 @@ dev = [ { name = "watchfiles", specifier = ">=1.1.1" }, ] +[[package]] +name = "fastapi-sso" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "httpx" }, + { name = "oauthlib" }, + { name = "pydantic", extra = ["email"] }, + { name = "pyjwt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/57/cc971c018af5d09eb5f8d1cd12abdd99ab4c59ea5c0b0b1b96349ffe117d/fastapi_sso-0.18.0.tar.gz", hash = "sha256:d8df5a686af7a6a7be248817544b405cf77f7e9ffcd5d0d7d2a196fd071964bc", size = 16811, upload-time = "2025-03-20T17:09:09.958Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/03/70ca13994f5569d343a9f99dba2930c8ae3471171f161b8887d44b6c526f/fastapi_sso-0.18.0-py3-none-any.whl", hash = "sha256:727754ad770b70690f1471f7b0a9e17c6dfd8ebd6e477616d3bde1eaf62e53dc", size = 26103, upload-time = "2025-03-20T17:09:08.656Z" }, +] + [[package]] name = "fastcrud" version = "0.19.2" @@ -816,6 +834,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "oauthlib" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/19930f824ffeb0ad4372da4812c50edbd1434f678c90c2733e1188edfc63/oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9", size = 185918, upload-time = "2025-06-19T22:48:08.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1039,11 +1066,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.9.0" +version = "2.10.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fb/68/ce067f09fca4abeca8771fe667d89cc347d1e99da3e093112ac329c6020e/pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c", size = 78825, upload-time = "2024-08-01T15:01:08.445Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/79/84/0fdf9b18ba31d69877bd39c9cd6052b47f3761e9910c15de788e519f079f/PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850", size = 22344, upload-time = "2024-08-01T15:01:06.481Z" }, + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ] [[package]] @@ -1175,15 +1202,15 @@ wheels = [ [[package]] name = "redis" -version = "5.3.0" +version = "5.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, { name = "pyjwt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/dd/2b37032f4119dff2a2f9bbcaade03221b100ba26051bb96e275de3e5db7a/redis-5.3.0.tar.gz", hash = "sha256:8d69d2dde11a12dc85d0dbf5c45577a5af048e2456f7077d87ad35c1c81c310e", size = 4626288, upload-time = "2025-04-30T14:54:40.634Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/cf/128b1b6d7086200c9f387bd4be9b2572a30b90745ef078bd8b235042dc9f/redis-5.3.1.tar.gz", hash = "sha256:ca49577a531ea64039b5a36db3d6cd1a0c7a60c34124d46924a45b956e8cf14c", size = 4626200, upload-time = "2025-07-25T08:06:27.778Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/45/b0/aa601efe12180ba492b02e270554877e68467e66bda5d73e51eaa8ecc78a/redis-5.3.0-py3-none-any.whl", hash = "sha256:f1deeca1ea2ef25c1e4e46b07f4ea1275140526b1feea4c6459c0ec27a10ef83", size = 272836, upload-time = "2025-04-30T14:54:30.744Z" }, + { url = "https://files.pythonhosted.org/packages/7f/26/5c5fa0e83c3621db835cfc1f1d789b37e7fa99ed54423b5f519beb931aa7/redis-5.3.1-py3-none-any.whl", hash = "sha256:dc1909bd24669cc31b5f67a039700b16ec30571096c5f1f0d9d2324bff31af97", size = 272833, upload-time = "2025-07-25T08:06:26.317Z" }, ] [package.optional-dependencies]