Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"gunicorn>=23.0.0",
"ruff>=0.11.13",
"mypy>=1.16.0",
"fastapi-sso>=0.18.0",
]

[project.optional-dependencies]
Expand Down
12 changes: 12 additions & 0 deletions scripts/local_with_uvicorn/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
APP_NAME="My Project"
APP_DESCRIPTION="My Project Description"
APP_VERSION="0.1"
APP_BACKEND_HOST="http://localhost:8000"
APP_FRONTEND_HOST="http://localhost:3000"
CONTACT_NAME="Me"
CONTACT_EMAIL="my.email@example.com"
LICENSE_NAME="MIT"
Expand Down Expand Up @@ -70,3 +72,13 @@ ENVIRONMENT="local"

# ------------- first tier -------------
TIER_NAME="free"

# ------------- auth settings -------------
# ENABLE_PASSWORD_AUTH=true
# GOOGLE_CLIENT_ID=
# GOOGLE_CLIENT_SECRET=
# MICROSOFT_CLIENT_ID=
# MICROSOFT_CLIENT_SECRET=
# MICROSOFT_TENANT=
# GITHUB_CLIENT_ID=
# GITHUB_CLIENT_SECRET=
6 changes: 4 additions & 2 deletions src/app/api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .health import router as health_router
from .login import router as login_router
from .logout import router as logout_router
from .oauth import router as oauth_router
from .posts import router as posts_router
from .rate_limits import router as rate_limits_router
from .tasks import router as tasks_router
Expand All @@ -13,8 +14,9 @@
router.include_router(health_router)
router.include_router(login_router)
router.include_router(logout_router)
router.include_router(users_router)
router.include_router(oauth_router)
router.include_router(posts_router)
router.include_router(rate_limits_router)
router.include_router(tasks_router)
router.include_router(tiers_router)
router.include_router(rate_limits_router)
router.include_router(users_router)
42 changes: 19 additions & 23 deletions src/app/api/v1/login.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import timedelta
from typing import Annotated

from fastapi import APIRouter, Depends, Request, Response
Expand All @@ -10,7 +9,6 @@
from ...core.exceptions.http_exceptions import UnauthorizedException
from ...core.schemas import Token
from ...core.security import (
ACCESS_TOKEN_EXPIRE_MINUTES,
TokenType,
authenticate_user,
create_access_token,
Expand All @@ -21,27 +19,25 @@
router = APIRouter(tags=["login"])


@router.post("/login", response_model=Token)
async def login_for_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[AsyncSession, Depends(async_get_db)],
) -> dict[str, str]:
user = await authenticate_user(username_or_email=form_data.username, password=form_data.password, db=db)
if not user:
raise UnauthorizedException("Wrong username, email or password.")

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = await create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)

refresh_token = await create_refresh_token(data={"sub": user["username"]})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60

response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age
)

return {"access_token": access_token, "token_type": "bearer"}
if settings.ENABLE_PASSWORD_AUTH:

@router.post("/login", response_model=Token)
async def login_with_password(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[AsyncSession, Depends(async_get_db)],
) -> dict[str, str]:
user = await authenticate_user(username_or_email=form_data.username, password=form_data.password, db=db)
if not user:
raise UnauthorizedException("Wrong username, email or password.")

access_token = await create_access_token(data={"sub": user["username"]})
refresh_token = await create_refresh_token(data={"sub": user["username"]})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age
)
return {"access_token": access_token, "token_type": "bearer"}


@router.post("/refresh")
Expand Down
140 changes: 140 additions & 0 deletions src/app/api/v1/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from abc import ABC
from typing import Any

from fastapi import APIRouter, Depends, Request, Response
from fastapi_sso.sso.base import OpenID, SSOBase
from fastapi_sso.sso.github import GithubSSO
from fastapi_sso.sso.google import GoogleSSO
from fastapi_sso.sso.microsoft import MicrosoftSSO
from sqlalchemy.ext.asyncio import AsyncSession

from ...core.config import settings
from ...core.db.database import async_get_db
from ...core.exceptions.http_exceptions import UnauthorizedException
from ...core.security import (
create_access_token,
create_refresh_token,
)
from ...crud.crud_users import crud_users
from ...schemas.user import UserCreateInternal, UserRead
from .users import write_user_internal

router = APIRouter(tags=["login", "oauth"])
logger = logging.getLogger(__name__)


class BaseOAuthProvider(ABC):
provider_config: dict[str, Any]
sso_provider: type[SSOBase]

def __init__(self, router: Any):
self.router = router
self.provider_name: str = self.sso_provider.provider
if self.is_enabled:
self.sso = self.sso_provider(redirect_uri=self.redirect_uri, **self.provider_config)
tag = f"{self.sso_provider.provider.title()} OAuth"
self.router.add_api_route(
f"/login/{self.provider_name}",
self._login_handler,
methods=["GET"],
tags=[tag],
summary=f"Login with {self.provider_name.title()} OAuth",
)
self.router.add_api_route(
f"/callback/{self.provider_name}",
self._callback_handler,
methods=["GET"],
tags=[tag],
summary=f"Callback for {self.provider_name.title()} OAuth",
)

@property
def redirect_uri(self) -> str:
return f"{settings.APP_BACKEND_HOST}/api/v1/callback/{self.provider_name}"

@property
def is_enabled(self) -> bool:
is_enabled = all(self.provider_config.values())
if settings.ENABLE_PASSWORD_AUTH and is_enabled:
logger.warning(
f"Both password authentication and {self.provider_name} OAuth are enabled. "
"For enterprise or B2B deployments, it is recommended to disable password authentication "
"by setting ENABLE_PASSWORD_AUTH=false and relying solely on OAuth."
)
return is_enabled

async def _create_and_set_token(self, response: Response, user: dict[str, Any]) -> str:
access_token = await create_access_token(data={"sub": user["username"]})
refresh_token = await create_refresh_token(data={"sub": user["username"]})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="lax", max_age=max_age
)
return access_token

async def _login_handler(self):
async with self.sso:
return await self.sso.get_login_redirect()

async def _callback_handler(self, request: Request, response: Response, db: AsyncSession = Depends(async_get_db)):
async with self.sso:
oauth_user: OpenID | None = await self.sso.verify_and_process(request)
if not oauth_user or not oauth_user.email:
raise UnauthorizedException(f"Invalid response from {self.provider_name.title()} OAuth.")

db_user = await crud_users.get(db=db, email=oauth_user.email, is_deleted=False, schema_to_select=UserRead)
if not db_user:
user = await self._get_user_details(oauth_user)
db_user = await write_user_internal(user=user, db=db)

access_token = await self._create_and_set_token(response, db_user)
return {"access_token": access_token, "token_type": "bearer"}

async def _get_user_details(self, oauth_user: OpenID) -> UserCreateInternal:
"""Get user details from the OAuth provider response.

The exact details exposed by the OpenID class can be found here:
https://github.com/tomasvotava/fastapi-sso/blob/master/fastapi_sso/sso/base.py#L64
"""
if not oauth_user.email:
raise UnauthorizedException(f"Invalid response from {self.provider_name.title()} OAuth.")
username = oauth_user.email.split("@")[0]
name = oauth_user.display_name or username

return UserCreateInternal(
email=oauth_user.email,
name=name,
username=username,
hashed_password=None, # No password since OAuth is used
)


class GoogleOAuthProvider(BaseOAuthProvider):
sso_provider = GoogleSSO
provider_config = {
"client_id": settings.GOOGLE_CLIENT_ID,
"client_secret": settings.GOOGLE_CLIENT_SECRET,
}


class MicrosoftOAuthProvider(BaseOAuthProvider):
sso_provider = MicrosoftSSO
provider_config = {
"client_id": settings.MICROSOFT_CLIENT_ID,
"client_secret": settings.MICROSOFT_CLIENT_SECRET,
"tenant": settings.MICROSOFT_TENANT,
}


class GithubSSOProvider(BaseOAuthProvider):
sso_provider = GithubSSO
provider_config = {
"client_id": settings.GITHUB_CLIENT_ID,
"client_secret": settings.GITHUB_CLIENT_SECRET,
}


GoogleOAuthProvider(router)
MicrosoftOAuthProvider(router)
GithubSSOProvider(router)
21 changes: 15 additions & 6 deletions src/app/api/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from ...api.dependencies import get_current_superuser, get_current_user
from ...core.config import settings
from ...core.db.database import async_get_db
from ...core.exceptions.http_exceptions import DuplicateValueException, ForbiddenException, NotFoundException
from ...core.security import blacklist_token, get_password_hash, oauth2_scheme
Expand All @@ -17,10 +18,17 @@
router = APIRouter(tags=["users"])


@router.post("/user", response_model=UserRead, status_code=201)
async def write_user(
request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)]
) -> dict[str, Any]:
if settings.ENABLE_PASSWORD_AUTH:

@router.post("/user", response_model=UserRead, status_code=201)
async def write_user(
request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)]
) -> dict[str, Any]:
created_user = await write_user_internal(user=user, db=db)
return created_user


async def write_user_internal(user: UserCreate | UserCreateInternal, db: AsyncSession) -> dict[str, Any]:
email_row = await crud_users.exists(db=db, email=user.email)
if email_row:
raise DuplicateValueException("Email is already registered")
Expand All @@ -30,8 +38,9 @@ async def write_user(
raise DuplicateValueException("Username not available")

user_internal_dict = user.model_dump()
user_internal_dict["hashed_password"] = get_password_hash(password=user_internal_dict["password"])
del user_internal_dict["password"]
if isinstance(user, UserCreate):
user_internal_dict["hashed_password"] = get_password_hash(password=user_internal_dict["password"])
del user_internal_dict["password"]

user_internal = UserCreateInternal(**user_internal_dict)
created_user = await crud_users.create(db=db, object=user_internal, schema_to_select=UserRead)
Expand Down
52 changes: 51 additions & 1 deletion src/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
import os
import warnings
from enum import Enum
from typing import Self

from pydantic import SecretStr, computed_field
from pydantic import SecretStr, computed_field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict


class AppSettings(BaseSettings):
APP_NAME: str = "FastAPI app"
APP_DESCRIPTION: str | None = None
APP_VERSION: str | None = None
APP_BACKEND_HOST: str = "http://localhost:8000"
APP_FRONTEND_HOST: str | None = None
LICENSE_NAME: str | None = None
CONTACT_NAME: str | None = None
CONTACT_EMAIL: str | None = None

@field_validator("APP_BACKEND_HOST", "APP_FRONTEND_HOST", mode="after")
@classmethod
def validate_hosts(cls, host: str) -> str:
if host is not None and not (host.startswith("http://") or host.startswith("https://")):
raise ValueError(
f"HOSTS must define their protocol and start with http:// or https://. Received the host '{host}'."
)
return host


class CryptSettings(BaseSettings):
SECRET_KEY: SecretStr = SecretStr("secret-key")
Expand Down Expand Up @@ -149,6 +162,17 @@ class CORSSettings(BaseSettings):
CORS_HEADERS: list[str] = ["*"]


class AuthSettings(BaseSettings):
ENABLE_PASSWORD_AUTH: bool = True
GOOGLE_CLIENT_ID: str | None = None
GOOGLE_CLIENT_SECRET: str | None = None
MICROSOFT_CLIENT_ID: str | None = None
MICROSOFT_CLIENT_SECRET: str | None = None
MICROSOFT_TENANT: str | None = None
GITHUB_CLIENT_ID: str | None = None
GITHUB_CLIENT_SECRET: str | None = None


class Settings(
AppSettings,
SQLiteSettings,
Expand All @@ -164,6 +188,7 @@ class Settings(
CRUDAdminSettings,
EnvironmentSettings,
CORSSettings,
AuthSettings,
):
model_config = SettingsConfigDict(
env_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", ".env"),
Expand All @@ -172,5 +197,30 @@ class Settings(
extra="ignore",
)

@model_validator(mode="after")
def validate_environment_settings(self) -> Self:
"The validation should not modify any of the settings. It should provide"
"feedback to the user if any misconfiguration is detected."
if self.ENVIRONMENT == EnvironmentOption.LOCAL:
pass
elif self.ENVIRONMENT == EnvironmentOption.STAGING:
if "*" in self.CORS_ORIGINS:
warnings.warn(
"For security, in a staging environment CORS_ORIGINS should not include '*'. "
"It's recommended to specify explicit origins (e.g., ['https://staging.example.com'])."
)
elif self.ENVIRONMENT == EnvironmentOption.PRODUCTION:
if "*" in self.CORS_ORIGINS:
raise ValueError(
"For security, in a production environment CORS_ORIGINS cannot include '*'. "
"You must specify explicit allowed origins (e.g., ['https://example.com', 'https://www.example.com'])."
)
if self.APP_FRONTEND_HOST and not self.APP_FRONTEND_HOST.startswith("https://"):
raise ValueError(
"In production, APP_FRONTEND_HOST must start with the https:// protocol. "
f"Received the host '{self.APP_FRONTEND_HOST}'."
)
return self


settings = Settings()
2 changes: 1 addition & 1 deletion src/app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def authenticate_user(username_or_email: str, password: str, db: AsyncSess
if not db_user:
return False

if not await verify_password(password, db_user["hashed_password"]):
if db_user["hashed_password"] is None or not await verify_password(password, db_user["hashed_password"]):
return False

return db_user
Expand Down
Loading