From 240b20015f461605f46af4bb0310225ff0dcc49a Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Fri, 31 Oct 2025 18:16:17 +0800 Subject: [PATCH 1/3] Add docs and code comments and tests --- docs/advanced.md | 1 - docs/explanation.md | 2 +- docs/index.md | 63 +++++- docs/install.md | 6 - docs/status.md | 2 +- docs/usage.md | 279 +++++++++++++++++++++++- fastapi_oauth20/clients/feishu.py | 15 +- fastapi_oauth20/clients/gitee.py | 15 +- fastapi_oauth20/clients/github.py | 15 +- fastapi_oauth20/clients/google.py | 15 +- fastapi_oauth20/clients/linuxdo.py | 15 +- fastapi_oauth20/clients/oschina.py | 15 +- fastapi_oauth20/errors.py | 27 ++- fastapi_oauth20/integrations/fastapi.py | 27 ++- fastapi_oauth20/oauth20.py | 69 +++--- mkdocs.yml | 1 - pyproject.toml | 1 + requirements.txt | 5 +- tests/__init__.py | 2 + uv.lock | 8 +- 20 files changed, 509 insertions(+), 74 deletions(-) delete mode 100644 docs/advanced.md diff --git a/docs/advanced.md b/docs/advanced.md deleted file mode 100644 index 6096c2a..0000000 --- a/docs/advanced.md +++ /dev/null @@ -1 +0,0 @@ -TODO... diff --git a/docs/explanation.md b/docs/explanation.md index 8ba6c61..4aab8c6 100644 --- a/docs/explanation.md +++ b/docs/explanation.md @@ -10,7 +10,7 @@ code 等参数 - ==state== 用于在请求和回调之间维护状态,主要用于防止跨站请求伪造(CSRF)攻击 - ==source== 支持的第三方客户端,比如:GitHub、LinuxDo 等 -- ==sid== 第三方客户端的用户 ID。以下是关于各平台的 sid 存储逻辑: +- ==sid== 第三方客户端的用户 ID !!! warning diff --git a/docs/index.md b/docs/index.md index f435039..c1423d9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,22 +1,61 @@ -
-

🔐

-
+# FastAPI OAuth 2.0 ---- +在 FastAPI 中异步授权 OAuth 2.0 客户端 -**Documentation -**: https://fastapi-practices.github.io/fastapi-oauth20 +## Features -**Source Code -**: https://github.com/fastapi-practices/fastapi-oauth20 +- **异步支持** - 使用 async/await 构建以获得最佳性能 +- **令牌管理** - 内置令牌刷新和吊销 +- **FastAPI 集成** - 用于回调处理的无缝依赖注入 +- **类型安全** - 完整类型提示 +- **错误处理** - OAuth2 错误的综合异常层次结构 ---- +## Quick Start -在 FastAPI 中异步授权 OAuth 2.0 客户端 +### Installation + +```bash +pip install fastapi-oauth20 +``` + +### Basic Usage + +```python +from fastapi import FastAPI, Depends +from fastapi_oauth20 import GitHubOAuth20, FastAPIOAuth20 +from fastapi.responses import RedirectResponse +import secrets + +app = FastAPI() + +# 定义重定向地址 +redirect_uri = "http://localhost:8000/auth/github/callback" + +# 初始化 GitHub OAuth2 客户端 +github_client = GitHubOAuth20( + client_id="your_github_client_id", + client_secret="your_github_client_secret" +) + +# 创建 FastAPI OAuth2 依赖项 +github_oauth = FastAPIOAuth20( + client=github_client, + redirect_uri=redirect_uri +) + + +@app.get("/auth/github") +async def github_auth(): + auth_url = await github_client.get_authorization_url(redirect_uri=redirect_uri) + return RedirectResponse(url=auth_url) -我们的目标是集成多个 CN 第三方客户端,敬请期待(🐦)... -你可以在 [客户端状态](status.md) 获取当前集成情况 +@app.get("/auth/github/callback") +async def github_callback(oauth_result: tuple = Depends(github_oauth)): + token_data, state = oauth_result + user_info = await github_client.get_userinfo(token_data["access_token"]) + return {"user": user_info} +``` ## 互动 diff --git a/docs/install.md b/docs/install.md index e39d606..98eab38 100644 --- a/docs/install.md +++ b/docs/install.md @@ -18,9 +18,3 @@ ```sh uv add fastapi-oauth20 ``` - -=== ":simple-pdm: pdm" - - ```sh - pdm add fastapi-oauth20 - ``` diff --git a/docs/status.md b/docs/status.md index 0bba581..ba32df0 100644 --- a/docs/status.md +++ b/docs/status.md @@ -4,7 +4,7 @@ 对于强制要求【实名 + 人脸认证】的平台,植入变得困难,所以它们不会很快到来 -## END +## FINISHED - [x] [LinuxDo](clients/linuxdo.md) - [x] [GitHub](clients/github.md) diff --git a/docs/usage.md b/docs/usage.md index 6096c2a..9ac34ad 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1 +1,278 @@ -TODO... +# 使用指南 + +本指南介绍如何将 FastAPI OAuth2.0 库与各种 OAuth2 提供程序一起使用。 + +## 基本用法 + +### 1. 选择 OAuth2 提供商并初始化客户端 + +```python +from fastapi_oauth20 import GitHubOAuth20, GoogleOAuth20, FastAPIOAuth20 +from fastapi import FastAPI, Depends +from fastapi.responses import RedirectResponse +import secrets + +app = FastAPI() + +# 初始化 GitHub OAuth2 客户端 +github_client = GitHubOAuth20( + client_id="your_github_client_id", + client_secret="your_github_client_secret" +) + +# 初始化 Google OAuth2 客户端 +google_client = GoogleOAuth20( + client_id="your_google_client_id", + client_secret="your_google_client_secret" +) +``` + +### 2. 创建 FastAPI OAuth2 依赖 + +```python +# 创建 FastAPI OAuth2 依赖 +github_oauth = FastAPIOAuth20( + client=github_client, + redirect_uri="http://localhost:8000/auth/github/callback" +) + +google_oauth = FastAPIOAuth20( + client=google_client, + redirect_uri="http://localhost:8000/auth/google/callback" +) +``` + +### 3. 定义授权端点 + +```python +@app.get("/auth/{provider}") +async def auth_provider(provider: str): + """重定向用户到 OAuth2 提供商进行授权""" + + # 生成安全的 state 参数用于 CSRF 保护 + state = secrets.token_urlsafe(32) + + if provider == "github": + auth_url = await github_client.get_authorization_url( + redirect_uri="http://localhost:8000/auth/github/callback", + state=state + ) + elif provider == "google": + auth_url = await google_client.get_authorization_url( + redirect_uri="http://localhost:8000/auth/google/callback", + state=state + ) + else: + raise HTTPException(status_code=404, detail="不支持的提供商") + + return RedirectResponse(url=auth_url) +``` + +### 4. 处理 OAuth 回调 + +```python +from typing import Tuple, Dict, Any + + +@app.get("/auth/github/callback") +async def github_callback( + oauth_result: Tuple[Dict[str, Any], str] = Depends(github_oauth) +): + """处理 GitHub OAuth 回调""" + token_data, state = oauth_result + + # 获取用户信息 + user_info = await github_client.get_userinfo(token_data["access_token"]) + + return { + "user": user_info, + "access_token": token_data["access_token"], + "state": state + } + + +@app.get("/auth/google/callback") +async def google_callback( + oauth_result: Tuple[Dict[str, Any], str] = Depends(google_oauth) +): + """处理 Google OAuth 回调""" + token_data, state = oauth_result + + # 获取用户信息 + user_info = await google_client.get_userinfo(token_data["access_token"]) + + return { + "user": user_info, + "access_token": token_data["access_token"], + "state": state + } +``` + +## 高级用法 + +### PKCE (Proof Key for Code Exchange) + +对于公共客户端(移动应用、SPA),使用 PKCE 增强安全性: + +```python +import base64 +import hashlib +import secrets + + +def generate_pkce_challenge(): + """生成 PKCE 代码验证器和挑战码""" + code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=') + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode('utf-8')).digest() + ).decode('utf-8').rstrip('=') + return code_verifier, code_challenge + + +@app.get("/auth/github/pkce") +async def github_auth_pkce(): + """GitHub OAuth with PKCE""" + code_verifier, code_challenge = generate_pkce_challenge() + state = secrets.token_urlsafe(32) + + # 在实际应用中,应该将 code_verifier 和 state 存储在会话或数据库中 + # 这里为了演示目的简化处理 + + auth_url = await github_client.get_authorization_url( + redirect_uri="http://localhost:8000/auth/github/callback", + state=state, + code_challenge=code_challenge, + code_challenge_method="S256" + ) + + return RedirectResponse(url=auth_url) +``` + +### 令牌刷新 + +某些提供商支持刷新令牌来延长会话: + +```python +@app.post("/auth/refresh") +async def refresh_token(refresh_token: str, provider: str): + """使用刷新令牌获取新的访问令牌""" + + if provider == "github": + # GitHub 不支持 OAuth 刷新令牌 + raise HTTPException(status_code=400, detail="GitHub 不支持令牌刷新") + elif provider == "google": + try: + new_tokens = await google_client.refresh_token(refresh_token) + return {"access_token": new_tokens["access_token"]} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) +``` + +### 令牌撤销 + +用户登出时撤销令牌: + +```python +@app.post("/auth/revoke") +async def revoke_token(access_token: str, provider: str): + """撤销访问令牌""" + + if provider == "google": + try: + await google_client.revoke_token(access_token) + return {"message": "令牌撤销成功"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + else: + raise HTTPException(status_code=400, detail="该提供商不支持令牌撤销") +``` + +## 错误处理 + +库提供了全面的错误处理: + +```python +from fastapi_oauth20.errors import ( + OAuth20AuthorizeCallbackError, + AccessTokenError, + GetUserInfoError +) + + +@app.exception_handler(OAuth20AuthorizeCallbackError) +async def oauth_callback_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError): + """处理 OAuth 回调错误""" + return { + "error": "OAuth 授权失败", + "detail": exc.detail, + "status_code": exc.status_code + } + + +@app.exception_handler(AccessTokenError) +async def access_token_error_handler(request: Request, exc: AccessTokenError): + """处理访问令牌错误""" + return { + "error": "访问令牌交换失败", + "detail": exc.msg + } +``` + +## 完整示例 + +```python +from fastapi import FastAPI, Depends, HTTPException +from fastapi.responses import RedirectResponse +from fastapi_oauth20 import GitHubOAuth20, FastAPIOAuth20 +from fastapi_oauth20.errors import OAuth20AuthorizeCallbackError +import secrets + +app = FastAPI() + +# 初始化客户端 +github_client = GitHubOAuth20( + client_id="your_client_id", + client_secret="your_client_secret" +) + +# 创建依赖 +github_oauth = FastAPIOAuth20( + client=github_client, + redirect_uri="http://localhost:8000/auth/github/callback" +) + + +@app.get("/auth/github") +async def github_auth(): + """GitHub 授权入口""" + state = secrets.token_urlsafe(32) + auth_url = await github_client.get_authorization_url( + redirect_uri="http://localhost:8000/auth/github/callback", + state=state + ) + return RedirectResponse(url=auth_url) + + +@app.get("/auth/github/callback") +async def github_callback( + oauth_result: tuple = Depends(github_oauth) +): + """GitHub 授权回调""" + token_data, state = oauth_result + user_info = await github_client.get_userinfo(token_data["access_token"]) + return {"user": user_info} + + +# 错误处理 +@app.exception_handler(OAuth20AuthorizeCallbackError) +async def oauth_error_handler(request, exc): + return {"error": "授权失败", "detail": exc.detail} +``` + +## 注意事项 + +1. **安全性**: 始终使用 HTTPS 端点 +2. **状态管理**: 使用安全的 state 参数防止 CSRF 攻击 +3. **令牌存储**: 安全地存储访问令牌和刷新令牌 +4. **错误处理**: 妥善处理各种 OAuth2 错误场景 +5. **作用域**: 只请求必要的作用域,遵循最小权限原则 diff --git a/fastapi_oauth20/clients/feishu.py b/fastapi_oauth20/clients/feishu.py index 95ed41f..cd4b001 100644 --- a/fastapi_oauth20/clients/feishu.py +++ b/fastapi_oauth20/clients/feishu.py @@ -7,7 +7,15 @@ class FeiShuOAuth20(OAuth20Base): + """FeiShu (Lark) OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize FeiShu OAuth2 client. + + :param client_id: FeiShu app client ID from the FeiShu developer console. + :param client_secret: FeiShu app client secret from the FeiShu developer console. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -22,7 +30,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from FeiShu""" + """ + Retrieve user information from FeiShu API. + + :param access_token: Valid FeiShu access token with contact:user scopes. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient() as client: response = await client.get('https://passport.feishu.cn/suite/passport/oauth/userinfo', headers=headers) diff --git a/fastapi_oauth20/clients/gitee.py b/fastapi_oauth20/clients/gitee.py index 8265e28..eac8c0a 100644 --- a/fastapi_oauth20/clients/gitee.py +++ b/fastapi_oauth20/clients/gitee.py @@ -7,7 +7,15 @@ class GiteeOAuth20(OAuth20Base): + """Gitee OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize Gitee OAuth2 client. + + :param client_id: Gitee OAuth application client ID. + :param client_secret: Gitee OAuth application client secret. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -18,7 +26,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from Gitee""" + """ + Retrieve user information from Gitee API. + + :param access_token: Valid Gitee access token with user_info scope. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient() as client: response = await client.get('https://gitee.com/api/v5/user', headers=headers) diff --git a/fastapi_oauth20/clients/github.py b/fastapi_oauth20/clients/github.py index 40bc6ce..49b95c9 100644 --- a/fastapi_oauth20/clients/github.py +++ b/fastapi_oauth20/clients/github.py @@ -7,7 +7,15 @@ class GitHubOAuth20(OAuth20Base): + """GitHub OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize GitHub OAuth2 client. + + :param client_id: GitHub OAuth App client ID. + :param client_secret: GitHub OAuth App client secret. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -17,7 +25,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from GitHub""" + """ + Retrieve user information from GitHub API. + + :param access_token: Valid GitHub access token with appropriate scopes. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient(headers=headers) as client: response = await client.get('https://api.github.com/user') diff --git a/fastapi_oauth20/clients/google.py b/fastapi_oauth20/clients/google.py index 8f5c151..0a7a148 100644 --- a/fastapi_oauth20/clients/google.py +++ b/fastapi_oauth20/clients/google.py @@ -7,7 +7,15 @@ class GoogleOAuth20(OAuth20Base): + """Google OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize Google OAuth2 client. + + :param client_id: Google OAuth 2.0 client ID from Google Cloud Console. + :param client_secret: Google OAuth 2.0 client secret from Google Cloud Console. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -19,7 +27,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from Google""" + """ + Retrieve user information from Google OAuth2 API. + + :param access_token: Valid Google access token with appropriate scopes. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient() as client: response = await client.get('https://www.googleapis.com/oauth2/v1/userinfo', headers=headers) diff --git a/fastapi_oauth20/clients/linuxdo.py b/fastapi_oauth20/clients/linuxdo.py index fc65d98..6adb0e7 100644 --- a/fastapi_oauth20/clients/linuxdo.py +++ b/fastapi_oauth20/clients/linuxdo.py @@ -7,7 +7,15 @@ class LinuxDoOAuth20(OAuth20Base): + """Linux.do OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize Linux.do OAuth2 client. + + :param client_id: Linux.do OAuth application client ID. + :param client_secret: Linux.do OAuth application client secret. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -18,7 +26,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from Linux Do""" + """ + Retrieve user information from Linux.do API. + + :param access_token: Valid Linux.do access token. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient() as client: response = await client.get('https://connect.linux.do/api/user', headers=headers) diff --git a/fastapi_oauth20/clients/oschina.py b/fastapi_oauth20/clients/oschina.py index c34fde2..91edfb7 100644 --- a/fastapi_oauth20/clients/oschina.py +++ b/fastapi_oauth20/clients/oschina.py @@ -7,7 +7,15 @@ class OSChinaOAuth20(OAuth20Base): + """OSChina OAuth2 client implementation.""" + def __init__(self, client_id: str, client_secret: str): + """ + Initialize OSChina OAuth2 client. + + :param client_id: OSChina OAuth application client ID. + :param client_secret: OSChina OAuth application client secret. + """ super().__init__( client_id=client_id, client_secret=client_secret, @@ -17,7 +25,12 @@ def __init__(self, client_id: str, client_secret: str): ) async def get_userinfo(self, access_token: str) -> dict: - """Get user info from OSChina""" + """ + Retrieve user information from OSChina API. + + :param access_token: Valid OSChina access token. + :return: + """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient() as client: response = await client.get('https://www.oschina.net/action/openapi/user', headers=headers) diff --git a/fastapi_oauth20/errors.py b/fastapi_oauth20/errors.py index d28c0cd..e677ace 100644 --- a/fastapi_oauth20/errors.py +++ b/fastapi_oauth20/errors.py @@ -4,54 +4,65 @@ class OAuth20BaseError(Exception): - """The oauth2 base error.""" + """Base exception class for all OAuth2-related errors.""" msg: str def __init__(self, msg: str) -> None: + """ + Initialize base OAuth2 error. + + :param msg: Human-readable error message describing the OAuth2 error. + """ self.msg = msg super().__init__(msg) class OAuth20RequestError(OAuth20BaseError): - """OAuth2 httpx request error""" + """Base exception for OAuth2 HTTP request errors.""" def __init__(self, msg: str, response: httpx.Response | None = None) -> None: + """ + Initialize OAuth2 request error. + + :param msg: Human-readable error message describing the request error. + :param response: The HTTP response object that caused the error (if available). + """ self.response = response super().__init__(msg) class HTTPXOAuth20Error(OAuth20RequestError): - """OAuth2 error for httpx raise for status""" + """Exception raised when httpx raises an HTTP status error.""" pass class AccessTokenError(OAuth20RequestError): - """Error raised when get access token fail.""" + """Exception raised when access token exchange fails.""" pass class RefreshTokenError(OAuth20RequestError): - """Refresh token error when refresh token fail.""" + """Exception raised when refresh token operation fails.""" pass class RevokeTokenError(OAuth20RequestError): - """Revoke token error when revoke token fail.""" + """Exception raised when token revocation fails.""" pass class GetUserInfoError(OAuth20RequestError): - """Get user info error when get user info fail.""" + """Exception raised when user info retrieval fails.""" pass class RedirectURIError(OAuth20RequestError): - """Redirect URI set error""" + """Exception raised for redirect URI configuration errors.""" pass diff --git a/fastapi_oauth20/integrations/fastapi.py b/fastapi_oauth20/integrations/fastapi.py index ec7e73c..5eb190c 100644 --- a/fastapi_oauth20/integrations/fastapi.py +++ b/fastapi_oauth20/integrations/fastapi.py @@ -11,7 +11,7 @@ class OAuth20AuthorizeCallbackError(HTTPException, OAuth20BaseError): - """The OAuth2 authorization callback error.""" + """Exception raised during OAuth2 authorization callback processing in FastAPI.""" def __init__( self, @@ -20,11 +20,21 @@ def __init__( headers: dict[str, str] | None = None, response: httpx.Response | None = None, ) -> None: + """ + Initialize OAuth2 callback error. + + :param status_code: HTTP status code to return in the response. + :param detail: Error detail message describing what went wrong. + :param headers: Additional HTTP headers to include in the error response. + :param response: The original HTTP response that caused the error (if any). + """ self.response = response super().__init__(status_code=status_code, detail=detail, headers=headers) class FastAPIOAuth20: + """FastAPI dependency for handling OAuth2 authorization callbacks.""" + def __init__( self, client: OAuth20Base, @@ -32,10 +42,10 @@ def __init__( redirect_uri: str | None = None, ): """ - OAuth2 authorization callback dependency injection + Initialize FastAPI OAuth2 callback handler. - :param client: A client base on OAuth20Base. - :param redirect_uri: OAuth2 callback full URL. + :param client: An OAuth2 client instance that inherits from OAuth20Base. + :param redirect_uri: The full callback URL where the OAuth2 provider redirects after authorization. Must match the URL registered with the OAuth2 provider. """ self.client = client self.redirect_uri = redirect_uri @@ -48,6 +58,15 @@ async def __call__( code_verifier: str | None = None, error: str | None = None, ) -> tuple[dict[str, Any], str | None]: + """ + Process OAuth2 callback request and exchange authorization code for access token. + + :param request: The FastAPI Request object containing callback parameters. + :param code: The authorization code received from the OAuth2 provider (extracted from query parameters). + :param state: The state parameter for CSRF protection (extracted from query parameters). + :param code_verifier: PKCE code verifier if PKCE was used in the authorization request. + :param error: Error parameter from OAuth2 provider if authorization was denied or failed. + """ if code is None or error is not None: raise OAuth20AuthorizeCallbackError( status_code=400, diff --git a/fastapi_oauth20/oauth20.py b/fastapi_oauth20/oauth20.py index 2a1c551..881dafc 100644 --- a/fastapi_oauth20/oauth20.py +++ b/fastapi_oauth20/oauth20.py @@ -33,17 +33,17 @@ def __init__( revoke_token_endpoint_basic_auth: bool = False, ): """ - Base OAuth2 client. + Base OAuth2 client implementing the OAuth 2.0 authorization framework. :param client_id: The client ID provided by the OAuth2 provider. :param client_secret: The client secret provided by the OAuth2 provider. - :param authorize_endpoint: The authorization endpoint URL. - :param access_token_endpoint: The access token endpoint URL. - :param refresh_token_endpoint: The refresh token endpoint URL. - :param revoke_token_endpoint: The revoke token endpoint URL. - :param default_scopes: - :param token_endpoint_basic_auth: - :param revoke_token_endpoint_basic_auth: + :param authorize_endpoint: The authorization endpoint URL where users are redirected to grant access. + :param access_token_endpoint: The token endpoint URL for exchanging authorization codes for access tokens. + :param refresh_token_endpoint: The token endpoint URL for refreshing expired access tokens using refresh tokens. + :param revoke_token_endpoint: The endpoint URL for revoking access tokens or refresh tokens. + :param default_scopes: Default list of OAuth scopes to request if none are specified. + :param token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for token endpoint requests. + :param revoke_token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for revoke endpoint requests. """ self.client_id = client_id self.client_secret = client_secret @@ -69,14 +69,14 @@ async def get_authorization_url( **kwargs, ) -> str: """ - Get authorization url for given. - - :param redirect_uri: redirected after authorization. - :param state: An opaque value used by the client to maintain state between the request and the callback. - :param scope: The scopes to be requested. - :param code_challenge: [PKCE](https://datatracker.ietf.org/doc/html/rfc7636) code challenge. - :param code_challenge_method: [PKCE](https://datatracker.ietf.org/doc/html/rfc7636) code challenge method. - :param kwargs: Additional arguments passed to the OAuth2 client. + Generate OAuth2 authorization URL for redirecting users to grant access. + + :param redirect_uri: The URL where the OAuth2 provider will redirect after authorization. + :param state: An opaque value used by the client to maintain state between the request and callback, preventing CSRF attacks. + :param scope: The list of OAuth scopes to request. If None, uses default_scopes from initialization. + :param code_challenge: PKCE code challenge generated from code_verifier using the specified method. + :param code_challenge_method: PKCE code challenge method, either 'plain' or 'S256' (recommended). + :param kwargs: Additional query parameters to include in the authorization URL. :return: """ params = { @@ -105,11 +105,11 @@ async def get_authorization_url( async def get_access_token(self, code: str, redirect_uri: str, code_verifier: str | None = None) -> dict[str, Any]: """ - Get access token for given. + Exchange authorization code for access token. - :param code: The authorization code. - :param redirect_uri: redirect uri after authorization. - :param code_verifier: the code verifier for the [PKCE](https://datatracker.ietf.org/doc/html/rfc7636). + :param code: The authorization code received from the OAuth2 provider callback. + :param redirect_uri: The exact redirect URI used in the authorization request (must match). + :param code_verifier: The PKCE code verifier used to generate the code challenge (required if PKCE was used). :return: """ data = { @@ -139,9 +139,9 @@ async def get_access_token(self, code: str, redirect_uri: str, code_verifier: st async def refresh_token(self, refresh_token: str) -> dict[str, Any]: """ - Get new access token by refresh token. + Refresh an access token using a refresh token. - :param refresh_token: The refresh token. + :param refresh_token: The refresh token received from the initial token exchange. :return: """ if self.refresh_token_endpoint is None: @@ -170,10 +170,10 @@ async def refresh_token(self, refresh_token: str) -> dict[str, Any]: async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: """ - Revoke a token. + Revoke an access token or refresh token. - :param token: A token or refresh token to revoke. - :param token_type_hint: Usually either `token` or `refresh_token`. + :param token: The access token or refresh token to revoke. + :param token_type_hint: Optional hint to the server about the token type ('access_token' or 'refresh_token'). :return: """ if self.revoke_token_endpoint is None: @@ -194,7 +194,12 @@ async def revoke_token(self, token: str, token_type_hint: str | None = None) -> @staticmethod def raise_httpx_oauth20_errors(response: httpx.Response) -> None: - """Raise HTTPXOAuth20Error if the response is invalid""" + """ + Check HTTP response and raise appropriate OAuth2 errors for invalid responses. + + :param response: The HTTP response object to validate. + :return: + """ try: response.raise_for_status() except httpx.HTTPStatusError as e: @@ -204,7 +209,13 @@ def raise_httpx_oauth20_errors(response: httpx.Response) -> None: @staticmethod def get_json_result(response: httpx.Response, *, err_class: type[OAuth20RequestError]) -> dict[str, Any]: - """Get response json""" + """ + Parse JSON response and handle JSON decoding errors. + + :param response: The HTTP response object containing JSON data. + :param err_class: The specific OAuth2RequestError subclass to raise on JSON parsing failure. + :return: + """ try: return cast(dict[str, Any], response.json()) except json.JSONDecodeError as e: @@ -213,9 +224,9 @@ def get_json_result(response: httpx.Response, *, err_class: type[OAuth20RequestE @abc.abstractmethod async def get_userinfo(self, access_token: str) -> dict[str, Any]: """ - Get user info from the API provider + Retrieve user information from the OAuth2 provider. - :param access_token: The access token. + :param access_token: Valid access token to authenticate the request to the provider's user info endpoint. :return: """ raise NotImplementedError() diff --git a/mkdocs.yml b/mkdocs.yml index 795fb60..3fb5580 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,7 +10,6 @@ nav: - 安装: install.md - 名词解释: explanation.md - 用法: usage.md - - 高级用法: advanced.md - 集成: - FastAPI: fastapi.md - 客户端状态: status.md diff --git a/pyproject.toml b/pyproject.toml index b4f3560..8c80a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ "fastapi>=0.119.0", "respx>=0.22.0", "ty>=0.0.1a23", + "click==8.2.1", ] lint = [ "pre-commit>=4.3.0", diff --git a/requirements.txt b/requirements.txt index d2617a2..6c0823f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,8 +15,11 @@ certifi==2025.10.5 # httpx cfgv==3.4.0 # via pre-commit +click==8.2.1 colorama==0.4.6 ; sys_platform == 'win32' - # via pytest + # via + # click + # pytest distlib==0.4.0 # via virtualenv exceptiongroup==1.3.0 ; python_full_version < '3.11' diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..56fafa5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/uv.lock b/uv.lock index 4291c8d..2bc8536 100644 --- a/uv.lock +++ b/uv.lock @@ -167,14 +167,14 @@ wheels = [ [[package]] name = "click" -version = "8.3.0" +version = "8.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/46/61/de6cd827efad202d7057d93e0fed9294b96952e188f7384832791c7b2254/click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4", size = 276943, upload-time = "2025-09-18T17:32:23.696Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" }, + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] [[package]] @@ -230,6 +230,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "click" }, { name = "fastapi" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -249,6 +250,7 @@ requires-dist = [{ name = "httpx", specifier = ">=0.18.0" }] [package.metadata.requires-dev] dev = [ + { name = "click", specifier = "==8.2.1" }, { name = "fastapi", specifier = ">=0.119.0" }, { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, From eb2f35a123c14d06ebcfa4f60499410defebbd5b Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Sat, 1 Nov 2025 19:06:43 +0800 Subject: [PATCH 2/3] Add tests --- fastapi_oauth20/__init__.py | 1 + tests/clients/__init__.py | 2 + tests/clients/test_feishu.py | 192 +++++++++++++++++ tests/clients/test_gitee.py | 130 ++++++++++++ tests/clients/test_github.py | 204 ++++++++++++++++++ tests/clients/test_google.py | 150 +++++++++++++ tests/clients/test_linuxdo.py | 119 +++++++++++ tests/clients/test_oschina.py | 104 +++++++++ tests/conftest.py | 233 ++++++++++++++++++++ tests/test_basic.py | 221 +++++++++++++++++++ tests/test_errors.py | 156 ++++++++++++++ tests/test_fastapi.py | 338 +++++++++++++++++++++++++++++ tests/test_oauth20.py | 385 ++++++++++++++++++++++++++++++++++ 13 files changed, 2235 insertions(+) create mode 100644 tests/clients/__init__.py create mode 100644 tests/clients/test_feishu.py create mode 100644 tests/clients/test_gitee.py create mode 100644 tests/clients/test_github.py create mode 100644 tests/clients/test_google.py create mode 100644 tests/clients/test_linuxdo.py create mode 100644 tests/clients/test_oschina.py create mode 100644 tests/conftest.py create mode 100644 tests/test_basic.py create mode 100644 tests/test_errors.py create mode 100644 tests/test_fastapi.py create mode 100644 tests/test_oauth20.py diff --git a/fastapi_oauth20/__init__.py b/fastapi_oauth20/__init__.py index d1ea7b6..9112766 100644 --- a/fastapi_oauth20/__init__.py +++ b/fastapi_oauth20/__init__.py @@ -7,5 +7,6 @@ from .clients.linuxdo import LinuxDoOAuth20 as LinuxDoOAuth20 from .clients.oschina import OSChinaOAuth20 as OSChinaOAuth20 from .integrations.fastapi import FastAPIOAuth20 as FastAPIOAuth20 +from .integrations.fastapi import OAuth20AuthorizeCallbackError as OAuth20AuthorizeCallbackError __version__ = '0.0.1' diff --git a/tests/clients/__init__.py b/tests/clients/__init__.py new file mode 100644 index 0000000..56fafa5 --- /dev/null +++ b/tests/clients/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/tests/clients/test_feishu.py b/tests/clients/test_feishu.py new file mode 100644 index 0000000..a68b8e7 --- /dev/null +++ b/tests/clients/test_feishu.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.feishu import FeiShuOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestFeiShuOAuth20: + """Test FeiShu OAuth2 client functionality.""" + + def test_feishu_client_initialization(self, feishu_client): + """Test FeiShu client initialization with correct parameters.""" + assert feishu_client.client_id == TEST_CLIENT_ID + assert feishu_client.client_secret == TEST_CLIENT_SECRET + assert feishu_client.authorize_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/authorize' + assert feishu_client.access_token_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/token' + assert feishu_client.refresh_token_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/authorize' + assert feishu_client.default_scopes == [ + 'contact:user.employee_id:readonly', + 'contact:user.base:readonly', + 'contact:user.email:readonly', + ] + + def test_feishu_client_initialization_with_custom_credentials(self): + """Test FeiShu client initialization with custom credentials.""" + client = FeiShuOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_feishu_client_inheritance(self, feishu_client): + """Test that FeiShu client properly inherits from OAuth20Base.""" + assert isinstance(feishu_client, OAuth20Base) + + def test_feishu_client_scopes_are_lists(self, feishu_client): + """Test that default scopes are properly configured as lists.""" + assert isinstance(feishu_client.default_scopes, list) + assert len(feishu_client.default_scopes) == 3 + assert all(isinstance(scope, str) for scope in feishu_client.default_scopes) + + def test_feishu_client_endpoint_urls(self): + """Test that FeiShu client uses correct endpoint URLs.""" + client = FeiShuOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/suite/passport/oauth/authorize') + assert client.access_token_endpoint.endswith('/suite/passport/oauth/token') + assert client.refresh_token_endpoint.endswith('/suite/passport/oauth/authorize') + + # Test that all endpoints use the correct domain + for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + assert 'passport.feishu.cn' in endpoint + + def test_feishu_client_multiple_instances(self): + """Test that multiple FeiShu client instances work independently.""" + client1 = FeiShuOAuth20('client1', 'secret1') + client2 = FeiShuOAuth20('client2', 'secret2') + + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint + assert client1.access_token_endpoint == client2.access_token_endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, feishu_client): + """Test successful user info retrieval from FeiShu API.""" + mock_user_data = create_mock_user_data('feishu') + mock_user_info_response( + respx, + {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, + mock_user_data, + ) + + result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_with_different_access_token(self, feishu_client): + """Test user info retrieval with different access tokens.""" + mock_user_data = create_mock_user_data('feishu', user_id='user_789', name='Another User') + mock_user_info_response( + respx, + {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, + mock_user_data, + ) + + result = await feishu_client.get_userinfo('different_token') + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_empty_response(self, feishu_client): + """Test handling of empty user info response.""" + mock_user_info_response( + respx, {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, {} + ) + + result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == {} + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_partial_data(self, feishu_client): + """Test handling of partial user info response.""" + partial_data = {'user_id': 'test_user', 'name': 'Test User'} + mock_user_info_response( + respx, + {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, + partial_data, + ) + + result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == partial_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, feishu_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'user_id': 'test_user'} + route = mock_user_info_response( + respx, + {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, + mock_user_data, + ) + + await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_401(self, feishu_client): + """Test handling of 401 HTTP error when getting user info.""" + respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( + return_value=httpx.Response(401, text='Unauthorized') + ) + + with pytest.raises(HTTPXOAuth20Error): + await feishu_client.get_userinfo(INVALID_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_403(self, feishu_client): + """Test handling of 403 HTTP error when getting user info.""" + respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( + return_value=httpx.Response(403, text='Forbidden') + ) + + with pytest.raises(HTTPXOAuth20Error): + await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_500(self, feishu_client): + """Test handling of 500 HTTP error when getting user info.""" + respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( + return_value=httpx.Response(500, text='Internal Server Error') + ) + + with pytest.raises(HTTPXOAuth20Error): + await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, feishu_client): + """Test handling of invalid JSON response.""" + respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( + return_value=httpx.Response(200, text='invalid json') + ) + + with pytest.raises(GetUserInfoError): + await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_gitee.py b/tests/clients/test_gitee.py new file mode 100644 index 0000000..7afa4a9 --- /dev/null +++ b/tests/clients/test_gitee.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.gitee import GiteeOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestGiteeOAuth20: + """Test Gitee OAuth2 client functionality.""" + + def test_gitee_client_initialization(self, gitee_client): + """Test Gitee client initialization with correct parameters.""" + assert gitee_client.client_id == TEST_CLIENT_ID + assert gitee_client.client_secret == TEST_CLIENT_SECRET + assert gitee_client.authorize_endpoint == 'https://gitee.com/oauth/authorize' + assert gitee_client.access_token_endpoint == 'https://gitee.com/oauth/token' + assert gitee_client.refresh_token_endpoint == 'https://gitee.com/oauth/token' + assert gitee_client.default_scopes == ['user_info'] + + def test_gitee_client_initialization_with_custom_credentials(self): + """Test Gitee client initialization with custom credentials.""" + client = GiteeOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_gitee_client_inheritance(self, gitee_client): + """Test that Gitee client properly inherits from OAuth20Base.""" + assert isinstance(gitee_client, OAuth20Base) + + def test_gitee_client_scopes_are_lists(self, gitee_client): + """Test that default scopes are properly configured as lists.""" + assert isinstance(gitee_client.default_scopes, list) + assert len(gitee_client.default_scopes) == 1 + assert all(isinstance(scope, str) for scope in gitee_client.default_scopes) + + def test_gitee_client_endpoint_urls(self): + """Test that Gitee client uses correct endpoint URLs.""" + client = GiteeOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/oauth/authorize') + assert client.access_token_endpoint.endswith('/oauth/token') + assert client.refresh_token_endpoint.endswith('/oauth/token') + + # Test that all endpoints use the correct domain + for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + assert 'gitee.com' in endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, gitee_client): + """Test successful user info retrieval from Gitee API.""" + mock_user_data = create_mock_user_data('gitee') + mock_user_info_response( + respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data + ) + + result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, gitee_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'id': 'test_user'} + route = mock_user_info_response( + respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data + ) + + await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, gitee_client): + """Test handling of HTTP errors when getting user info.""" + respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + + with pytest.raises(HTTPXOAuth20Error): + await gitee_client.get_userinfo(INVALID_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_empty_response(self, gitee_client): + """Test handling of empty user info response.""" + mock_user_info_response(respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, {}) + + result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == {} + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_partial_data(self, gitee_client): + """Test handling of partial user info response.""" + partial_data = {'id': 123456, 'login': 'testuser'} + mock_user_info_response( + respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, partial_data + ) + + result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == partial_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, gitee_client): + """Test handling of invalid JSON response.""" + respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(200, text='invalid json')) + + with pytest.raises(GetUserInfoError): + await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_github.py b/tests/clients/test_github.py new file mode 100644 index 0000000..acd13d2 --- /dev/null +++ b/tests/clients/test_github.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.github import GitHubOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestGitHubOAuth20: + """Test GitHub OAuth2 client functionality.""" + + def test_github_client_initialization(self, github_client): + """Test GitHub client initialization with correct parameters.""" + assert github_client.client_id == TEST_CLIENT_ID + assert github_client.client_secret == TEST_CLIENT_SECRET + assert github_client.authorize_endpoint == 'https://github.com/login/oauth/authorize' + assert github_client.access_token_endpoint == 'https://github.com/login/oauth/access_token' + assert github_client.default_scopes == ['user', 'user:email'] + + def test_github_client_initialization_with_custom_credentials(self): + """Test GitHub client initialization with custom credentials.""" + client = GitHubOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_github_client_inheritance(self, github_client): + """Test that GitHub client properly inherits from OAuth20Base.""" + assert isinstance(github_client, OAuth20Base) + + def test_github_client_scopes_are_lists(self, github_client): + """Test that default scopes are properly configured as lists.""" + assert isinstance(github_client.default_scopes, list) + assert len(github_client.default_scopes) == 2 + assert all(isinstance(scope, str) for scope in github_client.default_scopes) + + def test_github_client_endpoint_urls(self): + """Test that GitHub client uses correct endpoint URLs.""" + client = GitHubOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/login/oauth/authorize') + assert client.access_token_endpoint.endswith('/login/oauth/access_token') + + # Test that all endpoints use the correct domain + for endpoint in [client.authorize_endpoint, client.access_token_endpoint]: + assert 'github.com' in endpoint + + def test_github_client_multiple_instances(self): + """Test that multiple GitHub client instances work independently.""" + client1 = GitHubOAuth20('client1', 'secret1') + client2 = GitHubOAuth20('client2', 'secret2') + + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint + assert client1.access_token_endpoint == client2.access_token_endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success_with_email(self, github_client): + """Test successful user info retrieval from GitHub API with email included.""" + mock_user_data = create_mock_user_data('github') + mock_user_info_response( + respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data + ) + + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success_without_email(self, github_client): + """Test successful user info retrieval from GitHub API without email.""" + mock_user_data = create_mock_user_data('github', email=None) + mock_user_info_response( + respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data + ) + # Mock emails endpoint + emails_data = [{'email': 'test@example.com', 'primary': True}] + respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data)) + + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result['login'] == mock_user_data['login'] + assert result['email'] == 'test@example.com' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_with_different_access_token(self, github_client): + """Test user info retrieval with different access tokens.""" + mock_user_data = create_mock_user_data('github', id=789, login='different_user') + mock_user_info_response( + respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data + ) + + result = await github_client.get_userinfo('different_token') + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, github_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'id': 'test_user', 'email': 'test@example.com'} # Include email to avoid emails endpoint call + route = mock_user_info_response( + respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data + ) + + await github_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_401(self, github_client): + """Test handling of 401 HTTP error when getting user info.""" + respx.get('https://api.github.com/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + + with pytest.raises(HTTPXOAuth20Error): + await github_client.get_userinfo(INVALID_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_403(self, github_client): + """Test handling of 403 HTTP error when getting user info.""" + respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, text='Forbidden')) + + with pytest.raises(HTTPXOAuth20Error): + await github_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_500(self, github_client): + """Test handling of 500 HTTP error when getting user info.""" + respx.get('https://api.github.com/user').mock(return_value=httpx.Response(500, text='Internal Server Error')) + + with pytest.raises(HTTPXOAuth20Error): + await github_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, github_client): + """Test handling of invalid JSON response.""" + respx.get('https://api.github.com/user').mock(return_value=httpx.Response(200, text='invalid json')) + + with pytest.raises(GetUserInfoError): + await github_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_empty_response(self, github_client): + """Test handling of empty user info response.""" + mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, {}) + # Mock emails endpoint since empty response will trigger email lookup + emails_data = [{'email': 'test@example.com', 'primary': True}] + respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data)) + + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result['email'] == 'test@example.com' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_partial_data(self, github_client): + """Test handling of partial user info response.""" + partial_data = { + 'id': 123456, + 'login': 'testuser', + 'email': 'test@example.com', + } # Add email to avoid emails endpoint call + mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, partial_data) + + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == partial_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_rate_limit(self, github_client): + """Test handling of GitHub API rate limit.""" + # GitHub rate limit response + rate_limit_response = { + 'message': 'API rate limit exceeded for xxx.xxx.xxx.xxx.', + 'documentation_url': 'https://docs.github.com/rest/overview/rate-limits-for-the-rest-api', + } + + respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, json=rate_limit_response)) + + with pytest.raises(HTTPXOAuth20Error): + await github_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_google.py b/tests/clients/test_google.py new file mode 100644 index 0000000..505db53 --- /dev/null +++ b/tests/clients/test_google.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.google import GoogleOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestGoogleOAuth20: + """Test Google OAuth2 client functionality.""" + + def test_google_client_initialization(self, google_client): + """Test Google client initialization with correct parameters.""" + assert google_client.client_id == TEST_CLIENT_ID + assert google_client.client_secret == TEST_CLIENT_SECRET + assert google_client.authorize_endpoint == 'https://accounts.google.com/o/oauth2/v2/auth' + assert google_client.access_token_endpoint == 'https://oauth2.googleapis.com/token' + assert google_client.refresh_token_endpoint == 'https://oauth2.googleapis.com/token' + assert google_client.revoke_token_endpoint == 'https://accounts.google.com/o/oauth2/revoke' + assert google_client.default_scopes == ['email', 'openid', 'profile'] + + def test_google_client_initialization_with_custom_credentials(self): + """Test Google client initialization with custom credentials.""" + client = GoogleOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_google_client_inheritance(self, google_client): + """Test that Google client properly inherits from OAuth20Base.""" + assert isinstance(google_client, OAuth20Base) + + def test_google_client_scopes_are_lists(self, google_client): + """Test that default scopes are properly configured as lists.""" + assert isinstance(google_client.default_scopes, list) + assert len(google_client.default_scopes) == 3 + assert all(isinstance(scope, str) for scope in google_client.default_scopes) + + def test_google_client_endpoint_urls(self): + """Test that Google client uses correct endpoint URLs.""" + client = GoogleOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/o/oauth2/v2/auth') + assert client.access_token_endpoint.endswith('/token') + assert client.refresh_token_endpoint.endswith('/token') + assert client.revoke_token_endpoint.endswith('/o/oauth2/revoke') + + # Test that all endpoints use the correct domains + assert 'accounts.google.com' in client.authorize_endpoint + assert 'accounts.google.com' in client.revoke_token_endpoint + assert 'oauth2.googleapis.com' in client.access_token_endpoint + assert 'oauth2.googleapis.com' in client.refresh_token_endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, google_client): + """Test successful user info retrieval from Google OAuth2 API.""" + mock_user_data = create_mock_user_data('google') + mock_user_info_response( + respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data + ) + + result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, google_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'id': 'test_user'} + route = mock_user_info_response( + respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data + ) + + await google_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, google_client): + """Test handling of HTTP errors when getting user info.""" + respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock( + return_value=httpx.Response(401, text='Unauthorized') + ) + + with pytest.raises(HTTPXOAuth20Error): + await google_client.get_userinfo(INVALID_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_empty_response(self, google_client): + """Test handling of empty user info response.""" + mock_user_info_response( + respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, {} + ) + + result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == {} + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_partial_data(self, google_client): + """Test handling of partial user info response.""" + partial_data = {'id': '123456789', 'email': 'test@example.com'} + mock_user_info_response( + respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, partial_data + ) + + result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == partial_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, google_client): + """Test handling of invalid JSON response.""" + respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock( + return_value=httpx.Response(200, text='invalid json') + ) + + with pytest.raises(GetUserInfoError): + await google_client.get_userinfo(TEST_ACCESS_TOKEN) + + def test_google_client_multiple_instances(self): + """Test that multiple Google client instances work independently.""" + client1 = GoogleOAuth20('client1', 'secret1') + client2 = GoogleOAuth20('client2', 'secret2') + + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint + assert client1.access_token_endpoint == client2.access_token_endpoint diff --git a/tests/clients/test_linuxdo.py b/tests/clients/test_linuxdo.py new file mode 100644 index 0000000..1201c35 --- /dev/null +++ b/tests/clients/test_linuxdo.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20 +from fastapi_oauth20.errors import HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestLinuxDoOAuth20: + """Test LinuxDo OAuth2 client functionality.""" + + def test_linuxdo_client_initialization(self, linuxdo_client): + """Test LinuxDo client initialization with correct parameters.""" + assert linuxdo_client.client_id == TEST_CLIENT_ID + assert linuxdo_client.client_secret == TEST_CLIENT_SECRET + assert linuxdo_client.authorize_endpoint == 'https://connect.linux.do/oauth2/authorize' + assert linuxdo_client.access_token_endpoint == 'https://connect.linux.do/oauth2/token' + assert linuxdo_client.refresh_token_endpoint == 'https://connect.linux.do/oauth2/token' + assert linuxdo_client.default_scopes is None + assert linuxdo_client.token_endpoint_basic_auth is True + + def test_linuxdo_client_initialization_with_custom_credentials(self): + """Test LinuxDo client initialization with custom credentials.""" + client = LinuxDoOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_linuxdo_client_inheritance(self, linuxdo_client): + """Test that LinuxDo client properly inherits from OAuth20Base.""" + assert isinstance(linuxdo_client, OAuth20Base) + + def test_linuxdo_client_basic_auth_enabled(self, linuxdo_client): + """Test that LinuxDo client has basic authentication enabled for token endpoint.""" + assert linuxdo_client.token_endpoint_basic_auth is True + + def test_linuxdo_client_endpoint_urls(self): + """Test that LinuxDo client uses correct endpoint URLs.""" + client = LinuxDoOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/oauth2/authorize') + assert client.access_token_endpoint.endswith('/oauth2/token') + assert client.refresh_token_endpoint.endswith('/oauth2/token') + + # Test that all endpoints use the correct domain + for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + assert 'connect.linux.do' in endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, linuxdo_client): + """Test successful user info retrieval from LinuxDo API.""" + mock_user_data = create_mock_user_data('linuxdo') + mock_user_info_response( + respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data + ) + + result = await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, linuxdo_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'id': 'test_user'} + route = mock_user_info_response( + respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data + ) + + await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, linuxdo_client): + """Test handling of HTTP errors when getting user info.""" + respx.get('https://connect.linux.do/api/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + + with pytest.raises(HTTPXOAuth20Error): + await linuxdo_client.get_userinfo(INVALID_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_access_token_uses_basic_auth(self, linuxdo_client): + """Test that access token requests use HTTP Basic Authentication.""" + mock_token_data = {'access_token': 'new_access_token'} + + # Mock the token endpoint and capture the request + route = respx.post('https://connect.linux.do/oauth2/token').mock( + return_value=httpx.Response(200, json=mock_token_data) + ) + + await linuxdo_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') + + # Verify BasicAuth was used + assert route.called + request = route.calls[0].request + assert 'authorization' in request.headers + # Basic auth should be present + assert request.headers['authorization'].startswith('Basic ') diff --git a/tests/clients/test_oschina.py b/tests/clients/test_oschina.py new file mode 100644 index 0000000..95cc411 --- /dev/null +++ b/tests/clients/test_oschina.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi_oauth20.clients.oschina import OSChinaOAuth20 +from fastapi_oauth20.errors import HTTPXOAuth20Error +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, + mock_user_info_response, +) + +# Constants specific to this test file +CUSTOM_CLIENT_ID = 'custom_id' +CUSTOM_CLIENT_SECRET = 'custom_secret' + + +class TestOSChinaOAuth20: + """Test OSChina OAuth2 client functionality.""" + + def test_oschina_client_initialization(self, oschina_client): + """Test OSChina client initialization with correct parameters.""" + assert oschina_client.client_id == TEST_CLIENT_ID + assert oschina_client.client_secret == TEST_CLIENT_SECRET + assert oschina_client.authorize_endpoint == 'https://www.oschina.net/action/oauth2/authorize' + assert oschina_client.access_token_endpoint == 'https://www.oschina.net/action/openapi/token' + assert oschina_client.refresh_token_endpoint == 'https://www.oschina.net/action/openapi/token' + assert oschina_client.default_scopes is None + + def test_oschina_client_initialization_with_custom_credentials(self): + """Test OSChina client initialization with custom credentials.""" + client = OSChinaOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) + assert client.client_id == CUSTOM_CLIENT_ID + assert client.client_secret == CUSTOM_CLIENT_SECRET + + def test_oschina_client_inheritance(self, oschina_client): + """Test that OSChina client properly inherits from OAuth20Base.""" + assert isinstance(oschina_client, OAuth20Base) + + def test_oschina_client_no_default_scopes(self, oschina_client): + """Test that OSChina client has no default scopes configured.""" + assert oschina_client.default_scopes is None + + def test_oschina_client_endpoint_urls(self): + """Test that OSChina client uses correct endpoint URLs.""" + client = OSChinaOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) + + # Test that endpoints are correctly set without hardcoding them in tests + assert client.authorize_endpoint.endswith('/action/oauth2/authorize') + assert client.access_token_endpoint.endswith('/action/openapi/token') + assert client.refresh_token_endpoint.endswith('/action/openapi/token') + + # Test that all endpoints use the correct domain + for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + assert 'oschina.net' in endpoint + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, oschina_client): + """Test successful user info retrieval from OSChina API.""" + mock_user_data = create_mock_user_data('oschina') + mock_user_info_response( + respx, + {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'}, + mock_user_data, + ) + + result = await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_authorization_header(self, oschina_client): + """Test that authorization header is correctly formatted.""" + mock_user_data = {'id': 'test_user'} + route = mock_user_info_response( + respx, + {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'}, + mock_user_data, + ) + + await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) + + # Verify the request was made with correct authorization header + assert route.called + request = route.calls[0].request + assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, oschina_client): + """Test handling of HTTP errors when getting user info.""" + respx.get('https://www.oschina.net/action/openapi/user').mock( + return_value=httpx.Response(401, text='Unauthorized') + ) + + with pytest.raises(HTTPXOAuth20Error): + await oschina_client.get_userinfo(INVALID_TOKEN) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..80e6a66 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Global constants and shared fixtures for fastapi-oauth20 tests. +""" + +import httpx +import pytest + +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +from fastapi_oauth20.clients.feishu import FeiShuOAuth20 +from fastapi_oauth20.clients.gitee import GiteeOAuth20 +from fastapi_oauth20.clients.github import GitHubOAuth20 +from fastapi_oauth20.clients.google import GoogleOAuth20 +from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20 +from fastapi_oauth20.clients.oschina import OSChinaOAuth20 + +TEST_CLIENT_ID = 'test_client_id' +TEST_CLIENT_SECRET = 'test_client_secret' +TEST_ACCESS_TOKEN = 'test_access_token' +INVALID_TOKEN = 'invalid_token' +TEST_STATE = 'test_state' +OAUTH_PROVIDERS = [ + { + 'name': 'github', + 'client_class': GitHubOAuth20, + 'token_url': 'https://github.com/login/oauth/access_token', + 'user_info_url': 'https://api.github.com/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user'}, + 'redirect_uri': 'http://localhost:8000/auth/github/callback', + }, + { + 'name': 'google', + 'client_class': GoogleOAuth20, + 'token_url': 'https://oauth2.googleapis.com/token', + 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'Bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/google/callback', + }, + { + 'name': 'gitee', + 'client_class': GiteeOAuth20, + 'token_url': 'https://gitee.com/oauth/token', + 'user_info_url': 'https://gitee.com/api/v5/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user_info'}, + 'redirect_uri': 'http://localhost:8000/auth/gitee/callback', + }, + { + 'name': 'feishu', + 'client_class': FeiShuOAuth20, + 'token_url': 'https://passport.feishu.cn/suite/passport/oauth/token', + 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/feishu/callback', + }, + { + 'name': 'linuxdo', + 'client_class': LinuxDoOAuth20, + 'token_url': 'https://connect.linux.do/oauth2/token', + 'user_info_url': 'https://connect.linux.do/api/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/linuxdo/callback', + }, + { + 'name': 'oschina', + 'client_class': OSChinaOAuth20, + 'token_url': 'https://www.oschina.net/action/openapi/token', + 'user_info_url': 'https://www.oschina.net/action/openapi/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/oschina/callback', + }, +] + + +@pytest.fixture(params=OAUTH_PROVIDERS) +def oauth_provider_config(request): + """Get OAuth provider configuration.""" + return request.param + + +@pytest.fixture +def oauth_client(oauth_provider_config): + """Create OAuth2 client for testing based on provider config.""" + provider_config = oauth_provider_config + client_class = provider_config['client_class'] + return client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def fastapi_app(): + """Create FastAPI app for testing.""" + app = FastAPI() + + @app.get('/') + async def root(): + return {'message': 'OAuth2 Test App'} + + return app + + +@pytest.fixture +def test_client(fastapi_app): + """Create TestClient for FastAPI app.""" + return TestClient(fastapi_app) + + +# Individual OAuth client fixtures for non-parametrized tests +@pytest.fixture +def github_client(): + """Create GitHub OAuth2 client instance for testing.""" + return GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def google_client(): + """Create Google OAuth2 client instance for testing.""" + return GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def gitee_client(): + """Create Gitee OAuth2 client instance for testing.""" + return GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def feishu_client(): + """Create FeiShu OAuth2 client instance for testing.""" + return FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def linuxdo_client(): + """Create LinuxDo OAuth2 client instance for testing.""" + return LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +@pytest.fixture +def oschina_client(): + """Create OSChina OAuth2 client instance for testing.""" + return OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +def create_mock_user_data(provider_name: str, **overrides): + """Create mock user data for a specific provider with optional overrides.""" + MOCK_USER_DATA = { + 'github': { + 'id': 123456, + 'login': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'bio': 'Test bio', + 'location': 'Test Location', + }, + 'google': { + 'id': '123456789', + 'email': 'test@gmail.com', + 'name': 'Test User', + 'picture': 'https://lh3.googleusercontent.com/test.jpg', + }, + 'gitee': { + 'id': 123456, + 'login': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar_url': 'https://avatar.example.com/testuser.png', + }, + 'feishu': { + 'user_id': 'test_user_123', + 'employee_id': 'emp_456', + 'name': 'Test User', + 'email': 'test@example.com', + 'mobile': '13800000000', + }, + 'linuxdo': { + 'id': 123456, + 'username': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar_url': 'https://linux.do/avatar/testuser.png', + }, + 'oschina': { + 'id': 123456, + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar': 'https://oschina.net/img/test.jpg', + }, + } + + base_data = MOCK_USER_DATA.get(provider_name, {}).copy() + base_data.update(overrides) + return base_data + + +def mock_oauth_token_response(respx_mock, provider_config: dict, status_code: int = 200): + """Mock OAuth token endpoint response for a provider.""" + return respx_mock.post(provider_config['token_url']).mock( + return_value=httpx.Response(status_code, json=provider_config['token_response']) + ) + + +def mock_user_info_response(respx_mock, provider_config: dict, user_data: dict = None, status_code: int = 200): + """Mock user info endpoint response for a provider.""" + if user_data is None: + user_data = create_mock_user_data(provider_config['name']) + return respx_mock.get(provider_config['user_info_url']).mock( + return_value=httpx.Response(status_code, json=user_data) + ) + + +def setup_oauth_callback_route(app: FastAPI, provider_config: dict, oauth_dependency): + """Setup OAuth callback route for testing.""" + callback_path = f'/auth/{provider_config["name"]}/callback' + + @app.get(callback_path) + async def oauth_callback_handler(access_token_state=Depends(oauth_dependency)): + token, state = access_token_state + return {'provider': provider_config['name'], 'access_token': token, 'state': state} + + return callback_path + + +def assert_oauth_error_response(response, expected_error: str, expected_status: int = 400): + """Assert OAuth error response has correct format.""" + assert response.status_code == expected_status + data = response.json() + if 'detail' in data: + assert data['detail'] == expected_error + else: + assert data.get('error') == expected_error diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..db7b0ed --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi_oauth20.clients.feishu import FeiShuOAuth20 +from fastapi_oauth20.clients.gitee import GiteeOAuth20 +from fastapi_oauth20.clients.github import GitHubOAuth20 +from fastapi_oauth20.clients.google import GoogleOAuth20 +from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20 +from fastapi_oauth20.clients.oschina import OSChinaOAuth20 +from fastapi_oauth20.errors import ( + AccessTokenError, + GetUserInfoError, + HTTPXOAuth20Error, + OAuth20RequestError, + RefreshTokenError, + RevokeTokenError, +) + + +def test_feishu_client_creation(): + """Test FeiShu client can be created.""" + client = FeiShuOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.default_scopes is not None + + +def test_github_client_creation(): + """Test GitHub client can be created.""" + client = GitHubOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.default_scopes is not None + + +def test_google_client_creation(): + """Test Google client can be created.""" + client = GoogleOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.refresh_token_endpoint is not None + assert client.revoke_token_endpoint is not None + assert client.default_scopes is not None + + +def test_gitee_client_creation(): + """Test Gitee client can be created.""" + client = GiteeOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.refresh_token_endpoint is not None + assert client.default_scopes is not None + + +def test_oschina_client_creation(): + """Test OSChina client can be created.""" + client = OSChinaOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.refresh_token_endpoint is not None + + +def test_linuxdo_client_creation(): + """Test Linux.do client can be created.""" + client = LinuxDoOAuth20('test_id', 'test_secret') + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint is not None + assert client.access_token_endpoint is not None + assert client.refresh_token_endpoint is not None + assert client.token_endpoint_basic_auth is True + + +def test_all_clients_inherit_from_base(): + """Test all clients inherit from OAuth20Base.""" + from fastapi_oauth20.oauth20 import OAuth20Base + + clients = [ + FeiShuOAuth20('id', 'secret'), + GitHubOAuth20('id', 'secret'), + GoogleOAuth20('id', 'secret'), + GiteeOAuth20('id', 'secret'), + OSChinaOAuth20('id', 'secret'), + LinuxDoOAuth20('id', 'secret'), + ] + + for client in clients: + assert isinstance(client, OAuth20Base) + + +def test_error_classes_creation(): + """Test all error classes can be created.""" + # Test basic error creation + base_error = OAuth20RequestError('Base error') + assert str(base_error) == 'Base error' + + # Test specialized errors + access_error = AccessTokenError('Access error') + assert str(access_error) == 'Access error' + + refresh_error = RefreshTokenError('Refresh error') + assert str(refresh_error) == 'Refresh error' + + revoke_error = RevokeTokenError('Revoke error') + assert str(revoke_error) == 'Revoke error' + + user_error = GetUserInfoError('User info error') + assert str(user_error) == 'User info error' + + httpx_error = HTTPXOAuth20Error('HTTPX error') + assert str(httpx_error) == 'HTTPX error' + + +def test_error_inheritance(): + """Test error class inheritance hierarchy.""" + assert issubclass(AccessTokenError, OAuth20RequestError) + assert issubclass(RefreshTokenError, OAuth20RequestError) + assert issubclass(RevokeTokenError, OAuth20RequestError) + assert issubclass(GetUserInfoError, OAuth20RequestError) + assert issubclass(HTTPXOAuth20Error, OAuth20RequestError) + + +def test_client_endpoint_urls(): + """Test clients have proper endpoint URLs.""" + + # FeiShu + feishu = FeiShuOAuth20('id', 'secret') + assert 'feishu.cn' in feishu.authorize_endpoint + assert 'feishu.cn' in feishu.access_token_endpoint + + # GitHub + github = GitHubOAuth20('id', 'secret') + assert 'github.com' in github.authorize_endpoint + assert 'github.com' in github.access_token_endpoint + + # Google + google = GoogleOAuth20('id', 'secret') + assert 'google.com' in google.authorize_endpoint + assert 'googleapis.com' in google.access_token_endpoint + + # Gitee + gitee = GiteeOAuth20('id', 'secret') + assert 'gitee.com' in gitee.authorize_endpoint + assert 'gitee.com' in gitee.access_token_endpoint + + # OSChina + oschina = OSChinaOAuth20('id', 'secret') + assert 'oschina.net' in oschina.authorize_endpoint + assert 'oschina.net' in oschina.access_token_endpoint + + # Linux.do + linuxdo = LinuxDoOAuth20('id', 'secret') + assert 'linux.do' in linuxdo.authorize_endpoint + assert 'linux.do' in linuxdo.access_token_endpoint + + +def test_client_scopes(): + """Test clients have appropriate default scopes.""" + + feishu = FeiShuOAuth20('id', 'secret') + assert isinstance(feishu.default_scopes, list) + assert len(feishu.default_scopes) > 0 + + github = GitHubOAuth20('id', 'secret') + assert isinstance(github.default_scopes, list) + assert len(github.default_scopes) > 0 + + google = GoogleOAuth20('id', 'secret') + assert isinstance(google.default_scopes, list) + assert len(google.default_scopes) > 0 + + gitee = GiteeOAuth20('id', 'secret') + assert isinstance(gitee.default_scopes, list) + assert len(gitee.default_scopes) > 0 + + oschina = OSChinaOAuth20('id', 'secret') + # OSChina has no default scopes + assert oschina.default_scopes is None + + linuxdo = LinuxDoOAuth20('id', 'secret') + # Linux.do has no default scopes + assert linuxdo.default_scopes is None + + +def test_multiple_clients_independence(): + """Test multiple client instances work independently.""" + feishu1 = FeiShuOAuth20('id1', 'secret1') + feishu2 = FeiShuOAuth20('id2', 'secret2') + + assert feishu1.client_id != feishu2.client_id + assert feishu1.client_secret != feishu2.client_secret + assert feishu1.authorize_endpoint == feishu2.authorize_endpoint + + +def test_google_special_features(): + """Test Google client special features.""" + google = GoogleOAuth20('id', 'secret') + + # Google has both refresh and revoke endpoints + assert google.refresh_token_endpoint is not None + assert google.revoke_token_endpoint is not None + + # Google uses same endpoint for refresh and access + assert google.refresh_token_endpoint == google.access_token_endpoint + + +def test_linuxdo_special_features(): + """Test Linux.do client special features.""" + linuxdo = LinuxDoOAuth20('id', 'secret') + + # Linux.do uses basic auth for token endpoint + assert linuxdo.token_endpoint_basic_auth is True diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..52303ff --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx + +from fastapi_oauth20.errors import ( + AccessTokenError, + GetUserInfoError, + HTTPXOAuth20Error, + OAuth20RequestError, + RefreshTokenError, + RevokeTokenError, +) + + +def test_oauth20_request_error_basic(): + """Test basic OAuth20RequestError creation.""" + error = OAuth20RequestError('Test error') + assert str(error) == 'Test error' + assert error.msg == 'Test error' + + +def test_oauth20_request_error_with_response(): + """Test OAuth20RequestError with HTTP response.""" + mock_response = httpx.Response(400) + error = OAuth20RequestError('Bad request', mock_response) + assert str(error) == 'Bad request' + assert error.msg == 'Bad request' + assert error.response is mock_response + + +def test_httpx_oauth20_error_basic(): + """Test basic HTTPXOAuth20Error creation.""" + error = HTTPXOAuth20Error('HTTP error') + assert str(error) == 'HTTP error' + assert error.msg == 'HTTP error' + + +def test_httpx_oauth20_error_with_response(): + """Test HTTPXOAuth20Error with HTTP response.""" + mock_response = httpx.Response(404) + error = HTTPXOAuth20Error('Not found', mock_response) + assert str(error) == 'Not found' + assert error.msg == 'Not found' + assert error.response is mock_response + + +def test_access_token_error(): + """Test AccessTokenError creation and inheritance.""" + mock_response = httpx.Response(401) + error = AccessTokenError('Invalid token', mock_response) + + assert str(error) == 'Invalid token' + assert isinstance(error, OAuth20RequestError) + assert error.response is mock_response + + +def test_refresh_token_error(): + """Test RefreshTokenError creation and inheritance.""" + mock_response = httpx.Response(401) + error = RefreshTokenError('Invalid refresh token', mock_response) + + assert str(error) == 'Invalid refresh token' + assert isinstance(error, OAuth20RequestError) + assert error.response is mock_response + + +def test_revoke_token_error(): + """Test RevokeTokenError creation and inheritance.""" + mock_response = httpx.Response(400) + error = RevokeTokenError('Revocation failed', mock_response) + + assert str(error) == 'Revocation failed' + assert isinstance(error, OAuth20RequestError) + assert error.response is mock_response + + +def test_get_userinfo_error(): + """Test GetUserInfoError creation and inheritance.""" + mock_response = httpx.Response(403) + error = GetUserInfoError('Access denied', mock_response) + + assert str(error) == 'Access denied' + assert isinstance(error, OAuth20RequestError) + assert error.response is mock_response + + +def test_error_inheritance_chain(): + """Test that all OAuth2 errors have proper inheritance.""" + assert issubclass(AccessTokenError, OAuth20RequestError) + assert issubclass(RefreshTokenError, OAuth20RequestError) + assert issubclass(RevokeTokenError, OAuth20RequestError) + assert issubclass(GetUserInfoError, OAuth20RequestError) + + assert issubclass(HTTPXOAuth20Error, OAuth20RequestError) + assert issubclass(OAuth20RequestError, Exception) + + +def test_error_without_response(): + """Test error creation without HTTP response.""" + error = AccessTokenError('Simple error') + assert str(error) == 'Simple error' + assert error.msg == 'Simple error' + assert not hasattr(error, 'response') or error.response is None + + +def test_error_catch_hierarchy(): + """Test that errors can be caught at different levels of hierarchy.""" + mock_response = httpx.Response(400) + + # Specific error type + try: + raise AccessTokenError('Access token error', mock_response) + except AccessTokenError as e: + assert str(e) == 'Access token error' + + # Parent OAuth20RequestError type + try: + raise RefreshTokenError('Refresh token error', mock_response) + except OAuth20RequestError as e: + assert str(e) == 'Refresh token error' + + # HTTPXOAuth20Error type + try: + raise HTTPXOAuth20Error('HTTPX error', mock_response) + except HTTPXOAuth20Error as e: + assert str(e) == 'HTTPX error' + + +def test_error_properties(): + """Test that error objects have expected properties.""" + mock_response = httpx.Response(500) + + error = RevokeTokenError('Server error', mock_response) + assert hasattr(error, 'msg') + assert hasattr(error, 'response') + assert error.msg == 'Server error' + assert error.response == mock_response + + +def test_error_str_representation(): + """Test string representation of errors.""" + # Error without response + error1 = AccessTokenError('Simple message') + assert str(error1) == 'Simple message' + + # Error with response + mock_response = httpx.Response(404) + error2 = GetUserInfoError('User not found', mock_response) + assert str(error2) == 'User not found' + + +def test_error_with_complex_message(): + """Test errors with complex or multi-line messages.""" + complex_message = "Error: Invalid request\nDetails: Missing required parameter 'code'" + error = OAuth20RequestError(complex_message) + assert str(error) == complex_message diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py new file mode 100644 index 0000000..7010f42 --- /dev/null +++ b/tests/test_fastapi.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import httpx +import pytest +import respx + +from fastapi import Depends, FastAPI, Request +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + +from fastapi_oauth20 import FastAPIOAuth20, OAuth20AuthorizeCallbackError +from tests.conftest import ( + OAUTH_PROVIDERS, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + TEST_STATE, + assert_oauth_error_response, + mock_oauth_token_response, + setup_oauth_callback_route, +) + +# Integration test specific constants +LOCALHOST_URL = 'http://localhost:8000' +DEV_URL = 'http://dev.example.com' +APP_URL = 'http://app.example.org' +IP_URL = 'http://192.168.1.100:8080' +SECURE_LOCALHOST_URL = 'https://localhost:8000' +SECURE_DEV_URL = 'https://secure.example.com' +SECURE_APP_URL = 'https://app.example.org' +SECURE_IP_URL = 'https://192.168.1.100:8443' + +# Callback paths +AUTH_PATH = '/auth' +CALLBACK_PATH = '/callback' +OAUTH_CALLBACK_PATH = '/oauth2/callback' + +# Test URIs +HTTP_URIS = [ + f'{LOCALHOST_URL}{CALLBACK_PATH}', + f'{DEV_URL}{AUTH_PATH}/callback', + f'{APP_URL}{OAUTH_CALLBACK_PATH}', + f'{IP_URL}{AUTH_PATH}', +] + +HTTPS_URIS = [ + f'{SECURE_LOCALHOST_URL}{CALLBACK_PATH}', + f'{SECURE_DEV_URL}{AUTH_PATH}/callback', + f'{SECURE_APP_URL}{OAUTH_CALLBACK_PATH}', + f'{SECURE_IP_URL}{AUTH_PATH}', +] + + +class TestFastAPIOAuth20Basic: + """Basic tests for FastAPI OAuth2 integration using parametrized providers.""" + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + async def test_successful_callback_parametrized(self, provider_config, fastapi_app): + """Test successful OAuth2 callback with multiple providers.""" + # Mock token exchange + mock_oauth_token_response(respx, provider_config) + + # Create OAuth client and dependency + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + + # Setup callback route + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + # Test successful callback + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}') + + assert response.status_code == 200 + data = response.json() + assert data['provider'] == provider_config['name'] + assert data['access_token']['access_token'] == TEST_ACCESS_TOKEN + + @pytest.mark.asyncio + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + async def test_callback_missing_code_parametrized(self, provider_config, fastapi_app): + """Test OAuth2 callback with missing authorization code.""" + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?state={TEST_STATE}') + + assert_oauth_error_response(response, 'Bad Request') + + @pytest.mark.asyncio + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + async def test_callback_with_error_parametrized(self, provider_config, fastapi_app): + """Test OAuth2 callback with error parameter.""" + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?error=access_denied&state={TEST_STATE}') + + assert_oauth_error_response(response, 'access_denied') + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + async def test_callback_token_exchange_error_parametrized(self, provider_config, fastapi_app): + """Test OAuth2 callback with token exchange error.""" + # Mock token exchange error + respx.post(provider_config['token_url']).mock(return_value=httpx.Response(400, text='Bad Request')) + + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=invalid_code&state={TEST_STATE}') + + assert response.status_code == 500 + assert 'detail' in response.json() + + def test_custom_exception_handler(self, github_client, fastapi_app): + """Test custom exception handler for OAuth2 errors.""" + github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback') + + @fastapi_app.get('/auth/github/callback') + async def github_callback(access_token_state=Depends(github_oauth2_callback)): + token, state = access_token_state + return {'access_token': token, 'state': state} + + @fastapi_app.exception_handler(OAuth20AuthorizeCallbackError) + async def oauth2_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError): + return JSONResponse( + status_code=exc.status_code, + content={ + 'message': 'OAuth2 authentication failed', + 'error': exc.detail, + 'status_code': exc.status_code, + }, + ) + + client = TestClient(fastapi_app) + response = client.get('/auth/github/callback?error=access_denied') + + assert response.status_code == 400 + data = response.json() + assert data['message'] == 'OAuth2 authentication failed' + assert data['error'] == 'access_denied' + assert data['status_code'] == 400 + + def test_multiple_oauth_providers(self, github_client, google_client, fastapi_app): + """Test multiple OAuth providers in the same app.""" + # Setup GitHub OAuth + github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback') + + @fastapi_app.get('/auth/github/callback') + async def github_callback(access_token_state=Depends(github_oauth2_callback)): + token, state = access_token_state + return {'provider': 'github', 'access_token': token, 'state': state} + + # Setup Google OAuth + google_oauth2_callback = FastAPIOAuth20(google_client, redirect_uri=f'{LOCALHOST_URL}/auth/google/callback') + + @fastapi_app.get('/auth/google/callback') + async def google_callback(access_token_state=Depends(google_oauth2_callback)): + token, state = access_token_state + return {'provider': 'google', 'access_token': token, 'state': state} + + client = TestClient(fastapi_app) + + # Test GitHub route + response = client.get('/auth/github/callback?error=access_denied') + assert response.status_code == 400 + + # Test Google route + response = client.get('/auth/google/callback?error=access_denied') + assert response.status_code == 400 + + +class TestFastAPIOAuth20PKCE: + """Test PKCE functionality across providers.""" + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + async def test_pkce_flow(self, provider_config, fastapi_app): + """Test PKCE flow with code verifier and challenge.""" + # Mock token exchange + mock_oauth_token_response(respx, provider_config) + + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}&code_verifier=test_verifier') + + assert response.status_code == 200 + data = response.json() + assert data['provider'] == provider_config['name'] + + +class TestFastAPIOAuth20URIs: + """Test URI validation and functionality.""" + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + @pytest.mark.parametrize('redirect_uri', HTTP_URIS) + async def test_http_redirect_uris(self, provider_config, redirect_uri, fastapi_app): + """Test OAuth2 with HTTP redirect URIs.""" + # Mock token exchange + mock_oauth_token_response(respx, provider_config) + + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=redirect_uri) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}') + + assert response.status_code == 200 + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + @pytest.mark.parametrize('redirect_uri', HTTPS_URIS) + async def test_https_redirect_uris(self, provider_config, redirect_uri, fastapi_app): + """Test OAuth2 with HTTPS redirect URIs.""" + # Mock token exchange + mock_oauth_token_response(respx, provider_config) + + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=redirect_uri) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}') + + assert response.status_code == 200 + + +class TestFastAPIOAuth20ErrorScenarios: + """Test various error scenarios across providers.""" + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + @pytest.mark.parametrize('error_code', ['access_denied', 'temporarily_unavailable', 'invalid_request']) + async def test_oauth_errors(self, provider_config, error_code, fastapi_app): + """Test various OAuth error codes.""" + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?error={error_code}&state={TEST_STATE}') + + assert_oauth_error_response(response, error_code) + + @pytest.mark.asyncio + @respx.mock + @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS) + @pytest.mark.parametrize('http_status', [400, 401, 500, 502]) + async def test_token_exchange_errors(self, provider_config, http_status, fastapi_app): + """Test token exchange with various HTTP error codes.""" + # Mock token exchange error + respx.post(provider_config['token_url']).mock( + return_value=httpx.Response(http_status, text=f'Error {http_status}') + ) + + client_class = provider_config['client_class'] + oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri']) + callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback) + + client = TestClient(fastapi_app) + response = client.get(f'{callback_path}?code=invalid_code&state={TEST_STATE}') + + # Should return server error for token exchange failures + assert response.status_code >= 400 + assert 'detail' in response.json() + + +# Additional test classes for different scenarios +class TestFastAPIOAuth20Integration: + """Integration tests for FastAPI OAuth2.""" + + def test_oauth_dependency_creation(self, github_client): + """Test OAuth dependency creation with different parameters.""" + # Basic OAuth dependency + oauth_dep = FastAPIOAuth20(github_client) + assert oauth_dep.client == github_client + + # OAuth dependency with custom redirect URI + custom_redirect = f'{LOCALHOST_URL}/custom/callback' + oauth_dep_custom = FastAPIOAuth20(github_client, redirect_uri=custom_redirect) + assert oauth_dep_custom.client == github_client + + def test_multiple_apps_same_provider(self, github_client): + """Test the same OAuth provider in multiple FastAPI apps.""" + app1 = FastAPI() + app2 = FastAPI() + + # Setup first app + oauth_dep1 = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/app1/callback') + + @app1.get('/auth/github/callback') + async def github_callback1(access_token_state=Depends(oauth_dep1)): + return {'app': 'app1', 'access_token': access_token_state} + + # Setup second app + oauth_dep2 = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/app2/callback') + + @app2.get('/auth/github/callback') + async def github_callback2(access_token_state=Depends(oauth_dep2)): + return {'app': 'app2', 'access_token': access_token_state} + + client1 = TestClient(app1) + client2 = TestClient(app2) + + # Both apps should work independently + response1 = client1.get('/auth/github/callback?error=access_denied') + response2 = client2.get('/auth/github/callback?error=access_denied') + + assert response1.status_code == 400 + assert response2.status_code == 400 diff --git a/tests/test_oauth20.py b/tests/test_oauth20.py new file mode 100644 index 0000000..47e2d07 --- /dev/null +++ b/tests/test_oauth20.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import json + +from unittest.mock import Mock + +import httpx +import pytest +import respx + +from fastapi_oauth20.errors import ( + AccessTokenError, + HTTPXOAuth20Error, + RefreshTokenError, + RevokeTokenError, +) +from fastapi_oauth20.oauth20 import OAuth20Base + + +class MockOAuth20Client(OAuth20Base): + """Test implementation of OAuth20Base for testing purposes.""" + + async def get_userinfo(self, access_token: str) -> dict[str, any]: + """Mock implementation for testing.""" + return {'user_id': 'test_user', 'access_token': access_token} + + +@pytest.fixture +def oauth_client(): + """Create OAuth20Base client instance for testing.""" + return MockOAuth20Client( + client_id='test_client_id', + client_secret='test_client_secret', + authorize_endpoint='https://example.com/oauth/authorize', + access_token_endpoint='https://example.com/oauth/token', + refresh_token_endpoint='https://example.com/oauth/refresh', + revoke_token_endpoint='https://example.com/oauth/revoke', + default_scopes=['read', 'write'], + ) + + +def test_oauth_base_initialization(oauth_client): + """Test OAuth20Base initialization with all parameters.""" + assert oauth_client.client_id == 'test_client_id' + assert oauth_client.client_secret == 'test_client_secret' + assert oauth_client.authorize_endpoint == 'https://example.com/oauth/authorize' + assert oauth_client.access_token_endpoint == 'https://example.com/oauth/token' + assert oauth_client.refresh_token_endpoint == 'https://example.com/oauth/refresh' + assert oauth_client.revoke_token_endpoint == 'https://example.com/oauth/revoke' + assert oauth_client.default_scopes == ['read', 'write'] + assert oauth_client.token_endpoint_basic_auth is False + assert oauth_client.revoke_token_endpoint_basic_auth is False + assert oauth_client.request_headers == {'Accept': 'application/json'} + + +def test_oauth_base_initialization_minimal(): + """Test OAuth20Base initialization with minimal required parameters.""" + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + ) + + assert client.client_id == 'test_id' + assert client.client_secret == 'test_secret' + assert client.authorize_endpoint == 'https://example.com/auth' + assert client.access_token_endpoint == 'https://example.com/token' + assert client.refresh_token_endpoint is None + assert client.revoke_token_endpoint is None + assert client.default_scopes is None + + +def test_oauth_base_initialization_with_basic_auth(): + """Test OAuth20Base initialization with basic authentication enabled.""" + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + token_endpoint_basic_auth=True, + revoke_token_endpoint_basic_auth=True, + ) + + assert client.token_endpoint_basic_auth is True + assert client.revoke_token_endpoint_basic_auth is True + + +@pytest.mark.asyncio +async def test_get_authorization_url_basic(oauth_client): + """Test basic authorization URL generation.""" + url = await oauth_client.get_authorization_url(redirect_uri='https://example.com/callback') + + assert 'https://example.com/oauth/authorize' in url + assert 'client_id=test_client_id' in url + assert 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback' in url + assert 'response_type=code' in url + assert 'scope=read+write' in url + + +@pytest.mark.asyncio +async def test_get_authorization_url_with_state(oauth_client): + """Test authorization URL generation with state parameter.""" + url = await oauth_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='random_state_123' + ) + + assert 'state=random_state_123' in url + + +@pytest.mark.asyncio +async def test_get_authorization_url_with_custom_scope(oauth_client): + """Test authorization URL generation with custom scope.""" + url = await oauth_client.get_authorization_url( + redirect_uri='https://example.com/callback', scope=['read', 'delete'] + ) + + assert 'scope=read+delete' in url + assert 'write' not in url + + +@pytest.mark.asyncio +async def test_get_authorization_url_with_pkce(oauth_client): + """Test authorization URL generation with PKCE parameters.""" + url = await oauth_client.get_authorization_url( + redirect_uri='https://example.com/callback', code_challenge='challenge_123', code_challenge_method='S256' + ) + + assert 'code_challenge=challenge_123' in url + assert 'code_challenge_method=S256' in url + + +@pytest.mark.asyncio +async def test_get_authorization_url_with_extra_params(oauth_client): + """Test authorization URL generation with additional parameters.""" + url = await oauth_client.get_authorization_url( + redirect_uri='https://example.com/callback', access_type='offline', prompt='consent' + ) + + assert 'access_type=offline' in url + assert 'prompt=consent' in url + + +@pytest.mark.asyncio +@respx.mock +async def test_get_access_token_success(oauth_client): + """Test successful access token exchange.""" + mock_token_data = { + 'access_token': 'new_access_token', + 'token_type': 'Bearer', + 'expires_in': 3600, + 'refresh_token': 'refresh_token_123', + } + + # Mock the token endpoint + respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data)) + + result = await oauth_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') + assert result == mock_token_data + + +@pytest.mark.asyncio +@respx.mock +async def test_get_access_token_with_code_verifier(oauth_client): + """Test access token exchange with PKCE code verifier.""" + mock_token_data = {'access_token': 'new_access_token'} + + # Mock the token endpoint and capture the request + route = respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data)) + + await oauth_client.get_access_token( + code='auth_code_123', redirect_uri='https://example.com/callback', code_verifier='verifier_123' + ) + + # Verify the request was made with code_verifier + assert route.called + request_data = route.calls[0].request.content.decode() + assert 'code_verifier=verifier_123' in request_data + + +@pytest.mark.asyncio +@respx.mock +async def test_get_access_token_with_basic_auth(): + """Test access token exchange with HTTP Basic Authentication.""" + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + token_endpoint_basic_auth=True, + ) + + mock_token_data = {'access_token': 'new_access_token'} + + # Mock the token endpoint + route = respx.post('https://example.com/token').mock(return_value=httpx.Response(200, json=mock_token_data)) + + await client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') + + # Verify BasicAuth was used + assert route.called + request = route.calls[0].request + assert 'authorization' in request.headers + # Basic auth should be present + assert request.headers['authorization'].startswith('Basic ') + + +@pytest.mark.asyncio +@respx.mock +async def test_get_access_token_http_error(oauth_client): + """Test handling of HTTP errors during access token exchange.""" + # Mock HTTP error response + respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(400, text='Bad Request')) + + with pytest.raises(HTTPXOAuth20Error): + await oauth_client.get_access_token(code='invalid_code', redirect_uri='https://example.com/callback') + + +@pytest.mark.asyncio +@respx.mock +async def test_refresh_token_success(oauth_client): + """Test successful token refresh.""" + mock_token_data = {'access_token': 'refreshed_access_token', 'token_type': 'Bearer', 'expires_in': 3600} + + # Mock the refresh endpoint + respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(200, json=mock_token_data)) + + result = await oauth_client.refresh_token('refresh_token_123') + assert result == mock_token_data + + +@pytest.mark.asyncio +async def test_refresh_token_missing_endpoint(): + """Test refresh token when refresh endpoint is not configured.""" + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + ) + + with pytest.raises(RefreshTokenError, match='refresh token address is missing'): + await client.refresh_token('refresh_token_123') + + +@pytest.mark.asyncio +@respx.mock +async def test_refresh_token_http_error(oauth_client): + """Test handling of HTTP errors during token refresh.""" + # Mock HTTP error response + respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(401, text='Unauthorized')) + + with pytest.raises(HTTPXOAuth20Error): + await oauth_client.refresh_token('invalid_refresh_token') + + +@pytest.mark.asyncio +@respx.mock +async def test_revoke_token_success(oauth_client): + """Test successful token revocation.""" + # Mock successful revocation response + respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK')) + + # Should not raise any exception for successful revocation + await oauth_client.revoke_token('access_token_123') + + +@pytest.mark.asyncio +@respx.mock +async def test_revoke_token_with_type_hint(oauth_client): + """Test token revocation with token type hint.""" + # Mock the revoke endpoint and capture the request + route = respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK')) + + await oauth_client.revoke_token(token='refresh_token_123', token_type_hint='refresh_token') + + # Verify token_type_hint was included in the request + assert route.called + request_data = route.calls[0].request.content.decode() + assert 'token_type_hint=refresh_token' in request_data + + +@pytest.mark.asyncio +async def test_revoke_token_missing_endpoint(): + """Test token revocation when revoke endpoint is not configured.""" + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + ) + + with pytest.raises(RevokeTokenError, match='revoke token address is missing'): + await client.revoke_token('access_token_123') + + +@pytest.mark.asyncio +@respx.mock +async def test_revoke_token_http_error(oauth_client): + """Test handling of HTTP errors during token revocation.""" + # Mock HTTP error response + respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(400, text='Bad Request')) + + with pytest.raises(HTTPXOAuth20Error): + await oauth_client.revoke_token('invalid_token') + + +def test_raise_httpx_oauth20_errors_success(): + """Test successful HTTP response validation.""" + mock_response = Mock() + mock_response.raise_for_status.return_value = None + + # Should not raise any exception + OAuth20Base.raise_httpx_oauth20_errors(mock_response) + + +def test_raise_httpx_oauth20_errors_http_status_error(): + """Test handling of HTTP status errors.""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + 'Not Found', request=None, response=mock_response + ) + + with pytest.raises(HTTPXOAuth20Error): + OAuth20Base.raise_httpx_oauth20_errors(mock_response) + + +def test_raise_httpx_oauth20_errors_network_error(): + """Test handling of network errors.""" + # Test with a mock response that will raise RequestError when raise_for_status is called + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.RequestError('Network error') + + with pytest.raises(HTTPXOAuth20Error): + OAuth20Base.raise_httpx_oauth20_errors(mock_response) + + +def test_get_json_result_success(): + """Test successful JSON result parsing.""" + mock_response = Mock() + mock_response.json.return_value = {'key': 'value'} + + result = OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError) + assert result == {'key': 'value'} + + +def test_get_json_result_invalid_json(): + """Test handling of invalid JSON response.""" + mock_response = Mock() + mock_response.json.side_effect = json.JSONDecodeError('Invalid JSON', '', 0) + + with pytest.raises(AccessTokenError, match='Result serialization failed'): + OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError) + + +def test_abstract_method(): + """Test that get_userinfo is properly abstract.""" + with pytest.raises(TypeError): + OAuth20Base( + client_id='test', + client_secret='test', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + ) + + +def test_oauth_base_inheritance(): + """Test that OAuth20Base is properly abstract.""" + from abc import ABC + + assert issubclass(OAuth20Base, ABC) + + +@pytest.mark.asyncio +async def test_get_userinfo_implementation(): + """Test that concrete implementation of get_userinfo works.""" + client = MockOAuth20Client( + client_id='test', + client_secret='test', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + ) + + result = await client.get_userinfo('test_token') + assert result == {'user_id': 'test_user', 'access_token': 'test_token'} From 1bf563284307e907d82ae7047cda0c93dc06e3a8 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Sat, 1 Nov 2025 19:13:12 +0800 Subject: [PATCH 3/3] Fix the revoke token --- fastapi_oauth20/oauth20.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fastapi_oauth20/oauth20.py b/fastapi_oauth20/oauth20.py index 881dafc..9b8ac91 100644 --- a/fastapi_oauth20/oauth20.py +++ b/fastapi_oauth20/oauth20.py @@ -184,7 +184,13 @@ async def revoke_token(self, token: str, token_type_hint: str | None = None) -> if token_type_hint is not None: data.update({'token_type_hint': token_type_hint}) - async with httpx.AsyncClient() as client: + auth = None + if not self.revoke_token_endpoint_basic_auth: + data.update({'client_id': self.client_id, 'client_secret': self.client_secret}) + else: + auth = httpx.BasicAuth(self.client_id, self.client_secret) + + async with httpx.AsyncClient(auth=auth) as client: response = await client.post( self.revoke_token_endpoint, data=data,