diff --git a/docs/advanced.md b/docs/advanced.md
deleted file mode 100644
index 6096c2a..0000000
--- a/docs/advanced.md
+++ /dev/null
@@ -1 +0,0 @@
-TODO...
diff --git a/docs/explanation.md b/docs/explanation.md
index 8ba6c61..4aab8c6 100644
--- a/docs/explanation.md
+++ b/docs/explanation.md
@@ -10,7 +10,7 @@
code 等参数
- ==state== 用于在请求和回调之间维护状态,主要用于防止跨站请求伪造(CSRF)攻击
- ==source== 支持的第三方客户端,比如:GitHub、LinuxDo 等
-- ==sid== 第三方客户端的用户 ID。以下是关于各平台的 sid 存储逻辑:
+- ==sid== 第三方客户端的用户 ID
!!! warning
diff --git a/docs/index.md b/docs/index.md
index f435039..c1423d9 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,22 +1,61 @@
-
-
🔐
-
+# FastAPI OAuth 2.0
----
+在 FastAPI 中异步授权 OAuth 2.0 客户端
-**Documentation
-**: https://fastapi-practices.github.io/fastapi-oauth20
+## Features
-**Source Code
-**: https://github.com/fastapi-practices/fastapi-oauth20
+- **异步支持** - 使用 async/await 构建以获得最佳性能
+- **令牌管理** - 内置令牌刷新和吊销
+- **FastAPI 集成** - 用于回调处理的无缝依赖注入
+- **类型安全** - 完整类型提示
+- **错误处理** - OAuth2 错误的综合异常层次结构
----
+## Quick Start
-在 FastAPI 中异步授权 OAuth 2.0 客户端
+### Installation
+
+```bash
+pip install fastapi-oauth20
+```
+
+### Basic Usage
+
+```python
+from fastapi import FastAPI, Depends
+from fastapi_oauth20 import GitHubOAuth20, FastAPIOAuth20
+from fastapi.responses import RedirectResponse
+import secrets
+
+app = FastAPI()
+
+# 定义重定向地址
+redirect_uri = "http://localhost:8000/auth/github/callback"
+
+# 初始化 GitHub OAuth2 客户端
+github_client = GitHubOAuth20(
+ client_id="your_github_client_id",
+ client_secret="your_github_client_secret"
+)
+
+# 创建 FastAPI OAuth2 依赖项
+github_oauth = FastAPIOAuth20(
+ client=github_client,
+ redirect_uri=redirect_uri
+)
+
+
+@app.get("/auth/github")
+async def github_auth():
+ auth_url = await github_client.get_authorization_url(redirect_uri=redirect_uri)
+ return RedirectResponse(url=auth_url)
-我们的目标是集成多个 CN 第三方客户端,敬请期待(🐦)...
-你可以在 [客户端状态](status.md) 获取当前集成情况
+@app.get("/auth/github/callback")
+async def github_callback(oauth_result: tuple = Depends(github_oauth)):
+ token_data, state = oauth_result
+ user_info = await github_client.get_userinfo(token_data["access_token"])
+ return {"user": user_info}
+```
## 互动
diff --git a/docs/install.md b/docs/install.md
index e39d606..98eab38 100644
--- a/docs/install.md
+++ b/docs/install.md
@@ -18,9 +18,3 @@
```sh
uv add fastapi-oauth20
```
-
-=== ":simple-pdm: pdm"
-
- ```sh
- pdm add fastapi-oauth20
- ```
diff --git a/docs/status.md b/docs/status.md
index 0bba581..ba32df0 100644
--- a/docs/status.md
+++ b/docs/status.md
@@ -4,7 +4,7 @@
对于强制要求【实名 + 人脸认证】的平台,植入变得困难,所以它们不会很快到来
-## END
+## FINISHED
- [x] [LinuxDo](clients/linuxdo.md)
- [x] [GitHub](clients/github.md)
diff --git a/docs/usage.md b/docs/usage.md
index 6096c2a..9ac34ad 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -1 +1,278 @@
-TODO...
+# 使用指南
+
+本指南介绍如何将 FastAPI OAuth2.0 库与各种 OAuth2 提供程序一起使用。
+
+## 基本用法
+
+### 1. 选择 OAuth2 提供商并初始化客户端
+
+```python
+from fastapi_oauth20 import GitHubOAuth20, GoogleOAuth20, FastAPIOAuth20
+from fastapi import FastAPI, Depends
+from fastapi.responses import RedirectResponse
+import secrets
+
+app = FastAPI()
+
+# 初始化 GitHub OAuth2 客户端
+github_client = GitHubOAuth20(
+ client_id="your_github_client_id",
+ client_secret="your_github_client_secret"
+)
+
+# 初始化 Google OAuth2 客户端
+google_client = GoogleOAuth20(
+ client_id="your_google_client_id",
+ client_secret="your_google_client_secret"
+)
+```
+
+### 2. 创建 FastAPI OAuth2 依赖
+
+```python
+# 创建 FastAPI OAuth2 依赖
+github_oauth = FastAPIOAuth20(
+ client=github_client,
+ redirect_uri="http://localhost:8000/auth/github/callback"
+)
+
+google_oauth = FastAPIOAuth20(
+ client=google_client,
+ redirect_uri="http://localhost:8000/auth/google/callback"
+)
+```
+
+### 3. 定义授权端点
+
+```python
+@app.get("/auth/{provider}")
+async def auth_provider(provider: str):
+ """重定向用户到 OAuth2 提供商进行授权"""
+
+ # 生成安全的 state 参数用于 CSRF 保护
+ state = secrets.token_urlsafe(32)
+
+ if provider == "github":
+ auth_url = await github_client.get_authorization_url(
+ redirect_uri="http://localhost:8000/auth/github/callback",
+ state=state
+ )
+ elif provider == "google":
+ auth_url = await google_client.get_authorization_url(
+ redirect_uri="http://localhost:8000/auth/google/callback",
+ state=state
+ )
+ else:
+ raise HTTPException(status_code=404, detail="不支持的提供商")
+
+ return RedirectResponse(url=auth_url)
+```
+
+### 4. 处理 OAuth 回调
+
+```python
+from typing import Tuple, Dict, Any
+
+
+@app.get("/auth/github/callback")
+async def github_callback(
+ oauth_result: Tuple[Dict[str, Any], str] = Depends(github_oauth)
+):
+ """处理 GitHub OAuth 回调"""
+ token_data, state = oauth_result
+
+ # 获取用户信息
+ user_info = await github_client.get_userinfo(token_data["access_token"])
+
+ return {
+ "user": user_info,
+ "access_token": token_data["access_token"],
+ "state": state
+ }
+
+
+@app.get("/auth/google/callback")
+async def google_callback(
+ oauth_result: Tuple[Dict[str, Any], str] = Depends(google_oauth)
+):
+ """处理 Google OAuth 回调"""
+ token_data, state = oauth_result
+
+ # 获取用户信息
+ user_info = await google_client.get_userinfo(token_data["access_token"])
+
+ return {
+ "user": user_info,
+ "access_token": token_data["access_token"],
+ "state": state
+ }
+```
+
+## 高级用法
+
+### PKCE (Proof Key for Code Exchange)
+
+对于公共客户端(移动应用、SPA),使用 PKCE 增强安全性:
+
+```python
+import base64
+import hashlib
+import secrets
+
+
+def generate_pkce_challenge():
+ """生成 PKCE 代码验证器和挑战码"""
+ code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
+ code_challenge = base64.urlsafe_b64encode(
+ hashlib.sha256(code_verifier.encode('utf-8')).digest()
+ ).decode('utf-8').rstrip('=')
+ return code_verifier, code_challenge
+
+
+@app.get("/auth/github/pkce")
+async def github_auth_pkce():
+ """GitHub OAuth with PKCE"""
+ code_verifier, code_challenge = generate_pkce_challenge()
+ state = secrets.token_urlsafe(32)
+
+ # 在实际应用中,应该将 code_verifier 和 state 存储在会话或数据库中
+ # 这里为了演示目的简化处理
+
+ auth_url = await github_client.get_authorization_url(
+ redirect_uri="http://localhost:8000/auth/github/callback",
+ state=state,
+ code_challenge=code_challenge,
+ code_challenge_method="S256"
+ )
+
+ return RedirectResponse(url=auth_url)
+```
+
+### 令牌刷新
+
+某些提供商支持刷新令牌来延长会话:
+
+```python
+@app.post("/auth/refresh")
+async def refresh_token(refresh_token: str, provider: str):
+ """使用刷新令牌获取新的访问令牌"""
+
+ if provider == "github":
+ # GitHub 不支持 OAuth 刷新令牌
+ raise HTTPException(status_code=400, detail="GitHub 不支持令牌刷新")
+ elif provider == "google":
+ try:
+ new_tokens = await google_client.refresh_token(refresh_token)
+ return {"access_token": new_tokens["access_token"]}
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+```
+
+### 令牌撤销
+
+用户登出时撤销令牌:
+
+```python
+@app.post("/auth/revoke")
+async def revoke_token(access_token: str, provider: str):
+ """撤销访问令牌"""
+
+ if provider == "google":
+ try:
+ await google_client.revoke_token(access_token)
+ return {"message": "令牌撤销成功"}
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ else:
+ raise HTTPException(status_code=400, detail="该提供商不支持令牌撤销")
+```
+
+## 错误处理
+
+库提供了全面的错误处理:
+
+```python
+from fastapi_oauth20.errors import (
+ OAuth20AuthorizeCallbackError,
+ AccessTokenError,
+ GetUserInfoError
+)
+
+
+@app.exception_handler(OAuth20AuthorizeCallbackError)
+async def oauth_callback_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError):
+ """处理 OAuth 回调错误"""
+ return {
+ "error": "OAuth 授权失败",
+ "detail": exc.detail,
+ "status_code": exc.status_code
+ }
+
+
+@app.exception_handler(AccessTokenError)
+async def access_token_error_handler(request: Request, exc: AccessTokenError):
+ """处理访问令牌错误"""
+ return {
+ "error": "访问令牌交换失败",
+ "detail": exc.msg
+ }
+```
+
+## 完整示例
+
+```python
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.responses import RedirectResponse
+from fastapi_oauth20 import GitHubOAuth20, FastAPIOAuth20
+from fastapi_oauth20.errors import OAuth20AuthorizeCallbackError
+import secrets
+
+app = FastAPI()
+
+# 初始化客户端
+github_client = GitHubOAuth20(
+ client_id="your_client_id",
+ client_secret="your_client_secret"
+)
+
+# 创建依赖
+github_oauth = FastAPIOAuth20(
+ client=github_client,
+ redirect_uri="http://localhost:8000/auth/github/callback"
+)
+
+
+@app.get("/auth/github")
+async def github_auth():
+ """GitHub 授权入口"""
+ state = secrets.token_urlsafe(32)
+ auth_url = await github_client.get_authorization_url(
+ redirect_uri="http://localhost:8000/auth/github/callback",
+ state=state
+ )
+ return RedirectResponse(url=auth_url)
+
+
+@app.get("/auth/github/callback")
+async def github_callback(
+ oauth_result: tuple = Depends(github_oauth)
+):
+ """GitHub 授权回调"""
+ token_data, state = oauth_result
+ user_info = await github_client.get_userinfo(token_data["access_token"])
+ return {"user": user_info}
+
+
+# 错误处理
+@app.exception_handler(OAuth20AuthorizeCallbackError)
+async def oauth_error_handler(request, exc):
+ return {"error": "授权失败", "detail": exc.detail}
+```
+
+## 注意事项
+
+1. **安全性**: 始终使用 HTTPS 端点
+2. **状态管理**: 使用安全的 state 参数防止 CSRF 攻击
+3. **令牌存储**: 安全地存储访问令牌和刷新令牌
+4. **错误处理**: 妥善处理各种 OAuth2 错误场景
+5. **作用域**: 只请求必要的作用域,遵循最小权限原则
diff --git a/fastapi_oauth20/__init__.py b/fastapi_oauth20/__init__.py
index d1ea7b6..9112766 100644
--- a/fastapi_oauth20/__init__.py
+++ b/fastapi_oauth20/__init__.py
@@ -7,5 +7,6 @@
from .clients.linuxdo import LinuxDoOAuth20 as LinuxDoOAuth20
from .clients.oschina import OSChinaOAuth20 as OSChinaOAuth20
from .integrations.fastapi import FastAPIOAuth20 as FastAPIOAuth20
+from .integrations.fastapi import OAuth20AuthorizeCallbackError as OAuth20AuthorizeCallbackError
__version__ = '0.0.1'
diff --git a/fastapi_oauth20/clients/feishu.py b/fastapi_oauth20/clients/feishu.py
index 95ed41f..cd4b001 100644
--- a/fastapi_oauth20/clients/feishu.py
+++ b/fastapi_oauth20/clients/feishu.py
@@ -7,7 +7,15 @@
class FeiShuOAuth20(OAuth20Base):
+ """FeiShu (Lark) OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize FeiShu OAuth2 client.
+
+ :param client_id: FeiShu app client ID from the FeiShu developer console.
+ :param client_secret: FeiShu app client secret from the FeiShu developer console.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -22,7 +30,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from FeiShu"""
+ """
+ Retrieve user information from FeiShu API.
+
+ :param access_token: Valid FeiShu access token with contact:user scopes.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient() as client:
response = await client.get('https://passport.feishu.cn/suite/passport/oauth/userinfo', headers=headers)
diff --git a/fastapi_oauth20/clients/gitee.py b/fastapi_oauth20/clients/gitee.py
index 8265e28..eac8c0a 100644
--- a/fastapi_oauth20/clients/gitee.py
+++ b/fastapi_oauth20/clients/gitee.py
@@ -7,7 +7,15 @@
class GiteeOAuth20(OAuth20Base):
+ """Gitee OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize Gitee OAuth2 client.
+
+ :param client_id: Gitee OAuth application client ID.
+ :param client_secret: Gitee OAuth application client secret.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -18,7 +26,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from Gitee"""
+ """
+ Retrieve user information from Gitee API.
+
+ :param access_token: Valid Gitee access token with user_info scope.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient() as client:
response = await client.get('https://gitee.com/api/v5/user', headers=headers)
diff --git a/fastapi_oauth20/clients/github.py b/fastapi_oauth20/clients/github.py
index 40bc6ce..49b95c9 100644
--- a/fastapi_oauth20/clients/github.py
+++ b/fastapi_oauth20/clients/github.py
@@ -7,7 +7,15 @@
class GitHubOAuth20(OAuth20Base):
+ """GitHub OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize GitHub OAuth2 client.
+
+ :param client_id: GitHub OAuth App client ID.
+ :param client_secret: GitHub OAuth App client secret.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -17,7 +25,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from GitHub"""
+ """
+ Retrieve user information from GitHub API.
+
+ :param access_token: Valid GitHub access token with appropriate scopes.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient(headers=headers) as client:
response = await client.get('https://api.github.com/user')
diff --git a/fastapi_oauth20/clients/google.py b/fastapi_oauth20/clients/google.py
index 8f5c151..0a7a148 100644
--- a/fastapi_oauth20/clients/google.py
+++ b/fastapi_oauth20/clients/google.py
@@ -7,7 +7,15 @@
class GoogleOAuth20(OAuth20Base):
+ """Google OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize Google OAuth2 client.
+
+ :param client_id: Google OAuth 2.0 client ID from Google Cloud Console.
+ :param client_secret: Google OAuth 2.0 client secret from Google Cloud Console.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -19,7 +27,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from Google"""
+ """
+ Retrieve user information from Google OAuth2 API.
+
+ :param access_token: Valid Google access token with appropriate scopes.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient() as client:
response = await client.get('https://www.googleapis.com/oauth2/v1/userinfo', headers=headers)
diff --git a/fastapi_oauth20/clients/linuxdo.py b/fastapi_oauth20/clients/linuxdo.py
index fc65d98..6adb0e7 100644
--- a/fastapi_oauth20/clients/linuxdo.py
+++ b/fastapi_oauth20/clients/linuxdo.py
@@ -7,7 +7,15 @@
class LinuxDoOAuth20(OAuth20Base):
+ """Linux.do OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize Linux.do OAuth2 client.
+
+ :param client_id: Linux.do OAuth application client ID.
+ :param client_secret: Linux.do OAuth application client secret.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -18,7 +26,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from Linux Do"""
+ """
+ Retrieve user information from Linux.do API.
+
+ :param access_token: Valid Linux.do access token.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient() as client:
response = await client.get('https://connect.linux.do/api/user', headers=headers)
diff --git a/fastapi_oauth20/clients/oschina.py b/fastapi_oauth20/clients/oschina.py
index c34fde2..91edfb7 100644
--- a/fastapi_oauth20/clients/oschina.py
+++ b/fastapi_oauth20/clients/oschina.py
@@ -7,7 +7,15 @@
class OSChinaOAuth20(OAuth20Base):
+ """OSChina OAuth2 client implementation."""
+
def __init__(self, client_id: str, client_secret: str):
+ """
+ Initialize OSChina OAuth2 client.
+
+ :param client_id: OSChina OAuth application client ID.
+ :param client_secret: OSChina OAuth application client secret.
+ """
super().__init__(
client_id=client_id,
client_secret=client_secret,
@@ -17,7 +25,12 @@ def __init__(self, client_id: str, client_secret: str):
)
async def get_userinfo(self, access_token: str) -> dict:
- """Get user info from OSChina"""
+ """
+ Retrieve user information from OSChina API.
+
+ :param access_token: Valid OSChina access token.
+ :return:
+ """
headers = {'Authorization': f'Bearer {access_token}'}
async with httpx.AsyncClient() as client:
response = await client.get('https://www.oschina.net/action/openapi/user', headers=headers)
diff --git a/fastapi_oauth20/errors.py b/fastapi_oauth20/errors.py
index d28c0cd..e677ace 100644
--- a/fastapi_oauth20/errors.py
+++ b/fastapi_oauth20/errors.py
@@ -4,54 +4,65 @@
class OAuth20BaseError(Exception):
- """The oauth2 base error."""
+ """Base exception class for all OAuth2-related errors."""
msg: str
def __init__(self, msg: str) -> None:
+ """
+ Initialize base OAuth2 error.
+
+ :param msg: Human-readable error message describing the OAuth2 error.
+ """
self.msg = msg
super().__init__(msg)
class OAuth20RequestError(OAuth20BaseError):
- """OAuth2 httpx request error"""
+ """Base exception for OAuth2 HTTP request errors."""
def __init__(self, msg: str, response: httpx.Response | None = None) -> None:
+ """
+ Initialize OAuth2 request error.
+
+ :param msg: Human-readable error message describing the request error.
+ :param response: The HTTP response object that caused the error (if available).
+ """
self.response = response
super().__init__(msg)
class HTTPXOAuth20Error(OAuth20RequestError):
- """OAuth2 error for httpx raise for status"""
+ """Exception raised when httpx raises an HTTP status error."""
pass
class AccessTokenError(OAuth20RequestError):
- """Error raised when get access token fail."""
+ """Exception raised when access token exchange fails."""
pass
class RefreshTokenError(OAuth20RequestError):
- """Refresh token error when refresh token fail."""
+ """Exception raised when refresh token operation fails."""
pass
class RevokeTokenError(OAuth20RequestError):
- """Revoke token error when revoke token fail."""
+ """Exception raised when token revocation fails."""
pass
class GetUserInfoError(OAuth20RequestError):
- """Get user info error when get user info fail."""
+ """Exception raised when user info retrieval fails."""
pass
class RedirectURIError(OAuth20RequestError):
- """Redirect URI set error"""
+ """Exception raised for redirect URI configuration errors."""
pass
diff --git a/fastapi_oauth20/integrations/fastapi.py b/fastapi_oauth20/integrations/fastapi.py
index ec7e73c..5eb190c 100644
--- a/fastapi_oauth20/integrations/fastapi.py
+++ b/fastapi_oauth20/integrations/fastapi.py
@@ -11,7 +11,7 @@
class OAuth20AuthorizeCallbackError(HTTPException, OAuth20BaseError):
- """The OAuth2 authorization callback error."""
+ """Exception raised during OAuth2 authorization callback processing in FastAPI."""
def __init__(
self,
@@ -20,11 +20,21 @@ def __init__(
headers: dict[str, str] | None = None,
response: httpx.Response | None = None,
) -> None:
+ """
+ Initialize OAuth2 callback error.
+
+ :param status_code: HTTP status code to return in the response.
+ :param detail: Error detail message describing what went wrong.
+ :param headers: Additional HTTP headers to include in the error response.
+ :param response: The original HTTP response that caused the error (if any).
+ """
self.response = response
super().__init__(status_code=status_code, detail=detail, headers=headers)
class FastAPIOAuth20:
+ """FastAPI dependency for handling OAuth2 authorization callbacks."""
+
def __init__(
self,
client: OAuth20Base,
@@ -32,10 +42,10 @@ def __init__(
redirect_uri: str | None = None,
):
"""
- OAuth2 authorization callback dependency injection
+ Initialize FastAPI OAuth2 callback handler.
- :param client: A client base on OAuth20Base.
- :param redirect_uri: OAuth2 callback full URL.
+ :param client: An OAuth2 client instance that inherits from OAuth20Base.
+ :param redirect_uri: The full callback URL where the OAuth2 provider redirects after authorization. Must match the URL registered with the OAuth2 provider.
"""
self.client = client
self.redirect_uri = redirect_uri
@@ -48,6 +58,15 @@ async def __call__(
code_verifier: str | None = None,
error: str | None = None,
) -> tuple[dict[str, Any], str | None]:
+ """
+ Process OAuth2 callback request and exchange authorization code for access token.
+
+ :param request: The FastAPI Request object containing callback parameters.
+ :param code: The authorization code received from the OAuth2 provider (extracted from query parameters).
+ :param state: The state parameter for CSRF protection (extracted from query parameters).
+ :param code_verifier: PKCE code verifier if PKCE was used in the authorization request.
+ :param error: Error parameter from OAuth2 provider if authorization was denied or failed.
+ """
if code is None or error is not None:
raise OAuth20AuthorizeCallbackError(
status_code=400,
diff --git a/fastapi_oauth20/oauth20.py b/fastapi_oauth20/oauth20.py
index 2a1c551..9b8ac91 100644
--- a/fastapi_oauth20/oauth20.py
+++ b/fastapi_oauth20/oauth20.py
@@ -33,17 +33,17 @@ def __init__(
revoke_token_endpoint_basic_auth: bool = False,
):
"""
- Base OAuth2 client.
+ Base OAuth2 client implementing the OAuth 2.0 authorization framework.
:param client_id: The client ID provided by the OAuth2 provider.
:param client_secret: The client secret provided by the OAuth2 provider.
- :param authorize_endpoint: The authorization endpoint URL.
- :param access_token_endpoint: The access token endpoint URL.
- :param refresh_token_endpoint: The refresh token endpoint URL.
- :param revoke_token_endpoint: The revoke token endpoint URL.
- :param default_scopes:
- :param token_endpoint_basic_auth:
- :param revoke_token_endpoint_basic_auth:
+ :param authorize_endpoint: The authorization endpoint URL where users are redirected to grant access.
+ :param access_token_endpoint: The token endpoint URL for exchanging authorization codes for access tokens.
+ :param refresh_token_endpoint: The token endpoint URL for refreshing expired access tokens using refresh tokens.
+ :param revoke_token_endpoint: The endpoint URL for revoking access tokens or refresh tokens.
+ :param default_scopes: Default list of OAuth scopes to request if none are specified.
+ :param token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for token endpoint requests.
+ :param revoke_token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for revoke endpoint requests.
"""
self.client_id = client_id
self.client_secret = client_secret
@@ -69,14 +69,14 @@ async def get_authorization_url(
**kwargs,
) -> str:
"""
- Get authorization url for given.
-
- :param redirect_uri: redirected after authorization.
- :param state: An opaque value used by the client to maintain state between the request and the callback.
- :param scope: The scopes to be requested.
- :param code_challenge: [PKCE](https://datatracker.ietf.org/doc/html/rfc7636) code challenge.
- :param code_challenge_method: [PKCE](https://datatracker.ietf.org/doc/html/rfc7636) code challenge method.
- :param kwargs: Additional arguments passed to the OAuth2 client.
+ Generate OAuth2 authorization URL for redirecting users to grant access.
+
+ :param redirect_uri: The URL where the OAuth2 provider will redirect after authorization.
+ :param state: An opaque value used by the client to maintain state between the request and callback, preventing CSRF attacks.
+ :param scope: The list of OAuth scopes to request. If None, uses default_scopes from initialization.
+ :param code_challenge: PKCE code challenge generated from code_verifier using the specified method.
+ :param code_challenge_method: PKCE code challenge method, either 'plain' or 'S256' (recommended).
+ :param kwargs: Additional query parameters to include in the authorization URL.
:return:
"""
params = {
@@ -105,11 +105,11 @@ async def get_authorization_url(
async def get_access_token(self, code: str, redirect_uri: str, code_verifier: str | None = None) -> dict[str, Any]:
"""
- Get access token for given.
+ Exchange authorization code for access token.
- :param code: The authorization code.
- :param redirect_uri: redirect uri after authorization.
- :param code_verifier: the code verifier for the [PKCE](https://datatracker.ietf.org/doc/html/rfc7636).
+ :param code: The authorization code received from the OAuth2 provider callback.
+ :param redirect_uri: The exact redirect URI used in the authorization request (must match).
+ :param code_verifier: The PKCE code verifier used to generate the code challenge (required if PKCE was used).
:return:
"""
data = {
@@ -139,9 +139,9 @@ async def get_access_token(self, code: str, redirect_uri: str, code_verifier: st
async def refresh_token(self, refresh_token: str) -> dict[str, Any]:
"""
- Get new access token by refresh token.
+ Refresh an access token using a refresh token.
- :param refresh_token: The refresh token.
+ :param refresh_token: The refresh token received from the initial token exchange.
:return:
"""
if self.refresh_token_endpoint is None:
@@ -170,10 +170,10 @@ async def refresh_token(self, refresh_token: str) -> dict[str, Any]:
async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
"""
- Revoke a token.
+ Revoke an access token or refresh token.
- :param token: A token or refresh token to revoke.
- :param token_type_hint: Usually either `token` or `refresh_token`.
+ :param token: The access token or refresh token to revoke.
+ :param token_type_hint: Optional hint to the server about the token type ('access_token' or 'refresh_token').
:return:
"""
if self.revoke_token_endpoint is None:
@@ -184,7 +184,13 @@ async def revoke_token(self, token: str, token_type_hint: str | None = None) ->
if token_type_hint is not None:
data.update({'token_type_hint': token_type_hint})
- async with httpx.AsyncClient() as client:
+ auth = None
+ if not self.revoke_token_endpoint_basic_auth:
+ data.update({'client_id': self.client_id, 'client_secret': self.client_secret})
+ else:
+ auth = httpx.BasicAuth(self.client_id, self.client_secret)
+
+ async with httpx.AsyncClient(auth=auth) as client:
response = await client.post(
self.revoke_token_endpoint,
data=data,
@@ -194,7 +200,12 @@ async def revoke_token(self, token: str, token_type_hint: str | None = None) ->
@staticmethod
def raise_httpx_oauth20_errors(response: httpx.Response) -> None:
- """Raise HTTPXOAuth20Error if the response is invalid"""
+ """
+ Check HTTP response and raise appropriate OAuth2 errors for invalid responses.
+
+ :param response: The HTTP response object to validate.
+ :return:
+ """
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
@@ -204,7 +215,13 @@ def raise_httpx_oauth20_errors(response: httpx.Response) -> None:
@staticmethod
def get_json_result(response: httpx.Response, *, err_class: type[OAuth20RequestError]) -> dict[str, Any]:
- """Get response json"""
+ """
+ Parse JSON response and handle JSON decoding errors.
+
+ :param response: The HTTP response object containing JSON data.
+ :param err_class: The specific OAuth2RequestError subclass to raise on JSON parsing failure.
+ :return:
+ """
try:
return cast(dict[str, Any], response.json())
except json.JSONDecodeError as e:
@@ -213,9 +230,9 @@ def get_json_result(response: httpx.Response, *, err_class: type[OAuth20RequestE
@abc.abstractmethod
async def get_userinfo(self, access_token: str) -> dict[str, Any]:
"""
- Get user info from the API provider
+ Retrieve user information from the OAuth2 provider.
- :param access_token: The access token.
+ :param access_token: Valid access token to authenticate the request to the provider's user info endpoint.
:return:
"""
raise NotImplementedError()
diff --git a/mkdocs.yml b/mkdocs.yml
index 795fb60..3fb5580 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -10,7 +10,6 @@ nav:
- 安装: install.md
- 名词解释: explanation.md
- 用法: usage.md
- - 高级用法: advanced.md
- 集成:
- FastAPI: fastapi.md
- 客户端状态: status.md
diff --git a/pyproject.toml b/pyproject.toml
index b4f3560..8c80a0e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,6 +35,7 @@ dev = [
"fastapi>=0.119.0",
"respx>=0.22.0",
"ty>=0.0.1a23",
+ "click==8.2.1",
]
lint = [
"pre-commit>=4.3.0",
diff --git a/requirements.txt b/requirements.txt
index d2617a2..6c0823f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,8 +15,11 @@ certifi==2025.10.5
# httpx
cfgv==3.4.0
# via pre-commit
+click==8.2.1
colorama==0.4.6 ; sys_platform == 'win32'
- # via pytest
+ # via
+ # click
+ # pytest
distlib==0.4.0
# via virtualenv
exceptiongroup==1.3.0 ; python_full_version < '3.11'
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29..56fafa5 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/tests/clients/__init__.py b/tests/clients/__init__.py
new file mode 100644
index 0000000..56fafa5
--- /dev/null
+++ b/tests/clients/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/tests/clients/test_feishu.py b/tests/clients/test_feishu.py
new file mode 100644
index 0000000..a68b8e7
--- /dev/null
+++ b/tests/clients/test_feishu.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.feishu import FeiShuOAuth20
+from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestFeiShuOAuth20:
+ """Test FeiShu OAuth2 client functionality."""
+
+ def test_feishu_client_initialization(self, feishu_client):
+ """Test FeiShu client initialization with correct parameters."""
+ assert feishu_client.client_id == TEST_CLIENT_ID
+ assert feishu_client.client_secret == TEST_CLIENT_SECRET
+ assert feishu_client.authorize_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/authorize'
+ assert feishu_client.access_token_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/token'
+ assert feishu_client.refresh_token_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/authorize'
+ assert feishu_client.default_scopes == [
+ 'contact:user.employee_id:readonly',
+ 'contact:user.base:readonly',
+ 'contact:user.email:readonly',
+ ]
+
+ def test_feishu_client_initialization_with_custom_credentials(self):
+ """Test FeiShu client initialization with custom credentials."""
+ client = FeiShuOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_feishu_client_inheritance(self, feishu_client):
+ """Test that FeiShu client properly inherits from OAuth20Base."""
+ assert isinstance(feishu_client, OAuth20Base)
+
+ def test_feishu_client_scopes_are_lists(self, feishu_client):
+ """Test that default scopes are properly configured as lists."""
+ assert isinstance(feishu_client.default_scopes, list)
+ assert len(feishu_client.default_scopes) == 3
+ assert all(isinstance(scope, str) for scope in feishu_client.default_scopes)
+
+ def test_feishu_client_endpoint_urls(self):
+ """Test that FeiShu client uses correct endpoint URLs."""
+ client = FeiShuOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/suite/passport/oauth/authorize')
+ assert client.access_token_endpoint.endswith('/suite/passport/oauth/token')
+ assert client.refresh_token_endpoint.endswith('/suite/passport/oauth/authorize')
+
+ # Test that all endpoints use the correct domain
+ for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]:
+ assert 'passport.feishu.cn' in endpoint
+
+ def test_feishu_client_multiple_instances(self):
+ """Test that multiple FeiShu client instances work independently."""
+ client1 = FeiShuOAuth20('client1', 'secret1')
+ client2 = FeiShuOAuth20('client2', 'secret2')
+
+ assert client1.client_id != client2.client_id
+ assert client1.client_secret != client2.client_secret
+ assert client1.authorize_endpoint == client2.authorize_endpoint
+ assert client1.access_token_endpoint == client2.access_token_endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success(self, feishu_client):
+ """Test successful user info retrieval from FeiShu API."""
+ mock_user_data = create_mock_user_data('feishu')
+ mock_user_info_response(
+ respx,
+ {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'},
+ mock_user_data,
+ )
+
+ result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_with_different_access_token(self, feishu_client):
+ """Test user info retrieval with different access tokens."""
+ mock_user_data = create_mock_user_data('feishu', user_id='user_789', name='Another User')
+ mock_user_info_response(
+ respx,
+ {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'},
+ mock_user_data,
+ )
+
+ result = await feishu_client.get_userinfo('different_token')
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_empty_response(self, feishu_client):
+ """Test handling of empty user info response."""
+ mock_user_info_response(
+ respx, {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, {}
+ )
+
+ result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == {}
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_partial_data(self, feishu_client):
+ """Test handling of partial user info response."""
+ partial_data = {'user_id': 'test_user', 'name': 'Test User'}
+ mock_user_info_response(
+ respx,
+ {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'},
+ partial_data,
+ )
+
+ result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == partial_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, feishu_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'user_id': 'test_user'}
+ route = mock_user_info_response(
+ respx,
+ {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'},
+ mock_user_data,
+ )
+
+ await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_401(self, feishu_client):
+ """Test handling of 401 HTTP error when getting user info."""
+ respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock(
+ return_value=httpx.Response(401, text='Unauthorized')
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await feishu_client.get_userinfo(INVALID_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_403(self, feishu_client):
+ """Test handling of 403 HTTP error when getting user info."""
+ respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock(
+ return_value=httpx.Response(403, text='Forbidden')
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_500(self, feishu_client):
+ """Test handling of 500 HTTP error when getting user info."""
+ respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock(
+ return_value=httpx.Response(500, text='Internal Server Error')
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_invalid_json(self, feishu_client):
+ """Test handling of invalid JSON response."""
+ respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock(
+ return_value=httpx.Response(200, text='invalid json')
+ )
+
+ with pytest.raises(GetUserInfoError):
+ await feishu_client.get_userinfo(TEST_ACCESS_TOKEN)
diff --git a/tests/clients/test_gitee.py b/tests/clients/test_gitee.py
new file mode 100644
index 0000000..7afa4a9
--- /dev/null
+++ b/tests/clients/test_gitee.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.gitee import GiteeOAuth20
+from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestGiteeOAuth20:
+ """Test Gitee OAuth2 client functionality."""
+
+ def test_gitee_client_initialization(self, gitee_client):
+ """Test Gitee client initialization with correct parameters."""
+ assert gitee_client.client_id == TEST_CLIENT_ID
+ assert gitee_client.client_secret == TEST_CLIENT_SECRET
+ assert gitee_client.authorize_endpoint == 'https://gitee.com/oauth/authorize'
+ assert gitee_client.access_token_endpoint == 'https://gitee.com/oauth/token'
+ assert gitee_client.refresh_token_endpoint == 'https://gitee.com/oauth/token'
+ assert gitee_client.default_scopes == ['user_info']
+
+ def test_gitee_client_initialization_with_custom_credentials(self):
+ """Test Gitee client initialization with custom credentials."""
+ client = GiteeOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_gitee_client_inheritance(self, gitee_client):
+ """Test that Gitee client properly inherits from OAuth20Base."""
+ assert isinstance(gitee_client, OAuth20Base)
+
+ def test_gitee_client_scopes_are_lists(self, gitee_client):
+ """Test that default scopes are properly configured as lists."""
+ assert isinstance(gitee_client.default_scopes, list)
+ assert len(gitee_client.default_scopes) == 1
+ assert all(isinstance(scope, str) for scope in gitee_client.default_scopes)
+
+ def test_gitee_client_endpoint_urls(self):
+ """Test that Gitee client uses correct endpoint URLs."""
+ client = GiteeOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/oauth/authorize')
+ assert client.access_token_endpoint.endswith('/oauth/token')
+ assert client.refresh_token_endpoint.endswith('/oauth/token')
+
+ # Test that all endpoints use the correct domain
+ for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]:
+ assert 'gitee.com' in endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success(self, gitee_client):
+ """Test successful user info retrieval from Gitee API."""
+ mock_user_data = create_mock_user_data('gitee')
+ mock_user_info_response(
+ respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data
+ )
+
+ result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, gitee_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'id': 'test_user'}
+ route = mock_user_info_response(
+ respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data
+ )
+
+ await gitee_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error(self, gitee_client):
+ """Test handling of HTTP errors when getting user info."""
+ respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(401, text='Unauthorized'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await gitee_client.get_userinfo(INVALID_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_empty_response(self, gitee_client):
+ """Test handling of empty user info response."""
+ mock_user_info_response(respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, {})
+
+ result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == {}
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_partial_data(self, gitee_client):
+ """Test handling of partial user info response."""
+ partial_data = {'id': 123456, 'login': 'testuser'}
+ mock_user_info_response(
+ respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, partial_data
+ )
+
+ result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == partial_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_invalid_json(self, gitee_client):
+ """Test handling of invalid JSON response."""
+ respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(200, text='invalid json'))
+
+ with pytest.raises(GetUserInfoError):
+ await gitee_client.get_userinfo(TEST_ACCESS_TOKEN)
diff --git a/tests/clients/test_github.py b/tests/clients/test_github.py
new file mode 100644
index 0000000..acd13d2
--- /dev/null
+++ b/tests/clients/test_github.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.github import GitHubOAuth20
+from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestGitHubOAuth20:
+ """Test GitHub OAuth2 client functionality."""
+
+ def test_github_client_initialization(self, github_client):
+ """Test GitHub client initialization with correct parameters."""
+ assert github_client.client_id == TEST_CLIENT_ID
+ assert github_client.client_secret == TEST_CLIENT_SECRET
+ assert github_client.authorize_endpoint == 'https://github.com/login/oauth/authorize'
+ assert github_client.access_token_endpoint == 'https://github.com/login/oauth/access_token'
+ assert github_client.default_scopes == ['user', 'user:email']
+
+ def test_github_client_initialization_with_custom_credentials(self):
+ """Test GitHub client initialization with custom credentials."""
+ client = GitHubOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_github_client_inheritance(self, github_client):
+ """Test that GitHub client properly inherits from OAuth20Base."""
+ assert isinstance(github_client, OAuth20Base)
+
+ def test_github_client_scopes_are_lists(self, github_client):
+ """Test that default scopes are properly configured as lists."""
+ assert isinstance(github_client.default_scopes, list)
+ assert len(github_client.default_scopes) == 2
+ assert all(isinstance(scope, str) for scope in github_client.default_scopes)
+
+ def test_github_client_endpoint_urls(self):
+ """Test that GitHub client uses correct endpoint URLs."""
+ client = GitHubOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/login/oauth/authorize')
+ assert client.access_token_endpoint.endswith('/login/oauth/access_token')
+
+ # Test that all endpoints use the correct domain
+ for endpoint in [client.authorize_endpoint, client.access_token_endpoint]:
+ assert 'github.com' in endpoint
+
+ def test_github_client_multiple_instances(self):
+ """Test that multiple GitHub client instances work independently."""
+ client1 = GitHubOAuth20('client1', 'secret1')
+ client2 = GitHubOAuth20('client2', 'secret2')
+
+ assert client1.client_id != client2.client_id
+ assert client1.client_secret != client2.client_secret
+ assert client1.authorize_endpoint == client2.authorize_endpoint
+ assert client1.access_token_endpoint == client2.access_token_endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success_with_email(self, github_client):
+ """Test successful user info retrieval from GitHub API with email included."""
+ mock_user_data = create_mock_user_data('github')
+ mock_user_info_response(
+ respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data
+ )
+
+ result = await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success_without_email(self, github_client):
+ """Test successful user info retrieval from GitHub API without email."""
+ mock_user_data = create_mock_user_data('github', email=None)
+ mock_user_info_response(
+ respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data
+ )
+ # Mock emails endpoint
+ emails_data = [{'email': 'test@example.com', 'primary': True}]
+ respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data))
+
+ result = await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result['login'] == mock_user_data['login']
+ assert result['email'] == 'test@example.com'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_with_different_access_token(self, github_client):
+ """Test user info retrieval with different access tokens."""
+ mock_user_data = create_mock_user_data('github', id=789, login='different_user')
+ mock_user_info_response(
+ respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data
+ )
+
+ result = await github_client.get_userinfo('different_token')
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, github_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'id': 'test_user', 'email': 'test@example.com'} # Include email to avoid emails endpoint call
+ route = mock_user_info_response(
+ respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data
+ )
+
+ await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_401(self, github_client):
+ """Test handling of 401 HTTP error when getting user info."""
+ respx.get('https://api.github.com/user').mock(return_value=httpx.Response(401, text='Unauthorized'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await github_client.get_userinfo(INVALID_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_403(self, github_client):
+ """Test handling of 403 HTTP error when getting user info."""
+ respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, text='Forbidden'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error_500(self, github_client):
+ """Test handling of 500 HTTP error when getting user info."""
+ respx.get('https://api.github.com/user').mock(return_value=httpx.Response(500, text='Internal Server Error'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_invalid_json(self, github_client):
+ """Test handling of invalid JSON response."""
+ respx.get('https://api.github.com/user').mock(return_value=httpx.Response(200, text='invalid json'))
+
+ with pytest.raises(GetUserInfoError):
+ await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_empty_response(self, github_client):
+ """Test handling of empty user info response."""
+ mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, {})
+ # Mock emails endpoint since empty response will trigger email lookup
+ emails_data = [{'email': 'test@example.com', 'primary': True}]
+ respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data))
+
+ result = await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result['email'] == 'test@example.com'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_partial_data(self, github_client):
+ """Test handling of partial user info response."""
+ partial_data = {
+ 'id': 123456,
+ 'login': 'testuser',
+ 'email': 'test@example.com',
+ } # Add email to avoid emails endpoint call
+ mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, partial_data)
+
+ result = await github_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == partial_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_rate_limit(self, github_client):
+ """Test handling of GitHub API rate limit."""
+ # GitHub rate limit response
+ rate_limit_response = {
+ 'message': 'API rate limit exceeded for xxx.xxx.xxx.xxx.',
+ 'documentation_url': 'https://docs.github.com/rest/overview/rate-limits-for-the-rest-api',
+ }
+
+ respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, json=rate_limit_response))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await github_client.get_userinfo(TEST_ACCESS_TOKEN)
diff --git a/tests/clients/test_google.py b/tests/clients/test_google.py
new file mode 100644
index 0000000..505db53
--- /dev/null
+++ b/tests/clients/test_google.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.google import GoogleOAuth20
+from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestGoogleOAuth20:
+ """Test Google OAuth2 client functionality."""
+
+ def test_google_client_initialization(self, google_client):
+ """Test Google client initialization with correct parameters."""
+ assert google_client.client_id == TEST_CLIENT_ID
+ assert google_client.client_secret == TEST_CLIENT_SECRET
+ assert google_client.authorize_endpoint == 'https://accounts.google.com/o/oauth2/v2/auth'
+ assert google_client.access_token_endpoint == 'https://oauth2.googleapis.com/token'
+ assert google_client.refresh_token_endpoint == 'https://oauth2.googleapis.com/token'
+ assert google_client.revoke_token_endpoint == 'https://accounts.google.com/o/oauth2/revoke'
+ assert google_client.default_scopes == ['email', 'openid', 'profile']
+
+ def test_google_client_initialization_with_custom_credentials(self):
+ """Test Google client initialization with custom credentials."""
+ client = GoogleOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_google_client_inheritance(self, google_client):
+ """Test that Google client properly inherits from OAuth20Base."""
+ assert isinstance(google_client, OAuth20Base)
+
+ def test_google_client_scopes_are_lists(self, google_client):
+ """Test that default scopes are properly configured as lists."""
+ assert isinstance(google_client.default_scopes, list)
+ assert len(google_client.default_scopes) == 3
+ assert all(isinstance(scope, str) for scope in google_client.default_scopes)
+
+ def test_google_client_endpoint_urls(self):
+ """Test that Google client uses correct endpoint URLs."""
+ client = GoogleOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/o/oauth2/v2/auth')
+ assert client.access_token_endpoint.endswith('/token')
+ assert client.refresh_token_endpoint.endswith('/token')
+ assert client.revoke_token_endpoint.endswith('/o/oauth2/revoke')
+
+ # Test that all endpoints use the correct domains
+ assert 'accounts.google.com' in client.authorize_endpoint
+ assert 'accounts.google.com' in client.revoke_token_endpoint
+ assert 'oauth2.googleapis.com' in client.access_token_endpoint
+ assert 'oauth2.googleapis.com' in client.refresh_token_endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success(self, google_client):
+ """Test successful user info retrieval from Google OAuth2 API."""
+ mock_user_data = create_mock_user_data('google')
+ mock_user_info_response(
+ respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data
+ )
+
+ result = await google_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, google_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'id': 'test_user'}
+ route = mock_user_info_response(
+ respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data
+ )
+
+ await google_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error(self, google_client):
+ """Test handling of HTTP errors when getting user info."""
+ respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock(
+ return_value=httpx.Response(401, text='Unauthorized')
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await google_client.get_userinfo(INVALID_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_empty_response(self, google_client):
+ """Test handling of empty user info response."""
+ mock_user_info_response(
+ respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, {}
+ )
+
+ result = await google_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == {}
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_partial_data(self, google_client):
+ """Test handling of partial user info response."""
+ partial_data = {'id': '123456789', 'email': 'test@example.com'}
+ mock_user_info_response(
+ respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, partial_data
+ )
+
+ result = await google_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == partial_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_invalid_json(self, google_client):
+ """Test handling of invalid JSON response."""
+ respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock(
+ return_value=httpx.Response(200, text='invalid json')
+ )
+
+ with pytest.raises(GetUserInfoError):
+ await google_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ def test_google_client_multiple_instances(self):
+ """Test that multiple Google client instances work independently."""
+ client1 = GoogleOAuth20('client1', 'secret1')
+ client2 = GoogleOAuth20('client2', 'secret2')
+
+ assert client1.client_id != client2.client_id
+ assert client1.client_secret != client2.client_secret
+ assert client1.authorize_endpoint == client2.authorize_endpoint
+ assert client1.access_token_endpoint == client2.access_token_endpoint
diff --git a/tests/clients/test_linuxdo.py b/tests/clients/test_linuxdo.py
new file mode 100644
index 0000000..1201c35
--- /dev/null
+++ b/tests/clients/test_linuxdo.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20
+from fastapi_oauth20.errors import HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestLinuxDoOAuth20:
+ """Test LinuxDo OAuth2 client functionality."""
+
+ def test_linuxdo_client_initialization(self, linuxdo_client):
+ """Test LinuxDo client initialization with correct parameters."""
+ assert linuxdo_client.client_id == TEST_CLIENT_ID
+ assert linuxdo_client.client_secret == TEST_CLIENT_SECRET
+ assert linuxdo_client.authorize_endpoint == 'https://connect.linux.do/oauth2/authorize'
+ assert linuxdo_client.access_token_endpoint == 'https://connect.linux.do/oauth2/token'
+ assert linuxdo_client.refresh_token_endpoint == 'https://connect.linux.do/oauth2/token'
+ assert linuxdo_client.default_scopes is None
+ assert linuxdo_client.token_endpoint_basic_auth is True
+
+ def test_linuxdo_client_initialization_with_custom_credentials(self):
+ """Test LinuxDo client initialization with custom credentials."""
+ client = LinuxDoOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_linuxdo_client_inheritance(self, linuxdo_client):
+ """Test that LinuxDo client properly inherits from OAuth20Base."""
+ assert isinstance(linuxdo_client, OAuth20Base)
+
+ def test_linuxdo_client_basic_auth_enabled(self, linuxdo_client):
+ """Test that LinuxDo client has basic authentication enabled for token endpoint."""
+ assert linuxdo_client.token_endpoint_basic_auth is True
+
+ def test_linuxdo_client_endpoint_urls(self):
+ """Test that LinuxDo client uses correct endpoint URLs."""
+ client = LinuxDoOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/oauth2/authorize')
+ assert client.access_token_endpoint.endswith('/oauth2/token')
+ assert client.refresh_token_endpoint.endswith('/oauth2/token')
+
+ # Test that all endpoints use the correct domain
+ for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]:
+ assert 'connect.linux.do' in endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success(self, linuxdo_client):
+ """Test successful user info retrieval from LinuxDo API."""
+ mock_user_data = create_mock_user_data('linuxdo')
+ mock_user_info_response(
+ respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data
+ )
+
+ result = await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, linuxdo_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'id': 'test_user'}
+ route = mock_user_info_response(
+ respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data
+ )
+
+ await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error(self, linuxdo_client):
+ """Test handling of HTTP errors when getting user info."""
+ respx.get('https://connect.linux.do/api/user').mock(return_value=httpx.Response(401, text='Unauthorized'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await linuxdo_client.get_userinfo(INVALID_TOKEN)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_access_token_uses_basic_auth(self, linuxdo_client):
+ """Test that access token requests use HTTP Basic Authentication."""
+ mock_token_data = {'access_token': 'new_access_token'}
+
+ # Mock the token endpoint and capture the request
+ route = respx.post('https://connect.linux.do/oauth2/token').mock(
+ return_value=httpx.Response(200, json=mock_token_data)
+ )
+
+ await linuxdo_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback')
+
+ # Verify BasicAuth was used
+ assert route.called
+ request = route.calls[0].request
+ assert 'authorization' in request.headers
+ # Basic auth should be present
+ assert request.headers['authorization'].startswith('Basic ')
diff --git a/tests/clients/test_oschina.py b/tests/clients/test_oschina.py
new file mode 100644
index 0000000..95cc411
--- /dev/null
+++ b/tests/clients/test_oschina.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.clients.oschina import OSChinaOAuth20
+from fastapi_oauth20.errors import HTTPXOAuth20Error
+from fastapi_oauth20.oauth20 import OAuth20Base
+from tests.conftest import (
+ INVALID_TOKEN,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ create_mock_user_data,
+ mock_user_info_response,
+)
+
+# Constants specific to this test file
+CUSTOM_CLIENT_ID = 'custom_id'
+CUSTOM_CLIENT_SECRET = 'custom_secret'
+
+
+class TestOSChinaOAuth20:
+ """Test OSChina OAuth2 client functionality."""
+
+ def test_oschina_client_initialization(self, oschina_client):
+ """Test OSChina client initialization with correct parameters."""
+ assert oschina_client.client_id == TEST_CLIENT_ID
+ assert oschina_client.client_secret == TEST_CLIENT_SECRET
+ assert oschina_client.authorize_endpoint == 'https://www.oschina.net/action/oauth2/authorize'
+ assert oschina_client.access_token_endpoint == 'https://www.oschina.net/action/openapi/token'
+ assert oschina_client.refresh_token_endpoint == 'https://www.oschina.net/action/openapi/token'
+ assert oschina_client.default_scopes is None
+
+ def test_oschina_client_initialization_with_custom_credentials(self):
+ """Test OSChina client initialization with custom credentials."""
+ client = OSChinaOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET)
+ assert client.client_id == CUSTOM_CLIENT_ID
+ assert client.client_secret == CUSTOM_CLIENT_SECRET
+
+ def test_oschina_client_inheritance(self, oschina_client):
+ """Test that OSChina client properly inherits from OAuth20Base."""
+ assert isinstance(oschina_client, OAuth20Base)
+
+ def test_oschina_client_no_default_scopes(self, oschina_client):
+ """Test that OSChina client has no default scopes configured."""
+ assert oschina_client.default_scopes is None
+
+ def test_oschina_client_endpoint_urls(self):
+ """Test that OSChina client uses correct endpoint URLs."""
+ client = OSChinaOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET)
+
+ # Test that endpoints are correctly set without hardcoding them in tests
+ assert client.authorize_endpoint.endswith('/action/oauth2/authorize')
+ assert client.access_token_endpoint.endswith('/action/openapi/token')
+ assert client.refresh_token_endpoint.endswith('/action/openapi/token')
+
+ # Test that all endpoints use the correct domain
+ for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]:
+ assert 'oschina.net' in endpoint
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_success(self, oschina_client):
+ """Test successful user info retrieval from OSChina API."""
+ mock_user_data = create_mock_user_data('oschina')
+ mock_user_info_response(
+ respx,
+ {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'},
+ mock_user_data,
+ )
+
+ result = await oschina_client.get_userinfo(TEST_ACCESS_TOKEN)
+ assert result == mock_user_data
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_authorization_header(self, oschina_client):
+ """Test that authorization header is correctly formatted."""
+ mock_user_data = {'id': 'test_user'}
+ route = mock_user_info_response(
+ respx,
+ {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'},
+ mock_user_data,
+ )
+
+ await oschina_client.get_userinfo(TEST_ACCESS_TOKEN)
+
+ # Verify the request was made with correct authorization header
+ assert route.called
+ request = route.calls[0].request
+ assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}'
+
+ @pytest.mark.asyncio
+ @respx.mock
+ async def test_get_userinfo_http_error(self, oschina_client):
+ """Test handling of HTTP errors when getting user info."""
+ respx.get('https://www.oschina.net/action/openapi/user').mock(
+ return_value=httpx.Response(401, text='Unauthorized')
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await oschina_client.get_userinfo(INVALID_TOKEN)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..80e6a66
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Global constants and shared fixtures for fastapi-oauth20 tests.
+"""
+
+import httpx
+import pytest
+
+from fastapi import Depends, FastAPI
+from fastapi.testclient import TestClient
+
+from fastapi_oauth20.clients.feishu import FeiShuOAuth20
+from fastapi_oauth20.clients.gitee import GiteeOAuth20
+from fastapi_oauth20.clients.github import GitHubOAuth20
+from fastapi_oauth20.clients.google import GoogleOAuth20
+from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20
+from fastapi_oauth20.clients.oschina import OSChinaOAuth20
+
+TEST_CLIENT_ID = 'test_client_id'
+TEST_CLIENT_SECRET = 'test_client_secret'
+TEST_ACCESS_TOKEN = 'test_access_token'
+INVALID_TOKEN = 'invalid_token'
+TEST_STATE = 'test_state'
+OAUTH_PROVIDERS = [
+ {
+ 'name': 'github',
+ 'client_class': GitHubOAuth20,
+ 'token_url': 'https://github.com/login/oauth/access_token',
+ 'user_info_url': 'https://api.github.com/user',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user'},
+ 'redirect_uri': 'http://localhost:8000/auth/github/callback',
+ },
+ {
+ 'name': 'google',
+ 'client_class': GoogleOAuth20,
+ 'token_url': 'https://oauth2.googleapis.com/token',
+ 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'Bearer', 'expires_in': 3600},
+ 'redirect_uri': 'http://localhost:8000/auth/google/callback',
+ },
+ {
+ 'name': 'gitee',
+ 'client_class': GiteeOAuth20,
+ 'token_url': 'https://gitee.com/oauth/token',
+ 'user_info_url': 'https://gitee.com/api/v5/user',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user_info'},
+ 'redirect_uri': 'http://localhost:8000/auth/gitee/callback',
+ },
+ {
+ 'name': 'feishu',
+ 'client_class': FeiShuOAuth20,
+ 'token_url': 'https://passport.feishu.cn/suite/passport/oauth/token',
+ 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600},
+ 'redirect_uri': 'http://localhost:8000/auth/feishu/callback',
+ },
+ {
+ 'name': 'linuxdo',
+ 'client_class': LinuxDoOAuth20,
+ 'token_url': 'https://connect.linux.do/oauth2/token',
+ 'user_info_url': 'https://connect.linux.do/api/user',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600},
+ 'redirect_uri': 'http://localhost:8000/auth/linuxdo/callback',
+ },
+ {
+ 'name': 'oschina',
+ 'client_class': OSChinaOAuth20,
+ 'token_url': 'https://www.oschina.net/action/openapi/token',
+ 'user_info_url': 'https://www.oschina.net/action/openapi/user',
+ 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600},
+ 'redirect_uri': 'http://localhost:8000/auth/oschina/callback',
+ },
+]
+
+
+@pytest.fixture(params=OAUTH_PROVIDERS)
+def oauth_provider_config(request):
+ """Get OAuth provider configuration."""
+ return request.param
+
+
+@pytest.fixture
+def oauth_client(oauth_provider_config):
+ """Create OAuth2 client for testing based on provider config."""
+ provider_config = oauth_provider_config
+ client_class = provider_config['client_class']
+ return client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def fastapi_app():
+ """Create FastAPI app for testing."""
+ app = FastAPI()
+
+ @app.get('/')
+ async def root():
+ return {'message': 'OAuth2 Test App'}
+
+ return app
+
+
+@pytest.fixture
+def test_client(fastapi_app):
+ """Create TestClient for FastAPI app."""
+ return TestClient(fastapi_app)
+
+
+# Individual OAuth client fixtures for non-parametrized tests
+@pytest.fixture
+def github_client():
+ """Create GitHub OAuth2 client instance for testing."""
+ return GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def google_client():
+ """Create Google OAuth2 client instance for testing."""
+ return GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def gitee_client():
+ """Create Gitee OAuth2 client instance for testing."""
+ return GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def feishu_client():
+ """Create FeiShu OAuth2 client instance for testing."""
+ return FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def linuxdo_client():
+ """Create LinuxDo OAuth2 client instance for testing."""
+ return LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+@pytest.fixture
+def oschina_client():
+ """Create OSChina OAuth2 client instance for testing."""
+ return OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+
+
+def create_mock_user_data(provider_name: str, **overrides):
+ """Create mock user data for a specific provider with optional overrides."""
+ MOCK_USER_DATA = {
+ 'github': {
+ 'id': 123456,
+ 'login': 'testuser',
+ 'name': 'Test User',
+ 'email': 'test@example.com',
+ 'bio': 'Test bio',
+ 'location': 'Test Location',
+ },
+ 'google': {
+ 'id': '123456789',
+ 'email': 'test@gmail.com',
+ 'name': 'Test User',
+ 'picture': 'https://lh3.googleusercontent.com/test.jpg',
+ },
+ 'gitee': {
+ 'id': 123456,
+ 'login': 'testuser',
+ 'name': 'Test User',
+ 'email': 'test@example.com',
+ 'avatar_url': 'https://avatar.example.com/testuser.png',
+ },
+ 'feishu': {
+ 'user_id': 'test_user_123',
+ 'employee_id': 'emp_456',
+ 'name': 'Test User',
+ 'email': 'test@example.com',
+ 'mobile': '13800000000',
+ },
+ 'linuxdo': {
+ 'id': 123456,
+ 'username': 'testuser',
+ 'name': 'Test User',
+ 'email': 'test@example.com',
+ 'avatar_url': 'https://linux.do/avatar/testuser.png',
+ },
+ 'oschina': {
+ 'id': 123456,
+ 'name': 'Test User',
+ 'email': 'test@example.com',
+ 'avatar': 'https://oschina.net/img/test.jpg',
+ },
+ }
+
+ base_data = MOCK_USER_DATA.get(provider_name, {}).copy()
+ base_data.update(overrides)
+ return base_data
+
+
+def mock_oauth_token_response(respx_mock, provider_config: dict, status_code: int = 200):
+ """Mock OAuth token endpoint response for a provider."""
+ return respx_mock.post(provider_config['token_url']).mock(
+ return_value=httpx.Response(status_code, json=provider_config['token_response'])
+ )
+
+
+def mock_user_info_response(respx_mock, provider_config: dict, user_data: dict = None, status_code: int = 200):
+ """Mock user info endpoint response for a provider."""
+ if user_data is None:
+ user_data = create_mock_user_data(provider_config['name'])
+ return respx_mock.get(provider_config['user_info_url']).mock(
+ return_value=httpx.Response(status_code, json=user_data)
+ )
+
+
+def setup_oauth_callback_route(app: FastAPI, provider_config: dict, oauth_dependency):
+ """Setup OAuth callback route for testing."""
+ callback_path = f'/auth/{provider_config["name"]}/callback'
+
+ @app.get(callback_path)
+ async def oauth_callback_handler(access_token_state=Depends(oauth_dependency)):
+ token, state = access_token_state
+ return {'provider': provider_config['name'], 'access_token': token, 'state': state}
+
+ return callback_path
+
+
+def assert_oauth_error_response(response, expected_error: str, expected_status: int = 400):
+ """Assert OAuth error response has correct format."""
+ assert response.status_code == expected_status
+ data = response.json()
+ if 'detail' in data:
+ assert data['detail'] == expected_error
+ else:
+ assert data.get('error') == expected_error
diff --git a/tests/test_basic.py b/tests/test_basic.py
new file mode 100644
index 0000000..db7b0ed
--- /dev/null
+++ b/tests/test_basic.py
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from fastapi_oauth20.clients.feishu import FeiShuOAuth20
+from fastapi_oauth20.clients.gitee import GiteeOAuth20
+from fastapi_oauth20.clients.github import GitHubOAuth20
+from fastapi_oauth20.clients.google import GoogleOAuth20
+from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20
+from fastapi_oauth20.clients.oschina import OSChinaOAuth20
+from fastapi_oauth20.errors import (
+ AccessTokenError,
+ GetUserInfoError,
+ HTTPXOAuth20Error,
+ OAuth20RequestError,
+ RefreshTokenError,
+ RevokeTokenError,
+)
+
+
+def test_feishu_client_creation():
+ """Test FeiShu client can be created."""
+ client = FeiShuOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.default_scopes is not None
+
+
+def test_github_client_creation():
+ """Test GitHub client can be created."""
+ client = GitHubOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.default_scopes is not None
+
+
+def test_google_client_creation():
+ """Test Google client can be created."""
+ client = GoogleOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.refresh_token_endpoint is not None
+ assert client.revoke_token_endpoint is not None
+ assert client.default_scopes is not None
+
+
+def test_gitee_client_creation():
+ """Test Gitee client can be created."""
+ client = GiteeOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.refresh_token_endpoint is not None
+ assert client.default_scopes is not None
+
+
+def test_oschina_client_creation():
+ """Test OSChina client can be created."""
+ client = OSChinaOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.refresh_token_endpoint is not None
+
+
+def test_linuxdo_client_creation():
+ """Test Linux.do client can be created."""
+ client = LinuxDoOAuth20('test_id', 'test_secret')
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint is not None
+ assert client.access_token_endpoint is not None
+ assert client.refresh_token_endpoint is not None
+ assert client.token_endpoint_basic_auth is True
+
+
+def test_all_clients_inherit_from_base():
+ """Test all clients inherit from OAuth20Base."""
+ from fastapi_oauth20.oauth20 import OAuth20Base
+
+ clients = [
+ FeiShuOAuth20('id', 'secret'),
+ GitHubOAuth20('id', 'secret'),
+ GoogleOAuth20('id', 'secret'),
+ GiteeOAuth20('id', 'secret'),
+ OSChinaOAuth20('id', 'secret'),
+ LinuxDoOAuth20('id', 'secret'),
+ ]
+
+ for client in clients:
+ assert isinstance(client, OAuth20Base)
+
+
+def test_error_classes_creation():
+ """Test all error classes can be created."""
+ # Test basic error creation
+ base_error = OAuth20RequestError('Base error')
+ assert str(base_error) == 'Base error'
+
+ # Test specialized errors
+ access_error = AccessTokenError('Access error')
+ assert str(access_error) == 'Access error'
+
+ refresh_error = RefreshTokenError('Refresh error')
+ assert str(refresh_error) == 'Refresh error'
+
+ revoke_error = RevokeTokenError('Revoke error')
+ assert str(revoke_error) == 'Revoke error'
+
+ user_error = GetUserInfoError('User info error')
+ assert str(user_error) == 'User info error'
+
+ httpx_error = HTTPXOAuth20Error('HTTPX error')
+ assert str(httpx_error) == 'HTTPX error'
+
+
+def test_error_inheritance():
+ """Test error class inheritance hierarchy."""
+ assert issubclass(AccessTokenError, OAuth20RequestError)
+ assert issubclass(RefreshTokenError, OAuth20RequestError)
+ assert issubclass(RevokeTokenError, OAuth20RequestError)
+ assert issubclass(GetUserInfoError, OAuth20RequestError)
+ assert issubclass(HTTPXOAuth20Error, OAuth20RequestError)
+
+
+def test_client_endpoint_urls():
+ """Test clients have proper endpoint URLs."""
+
+ # FeiShu
+ feishu = FeiShuOAuth20('id', 'secret')
+ assert 'feishu.cn' in feishu.authorize_endpoint
+ assert 'feishu.cn' in feishu.access_token_endpoint
+
+ # GitHub
+ github = GitHubOAuth20('id', 'secret')
+ assert 'github.com' in github.authorize_endpoint
+ assert 'github.com' in github.access_token_endpoint
+
+ # Google
+ google = GoogleOAuth20('id', 'secret')
+ assert 'google.com' in google.authorize_endpoint
+ assert 'googleapis.com' in google.access_token_endpoint
+
+ # Gitee
+ gitee = GiteeOAuth20('id', 'secret')
+ assert 'gitee.com' in gitee.authorize_endpoint
+ assert 'gitee.com' in gitee.access_token_endpoint
+
+ # OSChina
+ oschina = OSChinaOAuth20('id', 'secret')
+ assert 'oschina.net' in oschina.authorize_endpoint
+ assert 'oschina.net' in oschina.access_token_endpoint
+
+ # Linux.do
+ linuxdo = LinuxDoOAuth20('id', 'secret')
+ assert 'linux.do' in linuxdo.authorize_endpoint
+ assert 'linux.do' in linuxdo.access_token_endpoint
+
+
+def test_client_scopes():
+ """Test clients have appropriate default scopes."""
+
+ feishu = FeiShuOAuth20('id', 'secret')
+ assert isinstance(feishu.default_scopes, list)
+ assert len(feishu.default_scopes) > 0
+
+ github = GitHubOAuth20('id', 'secret')
+ assert isinstance(github.default_scopes, list)
+ assert len(github.default_scopes) > 0
+
+ google = GoogleOAuth20('id', 'secret')
+ assert isinstance(google.default_scopes, list)
+ assert len(google.default_scopes) > 0
+
+ gitee = GiteeOAuth20('id', 'secret')
+ assert isinstance(gitee.default_scopes, list)
+ assert len(gitee.default_scopes) > 0
+
+ oschina = OSChinaOAuth20('id', 'secret')
+ # OSChina has no default scopes
+ assert oschina.default_scopes is None
+
+ linuxdo = LinuxDoOAuth20('id', 'secret')
+ # Linux.do has no default scopes
+ assert linuxdo.default_scopes is None
+
+
+def test_multiple_clients_independence():
+ """Test multiple client instances work independently."""
+ feishu1 = FeiShuOAuth20('id1', 'secret1')
+ feishu2 = FeiShuOAuth20('id2', 'secret2')
+
+ assert feishu1.client_id != feishu2.client_id
+ assert feishu1.client_secret != feishu2.client_secret
+ assert feishu1.authorize_endpoint == feishu2.authorize_endpoint
+
+
+def test_google_special_features():
+ """Test Google client special features."""
+ google = GoogleOAuth20('id', 'secret')
+
+ # Google has both refresh and revoke endpoints
+ assert google.refresh_token_endpoint is not None
+ assert google.revoke_token_endpoint is not None
+
+ # Google uses same endpoint for refresh and access
+ assert google.refresh_token_endpoint == google.access_token_endpoint
+
+
+def test_linuxdo_special_features():
+ """Test Linux.do client special features."""
+ linuxdo = LinuxDoOAuth20('id', 'secret')
+
+ # Linux.do uses basic auth for token endpoint
+ assert linuxdo.token_endpoint_basic_auth is True
diff --git a/tests/test_errors.py b/tests/test_errors.py
new file mode 100644
index 0000000..52303ff
--- /dev/null
+++ b/tests/test_errors.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+
+from fastapi_oauth20.errors import (
+ AccessTokenError,
+ GetUserInfoError,
+ HTTPXOAuth20Error,
+ OAuth20RequestError,
+ RefreshTokenError,
+ RevokeTokenError,
+)
+
+
+def test_oauth20_request_error_basic():
+ """Test basic OAuth20RequestError creation."""
+ error = OAuth20RequestError('Test error')
+ assert str(error) == 'Test error'
+ assert error.msg == 'Test error'
+
+
+def test_oauth20_request_error_with_response():
+ """Test OAuth20RequestError with HTTP response."""
+ mock_response = httpx.Response(400)
+ error = OAuth20RequestError('Bad request', mock_response)
+ assert str(error) == 'Bad request'
+ assert error.msg == 'Bad request'
+ assert error.response is mock_response
+
+
+def test_httpx_oauth20_error_basic():
+ """Test basic HTTPXOAuth20Error creation."""
+ error = HTTPXOAuth20Error('HTTP error')
+ assert str(error) == 'HTTP error'
+ assert error.msg == 'HTTP error'
+
+
+def test_httpx_oauth20_error_with_response():
+ """Test HTTPXOAuth20Error with HTTP response."""
+ mock_response = httpx.Response(404)
+ error = HTTPXOAuth20Error('Not found', mock_response)
+ assert str(error) == 'Not found'
+ assert error.msg == 'Not found'
+ assert error.response is mock_response
+
+
+def test_access_token_error():
+ """Test AccessTokenError creation and inheritance."""
+ mock_response = httpx.Response(401)
+ error = AccessTokenError('Invalid token', mock_response)
+
+ assert str(error) == 'Invalid token'
+ assert isinstance(error, OAuth20RequestError)
+ assert error.response is mock_response
+
+
+def test_refresh_token_error():
+ """Test RefreshTokenError creation and inheritance."""
+ mock_response = httpx.Response(401)
+ error = RefreshTokenError('Invalid refresh token', mock_response)
+
+ assert str(error) == 'Invalid refresh token'
+ assert isinstance(error, OAuth20RequestError)
+ assert error.response is mock_response
+
+
+def test_revoke_token_error():
+ """Test RevokeTokenError creation and inheritance."""
+ mock_response = httpx.Response(400)
+ error = RevokeTokenError('Revocation failed', mock_response)
+
+ assert str(error) == 'Revocation failed'
+ assert isinstance(error, OAuth20RequestError)
+ assert error.response is mock_response
+
+
+def test_get_userinfo_error():
+ """Test GetUserInfoError creation and inheritance."""
+ mock_response = httpx.Response(403)
+ error = GetUserInfoError('Access denied', mock_response)
+
+ assert str(error) == 'Access denied'
+ assert isinstance(error, OAuth20RequestError)
+ assert error.response is mock_response
+
+
+def test_error_inheritance_chain():
+ """Test that all OAuth2 errors have proper inheritance."""
+ assert issubclass(AccessTokenError, OAuth20RequestError)
+ assert issubclass(RefreshTokenError, OAuth20RequestError)
+ assert issubclass(RevokeTokenError, OAuth20RequestError)
+ assert issubclass(GetUserInfoError, OAuth20RequestError)
+
+ assert issubclass(HTTPXOAuth20Error, OAuth20RequestError)
+ assert issubclass(OAuth20RequestError, Exception)
+
+
+def test_error_without_response():
+ """Test error creation without HTTP response."""
+ error = AccessTokenError('Simple error')
+ assert str(error) == 'Simple error'
+ assert error.msg == 'Simple error'
+ assert not hasattr(error, 'response') or error.response is None
+
+
+def test_error_catch_hierarchy():
+ """Test that errors can be caught at different levels of hierarchy."""
+ mock_response = httpx.Response(400)
+
+ # Specific error type
+ try:
+ raise AccessTokenError('Access token error', mock_response)
+ except AccessTokenError as e:
+ assert str(e) == 'Access token error'
+
+ # Parent OAuth20RequestError type
+ try:
+ raise RefreshTokenError('Refresh token error', mock_response)
+ except OAuth20RequestError as e:
+ assert str(e) == 'Refresh token error'
+
+ # HTTPXOAuth20Error type
+ try:
+ raise HTTPXOAuth20Error('HTTPX error', mock_response)
+ except HTTPXOAuth20Error as e:
+ assert str(e) == 'HTTPX error'
+
+
+def test_error_properties():
+ """Test that error objects have expected properties."""
+ mock_response = httpx.Response(500)
+
+ error = RevokeTokenError('Server error', mock_response)
+ assert hasattr(error, 'msg')
+ assert hasattr(error, 'response')
+ assert error.msg == 'Server error'
+ assert error.response == mock_response
+
+
+def test_error_str_representation():
+ """Test string representation of errors."""
+ # Error without response
+ error1 = AccessTokenError('Simple message')
+ assert str(error1) == 'Simple message'
+
+ # Error with response
+ mock_response = httpx.Response(404)
+ error2 = GetUserInfoError('User not found', mock_response)
+ assert str(error2) == 'User not found'
+
+
+def test_error_with_complex_message():
+ """Test errors with complex or multi-line messages."""
+ complex_message = "Error: Invalid request\nDetails: Missing required parameter 'code'"
+ error = OAuth20RequestError(complex_message)
+ assert str(error) == complex_message
diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py
new file mode 100644
index 0000000..7010f42
--- /dev/null
+++ b/tests/test_fastapi.py
@@ -0,0 +1,338 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import httpx
+import pytest
+import respx
+
+from fastapi import Depends, FastAPI, Request
+from fastapi.responses import JSONResponse
+from fastapi.testclient import TestClient
+
+from fastapi_oauth20 import FastAPIOAuth20, OAuth20AuthorizeCallbackError
+from tests.conftest import (
+ OAUTH_PROVIDERS,
+ TEST_ACCESS_TOKEN,
+ TEST_CLIENT_ID,
+ TEST_CLIENT_SECRET,
+ TEST_STATE,
+ assert_oauth_error_response,
+ mock_oauth_token_response,
+ setup_oauth_callback_route,
+)
+
+# Integration test specific constants
+LOCALHOST_URL = 'http://localhost:8000'
+DEV_URL = 'http://dev.example.com'
+APP_URL = 'http://app.example.org'
+IP_URL = 'http://192.168.1.100:8080'
+SECURE_LOCALHOST_URL = 'https://localhost:8000'
+SECURE_DEV_URL = 'https://secure.example.com'
+SECURE_APP_URL = 'https://app.example.org'
+SECURE_IP_URL = 'https://192.168.1.100:8443'
+
+# Callback paths
+AUTH_PATH = '/auth'
+CALLBACK_PATH = '/callback'
+OAUTH_CALLBACK_PATH = '/oauth2/callback'
+
+# Test URIs
+HTTP_URIS = [
+ f'{LOCALHOST_URL}{CALLBACK_PATH}',
+ f'{DEV_URL}{AUTH_PATH}/callback',
+ f'{APP_URL}{OAUTH_CALLBACK_PATH}',
+ f'{IP_URL}{AUTH_PATH}',
+]
+
+HTTPS_URIS = [
+ f'{SECURE_LOCALHOST_URL}{CALLBACK_PATH}',
+ f'{SECURE_DEV_URL}{AUTH_PATH}/callback',
+ f'{SECURE_APP_URL}{OAUTH_CALLBACK_PATH}',
+ f'{SECURE_IP_URL}{AUTH_PATH}',
+]
+
+
+class TestFastAPIOAuth20Basic:
+ """Basic tests for FastAPI OAuth2 integration using parametrized providers."""
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ async def test_successful_callback_parametrized(self, provider_config, fastapi_app):
+ """Test successful OAuth2 callback with multiple providers."""
+ # Mock token exchange
+ mock_oauth_token_response(respx, provider_config)
+
+ # Create OAuth client and dependency
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+
+ # Setup callback route
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ # Test successful callback
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}')
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data['provider'] == provider_config['name']
+ assert data['access_token']['access_token'] == TEST_ACCESS_TOKEN
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ async def test_callback_missing_code_parametrized(self, provider_config, fastapi_app):
+ """Test OAuth2 callback with missing authorization code."""
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?state={TEST_STATE}')
+
+ assert_oauth_error_response(response, 'Bad Request')
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ async def test_callback_with_error_parametrized(self, provider_config, fastapi_app):
+ """Test OAuth2 callback with error parameter."""
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?error=access_denied&state={TEST_STATE}')
+
+ assert_oauth_error_response(response, 'access_denied')
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ async def test_callback_token_exchange_error_parametrized(self, provider_config, fastapi_app):
+ """Test OAuth2 callback with token exchange error."""
+ # Mock token exchange error
+ respx.post(provider_config['token_url']).mock(return_value=httpx.Response(400, text='Bad Request'))
+
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=invalid_code&state={TEST_STATE}')
+
+ assert response.status_code == 500
+ assert 'detail' in response.json()
+
+ def test_custom_exception_handler(self, github_client, fastapi_app):
+ """Test custom exception handler for OAuth2 errors."""
+ github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback')
+
+ @fastapi_app.get('/auth/github/callback')
+ async def github_callback(access_token_state=Depends(github_oauth2_callback)):
+ token, state = access_token_state
+ return {'access_token': token, 'state': state}
+
+ @fastapi_app.exception_handler(OAuth20AuthorizeCallbackError)
+ async def oauth2_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={
+ 'message': 'OAuth2 authentication failed',
+ 'error': exc.detail,
+ 'status_code': exc.status_code,
+ },
+ )
+
+ client = TestClient(fastapi_app)
+ response = client.get('/auth/github/callback?error=access_denied')
+
+ assert response.status_code == 400
+ data = response.json()
+ assert data['message'] == 'OAuth2 authentication failed'
+ assert data['error'] == 'access_denied'
+ assert data['status_code'] == 400
+
+ def test_multiple_oauth_providers(self, github_client, google_client, fastapi_app):
+ """Test multiple OAuth providers in the same app."""
+ # Setup GitHub OAuth
+ github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback')
+
+ @fastapi_app.get('/auth/github/callback')
+ async def github_callback(access_token_state=Depends(github_oauth2_callback)):
+ token, state = access_token_state
+ return {'provider': 'github', 'access_token': token, 'state': state}
+
+ # Setup Google OAuth
+ google_oauth2_callback = FastAPIOAuth20(google_client, redirect_uri=f'{LOCALHOST_URL}/auth/google/callback')
+
+ @fastapi_app.get('/auth/google/callback')
+ async def google_callback(access_token_state=Depends(google_oauth2_callback)):
+ token, state = access_token_state
+ return {'provider': 'google', 'access_token': token, 'state': state}
+
+ client = TestClient(fastapi_app)
+
+ # Test GitHub route
+ response = client.get('/auth/github/callback?error=access_denied')
+ assert response.status_code == 400
+
+ # Test Google route
+ response = client.get('/auth/google/callback?error=access_denied')
+ assert response.status_code == 400
+
+
+class TestFastAPIOAuth20PKCE:
+ """Test PKCE functionality across providers."""
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ async def test_pkce_flow(self, provider_config, fastapi_app):
+ """Test PKCE flow with code verifier and challenge."""
+ # Mock token exchange
+ mock_oauth_token_response(respx, provider_config)
+
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}&code_verifier=test_verifier')
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data['provider'] == provider_config['name']
+
+
+class TestFastAPIOAuth20URIs:
+ """Test URI validation and functionality."""
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ @pytest.mark.parametrize('redirect_uri', HTTP_URIS)
+ async def test_http_redirect_uris(self, provider_config, redirect_uri, fastapi_app):
+ """Test OAuth2 with HTTP redirect URIs."""
+ # Mock token exchange
+ mock_oauth_token_response(respx, provider_config)
+
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=redirect_uri)
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}')
+
+ assert response.status_code == 200
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ @pytest.mark.parametrize('redirect_uri', HTTPS_URIS)
+ async def test_https_redirect_uris(self, provider_config, redirect_uri, fastapi_app):
+ """Test OAuth2 with HTTPS redirect URIs."""
+ # Mock token exchange
+ mock_oauth_token_response(respx, provider_config)
+
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=redirect_uri)
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=test_code&state={TEST_STATE}')
+
+ assert response.status_code == 200
+
+
+class TestFastAPIOAuth20ErrorScenarios:
+ """Test various error scenarios across providers."""
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ @pytest.mark.parametrize('error_code', ['access_denied', 'temporarily_unavailable', 'invalid_request'])
+ async def test_oauth_errors(self, provider_config, error_code, fastapi_app):
+ """Test various OAuth error codes."""
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?error={error_code}&state={TEST_STATE}')
+
+ assert_oauth_error_response(response, error_code)
+
+ @pytest.mark.asyncio
+ @respx.mock
+ @pytest.mark.parametrize('provider_config', OAUTH_PROVIDERS)
+ @pytest.mark.parametrize('http_status', [400, 401, 500, 502])
+ async def test_token_exchange_errors(self, provider_config, http_status, fastapi_app):
+ """Test token exchange with various HTTP error codes."""
+ # Mock token exchange error
+ respx.post(provider_config['token_url']).mock(
+ return_value=httpx.Response(http_status, text=f'Error {http_status}')
+ )
+
+ client_class = provider_config['client_class']
+ oauth_client = client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET)
+ oauth_callback = FastAPIOAuth20(oauth_client, redirect_uri=provider_config['redirect_uri'])
+ callback_path = setup_oauth_callback_route(fastapi_app, provider_config, oauth_callback)
+
+ client = TestClient(fastapi_app)
+ response = client.get(f'{callback_path}?code=invalid_code&state={TEST_STATE}')
+
+ # Should return server error for token exchange failures
+ assert response.status_code >= 400
+ assert 'detail' in response.json()
+
+
+# Additional test classes for different scenarios
+class TestFastAPIOAuth20Integration:
+ """Integration tests for FastAPI OAuth2."""
+
+ def test_oauth_dependency_creation(self, github_client):
+ """Test OAuth dependency creation with different parameters."""
+ # Basic OAuth dependency
+ oauth_dep = FastAPIOAuth20(github_client)
+ assert oauth_dep.client == github_client
+
+ # OAuth dependency with custom redirect URI
+ custom_redirect = f'{LOCALHOST_URL}/custom/callback'
+ oauth_dep_custom = FastAPIOAuth20(github_client, redirect_uri=custom_redirect)
+ assert oauth_dep_custom.client == github_client
+
+ def test_multiple_apps_same_provider(self, github_client):
+ """Test the same OAuth provider in multiple FastAPI apps."""
+ app1 = FastAPI()
+ app2 = FastAPI()
+
+ # Setup first app
+ oauth_dep1 = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/app1/callback')
+
+ @app1.get('/auth/github/callback')
+ async def github_callback1(access_token_state=Depends(oauth_dep1)):
+ return {'app': 'app1', 'access_token': access_token_state}
+
+ # Setup second app
+ oauth_dep2 = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/app2/callback')
+
+ @app2.get('/auth/github/callback')
+ async def github_callback2(access_token_state=Depends(oauth_dep2)):
+ return {'app': 'app2', 'access_token': access_token_state}
+
+ client1 = TestClient(app1)
+ client2 = TestClient(app2)
+
+ # Both apps should work independently
+ response1 = client1.get('/auth/github/callback?error=access_denied')
+ response2 = client2.get('/auth/github/callback?error=access_denied')
+
+ assert response1.status_code == 400
+ assert response2.status_code == 400
diff --git a/tests/test_oauth20.py b/tests/test_oauth20.py
new file mode 100644
index 0000000..47e2d07
--- /dev/null
+++ b/tests/test_oauth20.py
@@ -0,0 +1,385 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import json
+
+from unittest.mock import Mock
+
+import httpx
+import pytest
+import respx
+
+from fastapi_oauth20.errors import (
+ AccessTokenError,
+ HTTPXOAuth20Error,
+ RefreshTokenError,
+ RevokeTokenError,
+)
+from fastapi_oauth20.oauth20 import OAuth20Base
+
+
+class MockOAuth20Client(OAuth20Base):
+ """Test implementation of OAuth20Base for testing purposes."""
+
+ async def get_userinfo(self, access_token: str) -> dict[str, any]:
+ """Mock implementation for testing."""
+ return {'user_id': 'test_user', 'access_token': access_token}
+
+
+@pytest.fixture
+def oauth_client():
+ """Create OAuth20Base client instance for testing."""
+ return MockOAuth20Client(
+ client_id='test_client_id',
+ client_secret='test_client_secret',
+ authorize_endpoint='https://example.com/oauth/authorize',
+ access_token_endpoint='https://example.com/oauth/token',
+ refresh_token_endpoint='https://example.com/oauth/refresh',
+ revoke_token_endpoint='https://example.com/oauth/revoke',
+ default_scopes=['read', 'write'],
+ )
+
+
+def test_oauth_base_initialization(oauth_client):
+ """Test OAuth20Base initialization with all parameters."""
+ assert oauth_client.client_id == 'test_client_id'
+ assert oauth_client.client_secret == 'test_client_secret'
+ assert oauth_client.authorize_endpoint == 'https://example.com/oauth/authorize'
+ assert oauth_client.access_token_endpoint == 'https://example.com/oauth/token'
+ assert oauth_client.refresh_token_endpoint == 'https://example.com/oauth/refresh'
+ assert oauth_client.revoke_token_endpoint == 'https://example.com/oauth/revoke'
+ assert oauth_client.default_scopes == ['read', 'write']
+ assert oauth_client.token_endpoint_basic_auth is False
+ assert oauth_client.revoke_token_endpoint_basic_auth is False
+ assert oauth_client.request_headers == {'Accept': 'application/json'}
+
+
+def test_oauth_base_initialization_minimal():
+ """Test OAuth20Base initialization with minimal required parameters."""
+ client = MockOAuth20Client(
+ client_id='test_id',
+ client_secret='test_secret',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ )
+
+ assert client.client_id == 'test_id'
+ assert client.client_secret == 'test_secret'
+ assert client.authorize_endpoint == 'https://example.com/auth'
+ assert client.access_token_endpoint == 'https://example.com/token'
+ assert client.refresh_token_endpoint is None
+ assert client.revoke_token_endpoint is None
+ assert client.default_scopes is None
+
+
+def test_oauth_base_initialization_with_basic_auth():
+ """Test OAuth20Base initialization with basic authentication enabled."""
+ client = MockOAuth20Client(
+ client_id='test_id',
+ client_secret='test_secret',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ token_endpoint_basic_auth=True,
+ revoke_token_endpoint_basic_auth=True,
+ )
+
+ assert client.token_endpoint_basic_auth is True
+ assert client.revoke_token_endpoint_basic_auth is True
+
+
+@pytest.mark.asyncio
+async def test_get_authorization_url_basic(oauth_client):
+ """Test basic authorization URL generation."""
+ url = await oauth_client.get_authorization_url(redirect_uri='https://example.com/callback')
+
+ assert 'https://example.com/oauth/authorize' in url
+ assert 'client_id=test_client_id' in url
+ assert 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback' in url
+ assert 'response_type=code' in url
+ assert 'scope=read+write' in url
+
+
+@pytest.mark.asyncio
+async def test_get_authorization_url_with_state(oauth_client):
+ """Test authorization URL generation with state parameter."""
+ url = await oauth_client.get_authorization_url(
+ redirect_uri='https://example.com/callback', state='random_state_123'
+ )
+
+ assert 'state=random_state_123' in url
+
+
+@pytest.mark.asyncio
+async def test_get_authorization_url_with_custom_scope(oauth_client):
+ """Test authorization URL generation with custom scope."""
+ url = await oauth_client.get_authorization_url(
+ redirect_uri='https://example.com/callback', scope=['read', 'delete']
+ )
+
+ assert 'scope=read+delete' in url
+ assert 'write' not in url
+
+
+@pytest.mark.asyncio
+async def test_get_authorization_url_with_pkce(oauth_client):
+ """Test authorization URL generation with PKCE parameters."""
+ url = await oauth_client.get_authorization_url(
+ redirect_uri='https://example.com/callback', code_challenge='challenge_123', code_challenge_method='S256'
+ )
+
+ assert 'code_challenge=challenge_123' in url
+ assert 'code_challenge_method=S256' in url
+
+
+@pytest.mark.asyncio
+async def test_get_authorization_url_with_extra_params(oauth_client):
+ """Test authorization URL generation with additional parameters."""
+ url = await oauth_client.get_authorization_url(
+ redirect_uri='https://example.com/callback', access_type='offline', prompt='consent'
+ )
+
+ assert 'access_type=offline' in url
+ assert 'prompt=consent' in url
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_get_access_token_success(oauth_client):
+ """Test successful access token exchange."""
+ mock_token_data = {
+ 'access_token': 'new_access_token',
+ 'token_type': 'Bearer',
+ 'expires_in': 3600,
+ 'refresh_token': 'refresh_token_123',
+ }
+
+ # Mock the token endpoint
+ respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data))
+
+ result = await oauth_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback')
+ assert result == mock_token_data
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_get_access_token_with_code_verifier(oauth_client):
+ """Test access token exchange with PKCE code verifier."""
+ mock_token_data = {'access_token': 'new_access_token'}
+
+ # Mock the token endpoint and capture the request
+ route = respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data))
+
+ await oauth_client.get_access_token(
+ code='auth_code_123', redirect_uri='https://example.com/callback', code_verifier='verifier_123'
+ )
+
+ # Verify the request was made with code_verifier
+ assert route.called
+ request_data = route.calls[0].request.content.decode()
+ assert 'code_verifier=verifier_123' in request_data
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_get_access_token_with_basic_auth():
+ """Test access token exchange with HTTP Basic Authentication."""
+ client = MockOAuth20Client(
+ client_id='test_id',
+ client_secret='test_secret',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ token_endpoint_basic_auth=True,
+ )
+
+ mock_token_data = {'access_token': 'new_access_token'}
+
+ # Mock the token endpoint
+ route = respx.post('https://example.com/token').mock(return_value=httpx.Response(200, json=mock_token_data))
+
+ await client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback')
+
+ # Verify BasicAuth was used
+ assert route.called
+ request = route.calls[0].request
+ assert 'authorization' in request.headers
+ # Basic auth should be present
+ assert request.headers['authorization'].startswith('Basic ')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_get_access_token_http_error(oauth_client):
+ """Test handling of HTTP errors during access token exchange."""
+ # Mock HTTP error response
+ respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(400, text='Bad Request'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await oauth_client.get_access_token(code='invalid_code', redirect_uri='https://example.com/callback')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_refresh_token_success(oauth_client):
+ """Test successful token refresh."""
+ mock_token_data = {'access_token': 'refreshed_access_token', 'token_type': 'Bearer', 'expires_in': 3600}
+
+ # Mock the refresh endpoint
+ respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(200, json=mock_token_data))
+
+ result = await oauth_client.refresh_token('refresh_token_123')
+ assert result == mock_token_data
+
+
+@pytest.mark.asyncio
+async def test_refresh_token_missing_endpoint():
+ """Test refresh token when refresh endpoint is not configured."""
+ client = MockOAuth20Client(
+ client_id='test_id',
+ client_secret='test_secret',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ )
+
+ with pytest.raises(RefreshTokenError, match='refresh token address is missing'):
+ await client.refresh_token('refresh_token_123')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_refresh_token_http_error(oauth_client):
+ """Test handling of HTTP errors during token refresh."""
+ # Mock HTTP error response
+ respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(401, text='Unauthorized'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await oauth_client.refresh_token('invalid_refresh_token')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_revoke_token_success(oauth_client):
+ """Test successful token revocation."""
+ # Mock successful revocation response
+ respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK'))
+
+ # Should not raise any exception for successful revocation
+ await oauth_client.revoke_token('access_token_123')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_revoke_token_with_type_hint(oauth_client):
+ """Test token revocation with token type hint."""
+ # Mock the revoke endpoint and capture the request
+ route = respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK'))
+
+ await oauth_client.revoke_token(token='refresh_token_123', token_type_hint='refresh_token')
+
+ # Verify token_type_hint was included in the request
+ assert route.called
+ request_data = route.calls[0].request.content.decode()
+ assert 'token_type_hint=refresh_token' in request_data
+
+
+@pytest.mark.asyncio
+async def test_revoke_token_missing_endpoint():
+ """Test token revocation when revoke endpoint is not configured."""
+ client = MockOAuth20Client(
+ client_id='test_id',
+ client_secret='test_secret',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ )
+
+ with pytest.raises(RevokeTokenError, match='revoke token address is missing'):
+ await client.revoke_token('access_token_123')
+
+
+@pytest.mark.asyncio
+@respx.mock
+async def test_revoke_token_http_error(oauth_client):
+ """Test handling of HTTP errors during token revocation."""
+ # Mock HTTP error response
+ respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(400, text='Bad Request'))
+
+ with pytest.raises(HTTPXOAuth20Error):
+ await oauth_client.revoke_token('invalid_token')
+
+
+def test_raise_httpx_oauth20_errors_success():
+ """Test successful HTTP response validation."""
+ mock_response = Mock()
+ mock_response.raise_for_status.return_value = None
+
+ # Should not raise any exception
+ OAuth20Base.raise_httpx_oauth20_errors(mock_response)
+
+
+def test_raise_httpx_oauth20_errors_http_status_error():
+ """Test handling of HTTP status errors."""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
+ 'Not Found', request=None, response=mock_response
+ )
+
+ with pytest.raises(HTTPXOAuth20Error):
+ OAuth20Base.raise_httpx_oauth20_errors(mock_response)
+
+
+def test_raise_httpx_oauth20_errors_network_error():
+ """Test handling of network errors."""
+ # Test with a mock response that will raise RequestError when raise_for_status is called
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = httpx.RequestError('Network error')
+
+ with pytest.raises(HTTPXOAuth20Error):
+ OAuth20Base.raise_httpx_oauth20_errors(mock_response)
+
+
+def test_get_json_result_success():
+ """Test successful JSON result parsing."""
+ mock_response = Mock()
+ mock_response.json.return_value = {'key': 'value'}
+
+ result = OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError)
+ assert result == {'key': 'value'}
+
+
+def test_get_json_result_invalid_json():
+ """Test handling of invalid JSON response."""
+ mock_response = Mock()
+ mock_response.json.side_effect = json.JSONDecodeError('Invalid JSON', '', 0)
+
+ with pytest.raises(AccessTokenError, match='Result serialization failed'):
+ OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError)
+
+
+def test_abstract_method():
+ """Test that get_userinfo is properly abstract."""
+ with pytest.raises(TypeError):
+ OAuth20Base(
+ client_id='test',
+ client_secret='test',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ )
+
+
+def test_oauth_base_inheritance():
+ """Test that OAuth20Base is properly abstract."""
+ from abc import ABC
+
+ assert issubclass(OAuth20Base, ABC)
+
+
+@pytest.mark.asyncio
+async def test_get_userinfo_implementation():
+ """Test that concrete implementation of get_userinfo works."""
+ client = MockOAuth20Client(
+ client_id='test',
+ client_secret='test',
+ authorize_endpoint='https://example.com/auth',
+ access_token_endpoint='https://example.com/token',
+ )
+
+ result = await client.get_userinfo('test_token')
+ assert result == {'user_id': 'test_user', 'access_token': 'test_token'}
diff --git a/uv.lock b/uv.lock
index 4291c8d..2bc8536 100644
--- a/uv.lock
+++ b/uv.lock
@@ -167,14 +167,14 @@ wheels = [
[[package]]
name = "click"
-version = "8.3.0"
+version = "8.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/46/61/de6cd827efad202d7057d93e0fed9294b96952e188f7384832791c7b2254/click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4", size = 276943, upload-time = "2025-09-18T17:32:23.696Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" },
+ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" },
]
[[package]]
@@ -230,6 +230,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
+ { name = "click" },
{ name = "fastapi" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
@@ -249,6 +250,7 @@ requires-dist = [{ name = "httpx", specifier = ">=0.18.0" }]
[package.metadata.requires-dev]
dev = [
+ { name = "click", specifier = "==8.2.1" },
{ name = "fastapi", specifier = ">=0.119.0" },
{ name = "pytest", specifier = ">=8.4.0" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" },