diff --git a/docs/fastapi.md b/docs/fastapi.md deleted file mode 100644 index 8c20ac7..0000000 --- a/docs/fastapi.md +++ /dev/null @@ -1,46 +0,0 @@ -提供的实用程序可简化 FastAPI 中 OAuth2 流程的集成 - -## `FastAPIOAuth20` - -依赖关系可调用,用于处理授权回调,它读取查询参数并返回访问令牌和状态 - -```python -from fastapi import FastAPI, Depends -from fastapi_oauth20 import FastAPIOAuth20, LinuxDoOAuth20 - -client = LinuxDoOAuth20("CLIENT_ID", "CLIENT_SECRET") -linuxdo_oauth2_callback = FastAPIOAuth20(client, "oauth2-callback") - -app = FastAPI() - - -@app.get("/oauth2-callback", name="oauth-callback") -async def oauth2_callback(access_token_state=Depends(linuxdo_oauth2_callback)): - token, state = access_token_state - # Do something useful -``` - -## 自定义异常 - -如果回调逻辑内部发生错误(用户拒绝访问、授权代码无效......),依赖关系将引发 `OAuth20AuthorizeCallbackError` 错误 - -它继承自 FastAPI 的 [HTTPException](https://fastapi.tiangolo.com/reference/exceptions/#fastapi.HTTPException),因此默认的 -FastAPI 异常处理程序会自动对其进行处理。您可以通过为 `OAuth20AuthorizeCallbackError` 实现自己的异常处理程序来自定义此行为 - -```python -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse -from fastapi_oauth20.callback import OAuth20AuthorizeCallbackError - -app = FastAPI() - - -@app.exception_handler(OAuth20AuthorizeCallbackError) -async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError): - detail = exc.detail - status_code = exc.status_code - return JSONResponse( - status_code=status_code, - content={"message": "The OAuth2 callback failed", "detail": detail}, - ) -``` diff --git a/docs/usage.md b/docs/usage.md index 9ac34ad..4a95f3a 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1,3 +1,5 @@ +from fastapi_oauth20 import FastAPIOAuth20 + # 使用指南 本指南介绍如何将 FastAPI OAuth2.0 库与各种 OAuth2 提供程序一起使用。 @@ -19,12 +21,6 @@ github_client = GitHubOAuth20( client_id="your_github_client_id", client_secret="your_github_client_secret" ) - -# 初始化 Google OAuth2 客户端 -google_client = GoogleOAuth20( - client_id="your_google_client_id", - client_secret="your_google_client_secret" -) ``` ### 2. 创建 FastAPI OAuth2 依赖 @@ -35,35 +31,22 @@ github_oauth = FastAPIOAuth20( client=github_client, redirect_uri="http://localhost:8000/auth/github/callback" ) - -google_oauth = FastAPIOAuth20( - client=google_client, - redirect_uri="http://localhost:8000/auth/google/callback" -) ``` ### 3. 定义授权端点 ```python -@app.get("/auth/{provider}") -async def auth_provider(provider: str): +@app.get("/oauth2/github") +async def oauth2_github(): """重定向用户到 OAuth2 提供商进行授权""" # 生成安全的 state 参数用于 CSRF 保护 state = secrets.token_urlsafe(32) - if provider == "github": - auth_url = await github_client.get_authorization_url( - redirect_uri="http://localhost:8000/auth/github/callback", - state=state - ) - elif provider == "google": - auth_url = await google_client.get_authorization_url( - redirect_uri="http://localhost:8000/auth/google/callback", - state=state - ) - else: - raise HTTPException(status_code=404, detail="不支持的提供商") + auth_url = await github_client.get_authorization_url( + redirect_uri="http://localhost:8000/auth/github/callback", + state=state + ) return RedirectResponse(url=auth_url) ``` @@ -71,12 +54,9 @@ async def auth_provider(provider: str): ### 4. 处理 OAuth 回调 ```python -from typing import Tuple, Dict, Any - - -@app.get("/auth/github/callback") -async def github_callback( - oauth_result: Tuple[Dict[str, Any], str] = Depends(github_oauth) +@app.get("/oauth2/github/callback") +async def oauth2_github_callback( + oauth_result: Annotated[FastAPIOAuth20, Depends(github_oauth)] ): """处理 GitHub OAuth 回调""" token_data, state = oauth_result @@ -89,142 +69,73 @@ async def github_callback( "access_token": token_data["access_token"], "state": state } - - -@app.get("/auth/google/callback") -async def google_callback( - oauth_result: Tuple[Dict[str, Any], str] = Depends(google_oauth) -): - """处理 Google OAuth 回调""" - token_data, state = oauth_result - - # 获取用户信息 - user_info = await google_client.get_userinfo(token_data["access_token"]) - - return { - "user": user_info, - "access_token": token_data["access_token"], - "state": state - } ``` -## 高级用法 - -### PKCE (Proof Key for Code Exchange) - -对于公共客户端(移动应用、SPA),使用 PKCE 增强安全性: - -```python -import base64 -import hashlib -import secrets - - -def generate_pkce_challenge(): - """生成 PKCE 代码验证器和挑战码""" - code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=') - code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode('utf-8')).digest() - ).decode('utf-8').rstrip('=') - return code_verifier, code_challenge - - -@app.get("/auth/github/pkce") -async def github_auth_pkce(): - """GitHub OAuth with PKCE""" - code_verifier, code_challenge = generate_pkce_challenge() - state = secrets.token_urlsafe(32) - - # 在实际应用中,应该将 code_verifier 和 state 存储在会话或数据库中 - # 这里为了演示目的简化处理 - - auth_url = await github_client.get_authorization_url( - redirect_uri="http://localhost:8000/auth/github/callback", - state=state, - code_challenge=code_challenge, - code_challenge_method="S256" - ) - - return RedirectResponse(url=auth_url) -``` - -### 令牌刷新 +## 令牌刷新 某些提供商支持刷新令牌来延长会话: ```python -@app.post("/auth/refresh") -async def refresh_token(refresh_token: str, provider: str): +@app.post("/oauth2/refresh") +async def oauth2_refresh_token(refresh_token: str): """使用刷新令牌获取新的访问令牌""" - - if provider == "github": - # GitHub 不支持 OAuth 刷新令牌 - raise HTTPException(status_code=400, detail="GitHub 不支持令牌刷新") - elif provider == "google": - try: - new_tokens = await google_client.refresh_token(refresh_token) - return {"access_token": new_tokens["access_token"]} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) + try: + new_tokens = await google_client.refresh_token(refresh_token) + return {"access_token": new_tokens["access_token"]} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) ``` -### 令牌撤销 +## 令牌撤销 用户登出时撤销令牌: ```python -@app.post("/auth/revoke") -async def revoke_token(access_token: str, provider: str): +@app.post("/oauth2/revoke") +async def oauth2_revoke_token(access_token: str): """撤销访问令牌""" - - if provider == "google": - try: - await google_client.revoke_token(access_token) - return {"message": "令牌撤销成功"} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - else: - raise HTTPException(status_code=400, detail="该提供商不支持令牌撤销") + try: + await google_client.revoke_token(access_token) + return {"message": "令牌撤销成功"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) ``` ## 错误处理 -库提供了全面的错误处理: +如果回调逻辑内部发生错误(用户拒绝访问、授权代码无效......),依赖关系将引发 `OAuth20AuthorizeCallbackError` 错误 -```python -from fastapi_oauth20.errors import ( - OAuth20AuthorizeCallbackError, - AccessTokenError, - GetUserInfoError -) +它继承自 FastAPI 的 [HTTPException](https://fastapi.tiangolo.com/reference/exceptions/#fastapi.HTTPException),因此默认的 +FastAPI 异常处理程序会自动对其进行处理。您可以通过为 `OAuth20AuthorizeCallbackError` 实现自己的异常处理程序来自定义此行为 +```python +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from fastapi_oauth20 import OAuth20AuthorizeCallbackError -@app.exception_handler(OAuth20AuthorizeCallbackError) -async def oauth_callback_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError): - """处理 OAuth 回调错误""" - return { - "error": "OAuth 授权失败", - "detail": exc.detail, - "status_code": exc.status_code - } +app = FastAPI() -@app.exception_handler(AccessTokenError) -async def access_token_error_handler(request: Request, exc: AccessTokenError): - """处理访问令牌错误""" - return { - "error": "访问令牌交换失败", - "detail": exc.msg - } +@app.exception_handler(OAuth20AuthorizeCallbackError) +async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth20AuthorizeCallbackError): + detail = exc.detail + status_code = exc.status_code + return JSONResponse( + status_code=status_code, + content={"message": "The OAuth2 callback failed", "detail": detail}, + ) ``` ## 完整示例 ```python +from typing import Annotated + from fastapi import FastAPI, Depends, HTTPException from fastapi.responses import RedirectResponse from fastapi_oauth20 import GitHubOAuth20, FastAPIOAuth20 -from fastapi_oauth20.errors import OAuth20AuthorizeCallbackError +from fastapi_oauth20 import OAuth20AuthorizeCallbackError + import secrets app = FastAPI() @@ -242,8 +153,8 @@ github_oauth = FastAPIOAuth20( ) -@app.get("/auth/github") -async def github_auth(): +@app.get("/oauth2/github") +async def oauth2_github(): """GitHub 授权入口""" state = secrets.token_urlsafe(32) auth_url = await github_client.get_authorization_url( @@ -253,9 +164,9 @@ async def github_auth(): return RedirectResponse(url=auth_url) -@app.get("/auth/github/callback") -async def github_callback( - oauth_result: tuple = Depends(github_oauth) +@app.get("/oauth2/github/callback") +async def oauth2_github_callback( + oauth_result: Annotated[FastAPIOAuth20, Depends(github_oauth)] ): """GitHub 授权回调""" token_data, state = oauth_result diff --git a/mkdocs.yml b/mkdocs.yml index 3fb5580..a78e217 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,8 +10,6 @@ nav: - 安装: install.md - 名词解释: explanation.md - 用法: usage.md - - 集成: - - FastAPI: fastapi.md - 客户端状态: status.md - 客户端申请: - LinuxDo: clients/linuxdo.md diff --git a/tests/clients/test_feishu.py b/tests/clients/test_feishu.py index a68b8e7..399d890 100644 --- a/tests/clients/test_feishu.py +++ b/tests/clients/test_feishu.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.feishu import FeiShuOAuth20 +from fastapi_oauth20 import FeiShuOAuth20 from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,13 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +FEISHU_USER_INFO_URL = 'https://passport.feishu.cn/suite/passport/oauth/userinfo' + + +@pytest.fixture +def feishu_client(): + """Create FeiShu OAuth2 client instance for testing.""" + return FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestFeiShuOAuth20: @@ -39,9 +43,9 @@ def test_feishu_client_initialization(self, feishu_client): def test_feishu_client_initialization_with_custom_credentials(self): """Test FeiShu client initialization with custom credentials.""" - client = FeiShuOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_feishu_client_inheritance(self, feishu_client): """Test that FeiShu client properly inherits from OAuth20Base.""" @@ -81,11 +85,7 @@ def test_feishu_client_multiple_instances(self): async def test_get_userinfo_success(self, feishu_client): """Test successful user info retrieval from FeiShu API.""" mock_user_data = create_mock_user_data('feishu') - mock_user_info_response( - respx, - {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, - mock_user_data, - ) + mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -95,11 +95,7 @@ async def test_get_userinfo_success(self, feishu_client): async def test_get_userinfo_with_different_access_token(self, feishu_client): """Test user info retrieval with different access tokens.""" mock_user_data = create_mock_user_data('feishu', user_id='user_789', name='Another User') - mock_user_info_response( - respx, - {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, - mock_user_data, - ) + mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) result = await feishu_client.get_userinfo('different_token') assert result == mock_user_data @@ -108,9 +104,7 @@ async def test_get_userinfo_with_different_access_token(self, feishu_client): @respx.mock async def test_get_userinfo_empty_response(self, feishu_client): """Test handling of empty user info response.""" - mock_user_info_response( - respx, {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, {} - ) + mock_user_info_response(respx, FEISHU_USER_INFO_URL, {}) result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @@ -120,11 +114,7 @@ async def test_get_userinfo_empty_response(self, feishu_client): async def test_get_userinfo_partial_data(self, feishu_client): """Test handling of partial user info response.""" partial_data = {'user_id': 'test_user', 'name': 'Test User'} - mock_user_info_response( - respx, - {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, - partial_data, - ) + mock_user_info_response(respx, FEISHU_USER_INFO_URL, partial_data) result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data @@ -134,11 +124,7 @@ async def test_get_userinfo_partial_data(self, feishu_client): async def test_get_userinfo_authorization_header(self, feishu_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'user_id': 'test_user'} - route = mock_user_info_response( - respx, - {'name': 'feishu', 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo'}, - mock_user_data, - ) + route = mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -151,9 +137,7 @@ async def test_get_userinfo_authorization_header(self, feishu_client): @respx.mock async def test_get_userinfo_http_error_401(self, feishu_client): """Test handling of 401 HTTP error when getting user info.""" - respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( - return_value=httpx.Response(401, text='Unauthorized') - ) + respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(INVALID_TOKEN) @@ -162,9 +146,7 @@ async def test_get_userinfo_http_error_401(self, feishu_client): @respx.mock async def test_get_userinfo_http_error_403(self, feishu_client): """Test handling of 403 HTTP error when getting user info.""" - respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( - return_value=httpx.Response(403, text='Forbidden') - ) + respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -173,9 +155,7 @@ async def test_get_userinfo_http_error_403(self, feishu_client): @respx.mock async def test_get_userinfo_http_error_500(self, feishu_client): """Test handling of 500 HTTP error when getting user info.""" - respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( - return_value=httpx.Response(500, text='Internal Server Error') - ) + respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -184,9 +164,7 @@ async def test_get_userinfo_http_error_500(self, feishu_client): @respx.mock async def test_get_userinfo_invalid_json(self, feishu_client): """Test handling of invalid JSON response.""" - respx.get('https://passport.feishu.cn/suite/passport/oauth/userinfo').mock( - return_value=httpx.Response(200, text='invalid json') - ) + respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) with pytest.raises(GetUserInfoError): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_gitee.py b/tests/clients/test_gitee.py index 7afa4a9..cc6e516 100644 --- a/tests/clients/test_gitee.py +++ b/tests/clients/test_gitee.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.gitee import GiteeOAuth20 +from fastapi_oauth20 import GiteeOAuth20 from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,13 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +GITEE_USER_INFO_URL = 'https://gitee.com/api/v5/user' + + +@pytest.fixture +def gitee_client(): + """Create Gitee OAuth2 client instance for testing.""" + return GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGiteeOAuth20: @@ -35,9 +39,9 @@ def test_gitee_client_initialization(self, gitee_client): def test_gitee_client_initialization_with_custom_credentials(self): """Test Gitee client initialization with custom credentials.""" - client = GiteeOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_gitee_client_inheritance(self, gitee_client): """Test that Gitee client properly inherits from OAuth20Base.""" @@ -67,9 +71,7 @@ def test_gitee_client_endpoint_urls(self): async def test_get_userinfo_success(self, gitee_client): """Test successful user info retrieval from Gitee API.""" mock_user_data = create_mock_user_data('gitee') - mock_user_info_response( - respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data - ) + mock_user_info_response(respx, GITEE_USER_INFO_URL, mock_user_data) result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -79,9 +81,7 @@ async def test_get_userinfo_success(self, gitee_client): async def test_get_userinfo_authorization_header(self, gitee_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} - route = mock_user_info_response( - respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, mock_user_data - ) + route = mock_user_info_response(respx, GITEE_USER_INFO_URL, mock_user_data) await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -94,7 +94,7 @@ async def test_get_userinfo_authorization_header(self, gitee_client): @respx.mock async def test_get_userinfo_http_error(self, gitee_client): """Test handling of HTTP errors when getting user info.""" - respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + respx.get(GITEE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await gitee_client.get_userinfo(INVALID_TOKEN) @@ -103,7 +103,7 @@ async def test_get_userinfo_http_error(self, gitee_client): @respx.mock async def test_get_userinfo_empty_response(self, gitee_client): """Test handling of empty user info response.""" - mock_user_info_response(respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, {}) + mock_user_info_response(respx, GITEE_USER_INFO_URL, {}) result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @@ -113,9 +113,7 @@ async def test_get_userinfo_empty_response(self, gitee_client): async def test_get_userinfo_partial_data(self, gitee_client): """Test handling of partial user info response.""" partial_data = {'id': 123456, 'login': 'testuser'} - mock_user_info_response( - respx, {'name': 'gitee', 'user_info_url': 'https://gitee.com/api/v5/user'}, partial_data - ) + mock_user_info_response(respx, GITEE_USER_INFO_URL, partial_data) result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data @@ -124,7 +122,7 @@ async def test_get_userinfo_partial_data(self, gitee_client): @respx.mock async def test_get_userinfo_invalid_json(self, gitee_client): """Test handling of invalid JSON response.""" - respx.get('https://gitee.com/api/v5/user').mock(return_value=httpx.Response(200, text='invalid json')) + respx.get(GITEE_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) with pytest.raises(GetUserInfoError): await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_github.py b/tests/clients/test_github.py index acd13d2..064ec57 100644 --- a/tests/clients/test_github.py +++ b/tests/clients/test_github.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.github import GitHubOAuth20 +from fastapi_oauth20 import GitHubOAuth20 from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,15 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +GITHUB_TOKEN_URL = 'https://github.com/login/oauth/access_token' +GITHUB_USER_INFO_URL = 'https://api.github.com/user' +GITHUB_EMAILS_URL = 'https://api.github.com/user/emails' + + +@pytest.fixture +def github_client(): + """Create GitHub OAuth2 client instance for testing.""" + return GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGitHubOAuth20: @@ -34,9 +40,9 @@ def test_github_client_initialization(self, github_client): def test_github_client_initialization_with_custom_credentials(self): """Test GitHub client initialization with custom credentials.""" - client = GitHubOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_github_client_inheritance(self, github_client): """Test that GitHub client properly inherits from OAuth20Base.""" @@ -75,9 +81,7 @@ def test_github_client_multiple_instances(self): async def test_get_userinfo_success_with_email(self, github_client): """Test successful user info retrieval from GitHub API with email included.""" mock_user_data = create_mock_user_data('github') - mock_user_info_response( - respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data - ) + mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -87,12 +91,10 @@ async def test_get_userinfo_success_with_email(self, github_client): async def test_get_userinfo_success_without_email(self, github_client): """Test successful user info retrieval from GitHub API without email.""" mock_user_data = create_mock_user_data('github', email=None) - mock_user_info_response( - respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data - ) + mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) # Mock emails endpoint emails_data = [{'email': 'test@example.com', 'primary': True}] - respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data)) + respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data)) result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result['login'] == mock_user_data['login'] @@ -103,9 +105,7 @@ async def test_get_userinfo_success_without_email(self, github_client): async def test_get_userinfo_with_different_access_token(self, github_client): """Test user info retrieval with different access tokens.""" mock_user_data = create_mock_user_data('github', id=789, login='different_user') - mock_user_info_response( - respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data - ) + mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) result = await github_client.get_userinfo('different_token') assert result == mock_user_data @@ -115,9 +115,7 @@ async def test_get_userinfo_with_different_access_token(self, github_client): async def test_get_userinfo_authorization_header(self, github_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user', 'email': 'test@example.com'} # Include email to avoid emails endpoint call - route = mock_user_info_response( - respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, mock_user_data - ) + route = mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) await github_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -130,7 +128,7 @@ async def test_get_userinfo_authorization_header(self, github_client): @respx.mock async def test_get_userinfo_http_error_401(self, github_client): """Test handling of 401 HTTP error when getting user info.""" - respx.get('https://api.github.com/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(INVALID_TOKEN) @@ -139,7 +137,7 @@ async def test_get_userinfo_http_error_401(self, github_client): @respx.mock async def test_get_userinfo_http_error_403(self, github_client): """Test handling of 403 HTTP error when getting user info.""" - respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, text='Forbidden')) + respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -148,7 +146,7 @@ async def test_get_userinfo_http_error_403(self, github_client): @respx.mock async def test_get_userinfo_http_error_500(self, github_client): """Test handling of 500 HTTP error when getting user info.""" - respx.get('https://api.github.com/user').mock(return_value=httpx.Response(500, text='Internal Server Error')) + respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -157,7 +155,7 @@ async def test_get_userinfo_http_error_500(self, github_client): @respx.mock async def test_get_userinfo_invalid_json(self, github_client): """Test handling of invalid JSON response.""" - respx.get('https://api.github.com/user').mock(return_value=httpx.Response(200, text='invalid json')) + respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) with pytest.raises(GetUserInfoError): await github_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -166,10 +164,10 @@ async def test_get_userinfo_invalid_json(self, github_client): @respx.mock async def test_get_userinfo_empty_response(self, github_client): """Test handling of empty user info response.""" - mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, {}) + mock_user_info_response(respx, GITHUB_USER_INFO_URL, {}) # Mock emails endpoint since empty response will trigger email lookup emails_data = [{'email': 'test@example.com', 'primary': True}] - respx.get('https://api.github.com/user/emails').mock(return_value=httpx.Response(200, json=emails_data)) + respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data)) result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result['email'] == 'test@example.com' @@ -183,7 +181,7 @@ async def test_get_userinfo_partial_data(self, github_client): 'login': 'testuser', 'email': 'test@example.com', } # Add email to avoid emails endpoint call - mock_user_info_response(respx, {'name': 'github', 'user_info_url': 'https://api.github.com/user'}, partial_data) + mock_user_info_response(respx, GITHUB_USER_INFO_URL, partial_data) result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data @@ -198,7 +196,7 @@ async def test_get_userinfo_rate_limit(self, github_client): 'documentation_url': 'https://docs.github.com/rest/overview/rate-limits-for-the-rest-api', } - respx.get('https://api.github.com/user').mock(return_value=httpx.Response(403, json=rate_limit_response)) + respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(403, json=rate_limit_response)) with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_google.py b/tests/clients/test_google.py index 505db53..e2ce547 100644 --- a/tests/clients/test_google.py +++ b/tests/clients/test_google.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.google import GoogleOAuth20 +from fastapi_oauth20 import GoogleOAuth20 from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,13 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +GOOGLE_USER_INFO_URL = 'https://www.googleapis.com/oauth2/v1/userinfo' + + +@pytest.fixture +def google_client(): + """Create Google OAuth2 client instance for testing.""" + return GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGoogleOAuth20: @@ -36,9 +40,9 @@ def test_google_client_initialization(self, google_client): def test_google_client_initialization_with_custom_credentials(self): """Test Google client initialization with custom credentials.""" - client = GoogleOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_google_client_inheritance(self, google_client): """Test that Google client properly inherits from OAuth20Base.""" @@ -71,9 +75,7 @@ def test_google_client_endpoint_urls(self): async def test_get_userinfo_success(self, google_client): """Test successful user info retrieval from Google OAuth2 API.""" mock_user_data = create_mock_user_data('google') - mock_user_info_response( - respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data - ) + mock_user_info_response(respx, GOOGLE_USER_INFO_URL, mock_user_data) result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -83,9 +85,7 @@ async def test_get_userinfo_success(self, google_client): async def test_get_userinfo_authorization_header(self, google_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} - route = mock_user_info_response( - respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, mock_user_data - ) + route = mock_user_info_response(respx, GOOGLE_USER_INFO_URL, mock_user_data) await google_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -98,9 +98,7 @@ async def test_get_userinfo_authorization_header(self, google_client): @respx.mock async def test_get_userinfo_http_error(self, google_client): """Test handling of HTTP errors when getting user info.""" - respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock( - return_value=httpx.Response(401, text='Unauthorized') - ) + respx.get(GOOGLE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await google_client.get_userinfo(INVALID_TOKEN) @@ -109,9 +107,7 @@ async def test_get_userinfo_http_error(self, google_client): @respx.mock async def test_get_userinfo_empty_response(self, google_client): """Test handling of empty user info response.""" - mock_user_info_response( - respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, {} - ) + mock_user_info_response(respx, GOOGLE_USER_INFO_URL, {}) result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @@ -121,9 +117,7 @@ async def test_get_userinfo_empty_response(self, google_client): async def test_get_userinfo_partial_data(self, google_client): """Test handling of partial user info response.""" partial_data = {'id': '123456789', 'email': 'test@example.com'} - mock_user_info_response( - respx, {'name': 'google', 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo'}, partial_data - ) + mock_user_info_response(respx, GOOGLE_USER_INFO_URL, partial_data) result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data @@ -132,9 +126,7 @@ async def test_get_userinfo_partial_data(self, google_client): @respx.mock async def test_get_userinfo_invalid_json(self, google_client): """Test handling of invalid JSON response.""" - respx.get('https://www.googleapis.com/oauth2/v1/userinfo').mock( - return_value=httpx.Response(200, text='invalid json') - ) + respx.get(GOOGLE_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) with pytest.raises(GetUserInfoError): await google_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_linuxdo.py b/tests/clients/test_linuxdo.py index 1201c35..f5cd3b9 100644 --- a/tests/clients/test_linuxdo.py +++ b/tests/clients/test_linuxdo.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20 +from fastapi_oauth20 import LinuxDoOAuth20 from fastapi_oauth20.errors import HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,13 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +LINUXDO_USER_INFO_URL = 'https://connect.linux.do/api/user' + + +@pytest.fixture +def linuxdo_client(): + """Create LinuxDo OAuth2 client instance for testing.""" + return LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestLinuxDoOAuth20: @@ -36,9 +40,9 @@ def test_linuxdo_client_initialization(self, linuxdo_client): def test_linuxdo_client_initialization_with_custom_credentials(self): """Test LinuxDo client initialization with custom credentials.""" - client = LinuxDoOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_linuxdo_client_inheritance(self, linuxdo_client): """Test that LinuxDo client properly inherits from OAuth20Base.""" @@ -66,9 +70,7 @@ def test_linuxdo_client_endpoint_urls(self): async def test_get_userinfo_success(self, linuxdo_client): """Test successful user info retrieval from LinuxDo API.""" mock_user_data = create_mock_user_data('linuxdo') - mock_user_info_response( - respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data - ) + mock_user_info_response(respx, LINUXDO_USER_INFO_URL, mock_user_data) result = await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -78,9 +80,7 @@ async def test_get_userinfo_success(self, linuxdo_client): async def test_get_userinfo_authorization_header(self, linuxdo_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} - route = mock_user_info_response( - respx, {'name': 'linuxdo', 'user_info_url': 'https://connect.linux.do/api/user'}, mock_user_data - ) + route = mock_user_info_response(respx, LINUXDO_USER_INFO_URL, mock_user_data) await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -93,7 +93,7 @@ async def test_get_userinfo_authorization_header(self, linuxdo_client): @respx.mock async def test_get_userinfo_http_error(self, linuxdo_client): """Test handling of HTTP errors when getting user info.""" - respx.get('https://connect.linux.do/api/user').mock(return_value=httpx.Response(401, text='Unauthorized')) + respx.get(LINUXDO_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await linuxdo_client.get_userinfo(INVALID_TOKEN) diff --git a/tests/clients/test_oschina.py b/tests/clients/test_oschina.py index 95cc411..b4b22d8 100644 --- a/tests/clients/test_oschina.py +++ b/tests/clients/test_oschina.py @@ -4,7 +4,7 @@ import pytest import respx -from fastapi_oauth20.clients.oschina import OSChinaOAuth20 +from fastapi_oauth20 import OSChinaOAuth20 from fastapi_oauth20.errors import HTTPXOAuth20Error from fastapi_oauth20.oauth20 import OAuth20Base from tests.conftest import ( @@ -16,9 +16,13 @@ mock_user_info_response, ) -# Constants specific to this test file -CUSTOM_CLIENT_ID = 'custom_id' -CUSTOM_CLIENT_SECRET = 'custom_secret' +OSCHINA_USER_INFO_URL = 'https://www.oschina.net/action/openapi/user' + + +@pytest.fixture +def oschina_client(): + """Create OSChina OAuth2 client instance for testing.""" + return OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestOSChinaOAuth20: @@ -35,9 +39,9 @@ def test_oschina_client_initialization(self, oschina_client): def test_oschina_client_initialization_with_custom_credentials(self): """Test OSChina client initialization with custom credentials.""" - client = OSChinaOAuth20(client_id=CUSTOM_CLIENT_ID, client_secret=CUSTOM_CLIENT_SECRET) - assert client.client_id == CUSTOM_CLIENT_ID - assert client.client_secret == CUSTOM_CLIENT_SECRET + client = OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + assert client.client_id == TEST_CLIENT_ID + assert client.client_secret == TEST_CLIENT_SECRET def test_oschina_client_inheritance(self, oschina_client): """Test that OSChina client properly inherits from OAuth20Base.""" @@ -65,11 +69,7 @@ def test_oschina_client_endpoint_urls(self): async def test_get_userinfo_success(self, oschina_client): """Test successful user info retrieval from OSChina API.""" mock_user_data = create_mock_user_data('oschina') - mock_user_info_response( - respx, - {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'}, - mock_user_data, - ) + mock_user_info_response(respx, OSCHINA_USER_INFO_URL, mock_user_data) result = await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @@ -79,11 +79,7 @@ async def test_get_userinfo_success(self, oschina_client): async def test_get_userinfo_authorization_header(self, oschina_client): """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} - route = mock_user_info_response( - respx, - {'name': 'oschina', 'user_info_url': 'https://www.oschina.net/action/openapi/user'}, - mock_user_data, - ) + route = mock_user_info_response(respx, OSCHINA_USER_INFO_URL, mock_user_data) await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) @@ -96,9 +92,7 @@ async def test_get_userinfo_authorization_header(self, oschina_client): @respx.mock async def test_get_userinfo_http_error(self, oschina_client): """Test handling of HTTP errors when getting user info.""" - respx.get('https://www.oschina.net/action/openapi/user').mock( - return_value=httpx.Response(401, text='Unauthorized') - ) + respx.get(OSCHINA_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) with pytest.raises(HTTPXOAuth20Error): await oschina_client.get_userinfo(INVALID_TOKEN) diff --git a/tests/conftest.py b/tests/conftest.py index 80e6a66..9c1a48f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,147 +1,12 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - -""" -Global constants and shared fixtures for fastapi-oauth20 tests. -""" - import httpx -import pytest - -from fastapi import Depends, FastAPI -from fastapi.testclient import TestClient - -from fastapi_oauth20.clients.feishu import FeiShuOAuth20 -from fastapi_oauth20.clients.gitee import GiteeOAuth20 -from fastapi_oauth20.clients.github import GitHubOAuth20 -from fastapi_oauth20.clients.google import GoogleOAuth20 -from fastapi_oauth20.clients.linuxdo import LinuxDoOAuth20 -from fastapi_oauth20.clients.oschina import OSChinaOAuth20 TEST_CLIENT_ID = 'test_client_id' TEST_CLIENT_SECRET = 'test_client_secret' TEST_ACCESS_TOKEN = 'test_access_token' INVALID_TOKEN = 'invalid_token' TEST_STATE = 'test_state' -OAUTH_PROVIDERS = [ - { - 'name': 'github', - 'client_class': GitHubOAuth20, - 'token_url': 'https://github.com/login/oauth/access_token', - 'user_info_url': 'https://api.github.com/user', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user'}, - 'redirect_uri': 'http://localhost:8000/auth/github/callback', - }, - { - 'name': 'google', - 'client_class': GoogleOAuth20, - 'token_url': 'https://oauth2.googleapis.com/token', - 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'Bearer', 'expires_in': 3600}, - 'redirect_uri': 'http://localhost:8000/auth/google/callback', - }, - { - 'name': 'gitee', - 'client_class': GiteeOAuth20, - 'token_url': 'https://gitee.com/oauth/token', - 'user_info_url': 'https://gitee.com/api/v5/user', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user_info'}, - 'redirect_uri': 'http://localhost:8000/auth/gitee/callback', - }, - { - 'name': 'feishu', - 'client_class': FeiShuOAuth20, - 'token_url': 'https://passport.feishu.cn/suite/passport/oauth/token', - 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, - 'redirect_uri': 'http://localhost:8000/auth/feishu/callback', - }, - { - 'name': 'linuxdo', - 'client_class': LinuxDoOAuth20, - 'token_url': 'https://connect.linux.do/oauth2/token', - 'user_info_url': 'https://connect.linux.do/api/user', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, - 'redirect_uri': 'http://localhost:8000/auth/linuxdo/callback', - }, - { - 'name': 'oschina', - 'client_class': OSChinaOAuth20, - 'token_url': 'https://www.oschina.net/action/openapi/token', - 'user_info_url': 'https://www.oschina.net/action/openapi/user', - 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, - 'redirect_uri': 'http://localhost:8000/auth/oschina/callback', - }, -] - - -@pytest.fixture(params=OAUTH_PROVIDERS) -def oauth_provider_config(request): - """Get OAuth provider configuration.""" - return request.param - - -@pytest.fixture -def oauth_client(oauth_provider_config): - """Create OAuth2 client for testing based on provider config.""" - provider_config = oauth_provider_config - client_class = provider_config['client_class'] - return client_class(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def fastapi_app(): - """Create FastAPI app for testing.""" - app = FastAPI() - - @app.get('/') - async def root(): - return {'message': 'OAuth2 Test App'} - - return app - - -@pytest.fixture -def test_client(fastapi_app): - """Create TestClient for FastAPI app.""" - return TestClient(fastapi_app) - - -# Individual OAuth client fixtures for non-parametrized tests -@pytest.fixture -def github_client(): - """Create GitHub OAuth2 client instance for testing.""" - return GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def google_client(): - """Create Google OAuth2 client instance for testing.""" - return GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def gitee_client(): - """Create Gitee OAuth2 client instance for testing.""" - return GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def feishu_client(): - """Create FeiShu OAuth2 client instance for testing.""" - return FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def linuxdo_client(): - """Create LinuxDo OAuth2 client instance for testing.""" - return LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - - -@pytest.fixture -def oschina_client(): - """Create OSChina OAuth2 client instance for testing.""" - return OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) def create_mock_user_data(provider_name: str, **overrides): @@ -195,39 +60,6 @@ def create_mock_user_data(provider_name: str, **overrides): return base_data -def mock_oauth_token_response(respx_mock, provider_config: dict, status_code: int = 200): - """Mock OAuth token endpoint response for a provider.""" - return respx_mock.post(provider_config['token_url']).mock( - return_value=httpx.Response(status_code, json=provider_config['token_response']) - ) - - -def mock_user_info_response(respx_mock, provider_config: dict, user_data: dict = None, status_code: int = 200): - """Mock user info endpoint response for a provider.""" - if user_data is None: - user_data = create_mock_user_data(provider_config['name']) - return respx_mock.get(provider_config['user_info_url']).mock( - return_value=httpx.Response(status_code, json=user_data) - ) - - -def setup_oauth_callback_route(app: FastAPI, provider_config: dict, oauth_dependency): - """Setup OAuth callback route for testing.""" - callback_path = f'/auth/{provider_config["name"]}/callback' - - @app.get(callback_path) - async def oauth_callback_handler(access_token_state=Depends(oauth_dependency)): - token, state = access_token_state - return {'provider': provider_config['name'], 'access_token': token, 'state': state} - - return callback_path - - -def assert_oauth_error_response(response, expected_error: str, expected_status: int = 400): - """Assert OAuth error response has correct format.""" - assert response.status_code == expected_status - data = response.json() - if 'detail' in data: - assert data['detail'] == expected_error - else: - assert data.get('error') == expected_error +def mock_user_info_response(respx_mock, user_info_url: str, user_data: dict, status_code: int = 200): + """Mock user info endpoint response.""" + return respx_mock.get(user_info_url).mock(return_value=httpx.Response(status_code, json=user_data)) diff --git a/tests/test_fastapi.py b/tests/test_callback.py similarity index 75% rename from tests/test_fastapi.py rename to tests/test_callback.py index 7010f42..0640b0b 100644 --- a/tests/test_fastapi.py +++ b/tests/test_callback.py @@ -8,19 +8,74 @@ from fastapi.responses import JSONResponse from fastapi.testclient import TestClient -from fastapi_oauth20 import FastAPIOAuth20, OAuth20AuthorizeCallbackError +from fastapi_oauth20 import ( + FastAPIOAuth20, + FeiShuOAuth20, + GiteeOAuth20, + GitHubOAuth20, + GoogleOAuth20, + LinuxDoOAuth20, + OAuth20AuthorizeCallbackError, + OSChinaOAuth20, +) from tests.conftest import ( - OAUTH_PROVIDERS, TEST_ACCESS_TOKEN, TEST_CLIENT_ID, TEST_CLIENT_SECRET, TEST_STATE, - assert_oauth_error_response, - mock_oauth_token_response, - setup_oauth_callback_route, ) -# Integration test specific constants +OAUTH_PROVIDERS = [ + { + 'name': 'github', + 'client_class': GitHubOAuth20, + 'token_url': 'https://github.com/login/oauth/access_token', + 'user_info_url': 'https://api.github.com/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user'}, + 'redirect_uri': 'http://localhost:8000/auth/github/callback', + }, + { + 'name': 'google', + 'client_class': GoogleOAuth20, + 'token_url': 'https://oauth2.googleapis.com/token', + 'user_info_url': 'https://www.googleapis.com/oauth2/v1/userinfo', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'Bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/google/callback', + }, + { + 'name': 'gitee', + 'client_class': GiteeOAuth20, + 'token_url': 'https://gitee.com/oauth/token', + 'user_info_url': 'https://gitee.com/api/v5/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'scope': 'user_info'}, + 'redirect_uri': 'http://localhost:8000/auth/gitee/callback', + }, + { + 'name': 'feishu', + 'client_class': FeiShuOAuth20, + 'token_url': 'https://passport.feishu.cn/suite/passport/oauth/token', + 'user_info_url': 'https://passport.feishu.cn/suite/passport/oauth/userinfo', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/feishu/callback', + }, + { + 'name': 'linuxdo', + 'client_class': LinuxDoOAuth20, + 'token_url': 'https://connect.linux.do/oauth2/token', + 'user_info_url': 'https://connect.linux.do/api/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/linuxdo/callback', + }, + { + 'name': 'oschina', + 'client_class': OSChinaOAuth20, + 'token_url': 'https://www.oschina.net/action/openapi/token', + 'user_info_url': 'https://www.oschina.net/action/openapi/user', + 'token_response': {'access_token': TEST_ACCESS_TOKEN, 'token_type': 'bearer', 'expires_in': 3600}, + 'redirect_uri': 'http://localhost:8000/auth/oschina/callback', + }, +] + LOCALHOST_URL = 'http://localhost:8000' DEV_URL = 'http://dev.example.com' APP_URL = 'http://app.example.org' @@ -30,12 +85,10 @@ SECURE_APP_URL = 'https://app.example.org' SECURE_IP_URL = 'https://192.168.1.100:8443' -# Callback paths AUTH_PATH = '/auth' CALLBACK_PATH = '/callback' OAUTH_CALLBACK_PATH = '/oauth2/callback' -# Test URIs HTTP_URIS = [ f'{LOCALHOST_URL}{CALLBACK_PATH}', f'{DEV_URL}{AUTH_PATH}/callback', @@ -51,6 +104,47 @@ ] +@pytest.fixture +def fastapi_app(): + """Create FastAPI app for testing.""" + app = FastAPI() + + @app.get('/') + async def root(): + return {'message': 'OAuth2 Test App'} + + return app + + +def mock_oauth_token_response(respx_mock, provider_config: dict, status_code: int = 200): + """Mock OAuth token endpoint response for a provider.""" + return respx_mock.post(provider_config['token_url']).mock( + return_value=httpx.Response(status_code, json=provider_config['token_response']) + ) + + +def setup_oauth_callback_route(app: FastAPI, provider_config: dict, oauth_dependency): + """Setup OAuth callback route for testing.""" + callback_path = f'/auth/{provider_config["name"]}/callback' + + @app.get(callback_path) + async def oauth_callback_handler(access_token_state=Depends(oauth_dependency)): + token, state = access_token_state + return {'provider': provider_config['name'], 'access_token': token, 'state': state} + + return callback_path + + +def assert_oauth_error_response(response, expected_error: str, expected_status: int = 400): + """Assert OAuth error response has correct format.""" + assert response.status_code == expected_status + data = response.json() + if 'detail' in data: + assert data['detail'] == expected_error + else: + assert data.get('error') == expected_error + + class TestFastAPIOAuth20Basic: """Basic tests for FastAPI OAuth2 integration using parametrized providers.""" @@ -126,8 +220,9 @@ async def test_callback_token_exchange_error_parametrized(self, provider_config, assert response.status_code == 500 assert 'detail' in response.json() - def test_custom_exception_handler(self, github_client, fastapi_app): + def test_custom_exception_handler(self, fastapi_app): """Test custom exception handler for OAuth2 errors.""" + github_client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback') @fastapi_app.get('/auth/github/callback') @@ -155,9 +250,10 @@ async def oauth2_error_handler(request: Request, exc: OAuth20AuthorizeCallbackEr assert data['error'] == 'access_denied' assert data['status_code'] == 400 - def test_multiple_oauth_providers(self, github_client, google_client, fastapi_app): + def test_multiple_oauth_providers(self, fastapi_app): """Test multiple OAuth providers in the same app.""" # Setup GitHub OAuth + github_client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) github_oauth2_callback = FastAPIOAuth20(github_client, redirect_uri=f'{LOCALHOST_URL}/auth/github/callback') @fastapi_app.get('/auth/github/callback') @@ -166,6 +262,7 @@ async def github_callback(access_token_state=Depends(github_oauth2_callback)): return {'provider': 'github', 'access_token': token, 'state': state} # Setup Google OAuth + google_client = GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) google_oauth2_callback = FastAPIOAuth20(google_client, redirect_uri=f'{LOCALHOST_URL}/auth/google/callback') @fastapi_app.get('/auth/google/callback') @@ -297,8 +394,9 @@ async def test_token_exchange_errors(self, provider_config, http_status, fastapi class TestFastAPIOAuth20Integration: """Integration tests for FastAPI OAuth2.""" - def test_oauth_dependency_creation(self, github_client): + def test_oauth_dependency_creation(self): """Test OAuth dependency creation with different parameters.""" + github_client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) # Basic OAuth dependency oauth_dep = FastAPIOAuth20(github_client) assert oauth_dep.client == github_client @@ -308,8 +406,9 @@ def test_oauth_dependency_creation(self, github_client): oauth_dep_custom = FastAPIOAuth20(github_client, redirect_uri=custom_redirect) assert oauth_dep_custom.client == github_client - def test_multiple_apps_same_provider(self, github_client): + def test_multiple_apps_same_provider(self): """Test the same OAuth provider in multiple FastAPI apps.""" + github_client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) app1 = FastAPI() app2 = FastAPI() diff --git a/tests/test_oauth20.py b/tests/test_oauth20.py index 47e2d07..ae738dc 100644 --- a/tests/test_oauth20.py +++ b/tests/test_oauth20.py @@ -33,6 +33,7 @@ def oauth_client(): client_secret='test_client_secret', authorize_endpoint='https://example.com/oauth/authorize', access_token_endpoint='https://example.com/oauth/token', + userinfo_endpoint='https://example.com/oauth/userinfo', refresh_token_endpoint='https://example.com/oauth/refresh', revoke_token_endpoint='https://example.com/oauth/revoke', default_scopes=['read', 'write'], @@ -60,6 +61,7 @@ def test_oauth_base_initialization_minimal(): client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', ) assert client.client_id == 'test_id' @@ -78,6 +80,7 @@ def test_oauth_base_initialization_with_basic_auth(): client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', token_endpoint_basic_auth=True, revoke_token_endpoint_basic_auth=True, ) @@ -187,6 +190,7 @@ async def test_get_access_token_with_basic_auth(): client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', token_endpoint_basic_auth=True, ) @@ -237,6 +241,7 @@ async def test_refresh_token_missing_endpoint(): client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', ) with pytest.raises(RefreshTokenError, match='refresh token address is missing'): @@ -288,6 +293,7 @@ async def test_revoke_token_missing_endpoint(): client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', ) with pytest.raises(RevokeTokenError, match='revoke token address is missing'): @@ -353,22 +359,19 @@ def test_get_json_result_invalid_json(): OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError) -def test_abstract_method(): - """Test that get_userinfo is properly abstract.""" - with pytest.raises(TypeError): - OAuth20Base( - client_id='test', - client_secret='test', - authorize_endpoint='https://example.com/auth', - access_token_endpoint='https://example.com/token', - ) - - -def test_oauth_base_inheritance(): - """Test that OAuth20Base is properly abstract.""" - from abc import ABC +def test_concrete_implementation(): + """Test that OAuth20Base can be instantiated directly.""" + client = OAuth20Base( + client_id='test', + client_secret='test', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', + ) - assert issubclass(OAuth20Base, ABC) + assert client.client_id == 'test' + assert client.client_secret == 'test' + assert client.userinfo_endpoint == 'https://example.com/userinfo' @pytest.mark.asyncio @@ -379,6 +382,7 @@ async def test_get_userinfo_implementation(): client_secret='test', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', ) result = await client.get_userinfo('test_token')