diff --git a/logto/LogtoClient.py b/logto/LogtoClient.py index 44caa44..1d0c37e 100644 --- a/logto/LogtoClient.py +++ b/logto/LogtoClient.py @@ -3,13 +3,20 @@ """ import time +import urllib.parse from typing import Dict, List, Literal, Optional, Union + from pydantic import BaseModel -import urllib.parse -from .models.oidc import ReservedResource, Scope, UserInfoScope -from .Storage import MemoryStorage, Storage from .LogtoException import LogtoException +from .models.oidc import ( + DirectSignInOption, + FirstScreen, + Identifier, + ReservedResource, + Scope, + UserInfoScope, +) from .OidcCore import ( AccessTokenClaims, IdTokenClaims, @@ -17,6 +24,7 @@ TokenResponse, UserInfoResponse, ) +from .Storage import MemoryStorage, Storage from .utilities import OrganizationUrnPrefix, buildOrganizationUrn, removeFalsyKeys @@ -162,7 +170,8 @@ def _getAccessTokenMap(self) -> AccessTokenMap: """ accessTokenMap = self._storage.get("accessTokenMap") try: - return AccessTokenMap.model_validate_json(accessTokenMap) + # Returns parsed `AccessTokenMap` if valid JSON, otherwise will be caught by except clause + return AccessTokenMap.model_validate_json(accessTokenMap) # type: ignore except: return AccessTokenMap(x={}) @@ -218,6 +227,10 @@ async def _buildSignInUrl( codeChallenge: str, state: str, interactionMode: Optional[InteractionMode] = None, + firstScreen: Optional[FirstScreen] = None, + identifiers: Optional[List[Identifier]] = None, + directSignIn: Optional[DirectSignInOption] = None, + extraParams: Optional[Dict[str, str]] = None, ) -> str: appId, prompt, resources, scopes = ( self.config.appId, @@ -248,6 +261,18 @@ async def _buildSignInUrl( "code_challenge_method": "S256", "state": state, "interaction_mode": interactionMode, + "first_screen": firstScreen, + "identifier": ( + " ".join(identifier.value for identifier in identifiers or []) + if identifiers + else None + ), + "direct_sign_in": ( + f"{directSignIn.method}:{directSignIn.identifier}" + if directSignIn + else None + ), + **(extraParams or {}), } ), True, @@ -270,27 +295,49 @@ def _getSignInSession(self) -> Optional[SignInSession]: def _setSignInSession(self, signInSession: SignInSession) -> None: self._storage.set("signInSession", signInSession.model_dump_json()) + def _clearAllTokens(self) -> None: + self._storage.delete("idToken") + self._storage.delete("refreshToken") + self._storage.delete("accessTokenMap") + async def signIn( - self, redirectUri: str, interactionMode: Optional[InteractionMode] = None + self, + redirectUri: str, + interactionMode: Optional[InteractionMode] = None, + firstScreen: Optional[FirstScreen] = None, + identifiers: Optional[List[Identifier]] = None, + directSignIn: Optional[DirectSignInOption] = None, + extraParams: Optional[Dict[str, str]] = None, ) -> str: """ - Returns the sign-in URL for the given redirect URI. You should redirect the user - to the returned URL to sign in. + Returns the sign-in URL for the given redirect URI. - By specifying the interaction mode, you can control whether the user will be - prompted for sign-in or sign-up on the first screen. If the interaction mode is - not specified, the default one will be used. + Args: + redirectUri: The URI to redirect after sign-in + interactionMode: Control whether to show sign-in or sign-up screen + firstScreen: Specify the first screen to show in sign-in experience + directSignIn: Configure direct sign-in options for SSO or social sign-in Example: ```python - return redirect(await client.signIn('https://example.com/callback')) + return redirect(await client.signIn( + 'https://example.com/callback', + firstScreen=FirstScreen.register + )) ``` """ codeVerifier = OidcCore.generateCodeVerifier() codeChallenge = OidcCore.generateCodeChallenge(codeVerifier) state = OidcCore.generateState() signInUrl = await self._buildSignInUrl( - redirectUri, codeChallenge, state, interactionMode + redirectUri, + codeChallenge, + state, + interactionMode, + firstScreen, + identifiers, + directSignIn, + extraParams, ) self._setSignInSession( @@ -300,8 +347,7 @@ async def signIn( state=state, ) ) - for key in ["idToken", "accessToken", "refreshToken"]: - self._storage.delete(key) + self._clearAllTokens() return signInUrl @@ -323,9 +369,7 @@ async def signOut(self, postLogoutRedirectUri: Optional[str] = None) -> str: return redirect(await client.signOut('https://example.com')) ``` """ - self._storage.delete("idToken") - self._storage.delete("refreshToken") - self._storage.delete("accessTokenMap") + self._clearAllTokens() endSessionEndpoint = (await self.getOidcCore()).metadata.end_session_endpoint @@ -438,6 +482,8 @@ async def getAccessTokenClaims(self, resource: str = "") -> AccessTokenClaims: access token, an exception will be thrown. """ accessToken = await self.getAccessToken(resource) + if accessToken is None: + raise LogtoException("Failed to get access token claims.") return OidcCore.decodeAccessToken(accessToken) async def getOrganizationTokenClaims( @@ -486,4 +532,8 @@ async def fetchUserInfo(self) -> UserInfoResponse: is expired, it will be refreshed automatically. """ accessToken = await self.getAccessToken() + if accessToken is None: + raise LogtoException( + "Can not get access token and fail to fetch user info." + ) return await (await self.getOidcCore()).fetchUserInfo(accessToken) diff --git a/logto/LogtoClient_test.py b/logto/LogtoClient_test.py index d05d460..18706fe 100644 --- a/logto/LogtoClient_test.py +++ b/logto/LogtoClient_test.py @@ -1,13 +1,24 @@ +from itertools import combinations from typing import Any, Callable, Dict, Optional -from pytest_mock import MockerFixture +from urllib.parse import quote + import pytest +from pytest_mock import MockerFixture from . import LogtoClient, LogtoConfig, LogtoException, Storage -from .utilities.test import mockHttp, mockProviderMetadata +from .models.oidc import ( + AccessTokenClaims, + DirectSignInOption, + DirectSignInOptionMethod, + FirstScreen, + Identifier, + IdTokenClaims, + UserInfoScope, +) from .models.response import TokenResponse, UserInfoResponse -from .models.oidc import IdTokenClaims, AccessTokenClaims, UserInfoScope -from .Storage import MemoryStorage, Storage from .OidcCore import OidcCore +from .Storage import MemoryStorage, Storage +from .utilities.test import mockHttp, mockProviderMetadata MockRequest = Callable[..., None] @@ -110,6 +121,80 @@ async def test_signIn_allConfigs(self, client: LogtoClient) -> None: == "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=email+phone+openid+offline_access+profile&resource=https%3A%2F%2Fresource1&resource=https%3A%2F%2Fresource2&prompt=login&code_challenge=codeChallenge&code_challenge_method=S256&state=state&interaction_mode=signUp" ) + async def test_signIn_firstScreen(self, client: LogtoClient) -> None: + # Get all possible identifier combinations + all_identifiers = list(Identifier) + possible_identifier_combinations: list[tuple[Identifier]] = [] + for r in range(1, len(all_identifiers) + 1): + possible_identifier_combinations.extend(combinations(all_identifiers, r)) + + for firstScreen in FirstScreen: + for identifiers in possible_identifier_combinations: + url = await client.signIn( + "redirectUri", + firstScreen=firstScreen, + identifiers=list(identifiers), + ) + + expected_identifiers = "+".join( + quote(identifier.value) for identifier in identifiers + ) + assert ( + url + == f"https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&first_screen={quote(firstScreen.value)}&identifier={expected_identifiers}" + ) + + async def test_signIn_directSignIn_sso(self, client: LogtoClient) -> None: + url = await client.signIn( + "redirectUri", + directSignIn=DirectSignInOption( + method=DirectSignInOptionMethod.sso, identifier="arbitrary-sso-id" + ), + ) + + assert ( + url + == "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&direct_sign_in=sso%3Aarbitrary-sso-id" + ) + + async def test_signIn_directSignIn_social(self, client: LogtoClient) -> None: + url = await client.signIn( + "redirectUri", + directSignIn=DirectSignInOption( + method=DirectSignInOptionMethod.social, identifier="google" + ), + ) + + assert ( + url + == "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&direct_sign_in=social%3Agoogle" + ) + + async def test_signIn_extraParams(self, client: LogtoClient) -> None: + url = await client.signIn( + "redirectUri", + extraParams={"custom_param_1": "value_1", "custom_param_2": "value_2"}, + ) + + assert ( + url + == "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&custom_param_1=value_1&custom_param_2=value_2" + ) + + async def test_signIn_multipleParams(self, client: LogtoClient) -> None: + url = await client.signIn( + "redirectUri", + interactionMode="signUp", + firstScreen=FirstScreen.register, + identifiers=[Identifier.email], + extraParams={"custom_param_1": "value_1", "custom_param_2": "value_2"}, + ) + + assert ( + url + == "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&interaction_mode=signUp&first_screen=identifier%3Aregister&identifier=email&custom_param_1=value_1&custom_param_2=value_2" + ) + async def test_signOut( self, client: LogtoClient, storage: Storage, mockRequest: MockRequest ) -> None: diff --git a/logto/OidcCore.py b/logto/OidcCore.py index f2bc99f..87dded6 100644 --- a/logto/OidcCore.py +++ b/logto/OidcCore.py @@ -6,10 +6,11 @@ import hashlib import secrets +from typing import List, Optional + import aiohttp -from jwt import PyJWKClient import jwt -from typing import List, Optional +from jwt import PyJWKClient from .LogtoException import LogtoException from .models.oidc import ( @@ -42,12 +43,14 @@ def __init__(self, metadata: OidcProviderMetadata) -> None: metadata.jwks_uri, headers={"user-agent": "@logto/python", "accept": "*/*"} ) + @staticmethod def generateState() -> str: """ Generate a random string (32 bytes) for the state parameter. """ return urlsafeEncode(secrets.token_bytes(32)) + @staticmethod def generateCodeVerifier() -> str: """ Generate a random code verifier string (32 bytes) for PKCE. @@ -56,6 +59,7 @@ def generateCodeVerifier() -> str: """ return urlsafeEncode(secrets.token_bytes(32)) + @staticmethod def generateCodeChallenge(codeVerifier: str) -> str: """ Generate a code challenge string for the given code verifier string. @@ -64,12 +68,14 @@ def generateCodeChallenge(codeVerifier: str) -> str: """ return urlsafeEncode(hashlib.sha256(codeVerifier.encode("ascii")).digest()) + @staticmethod def decodeIdToken(idToken: str) -> IdTokenClaims: """ Decode the ID Token and return the claims without verifying the signature. """ return IdTokenClaims(**jwt.decode(idToken, options={"verify_signature": False})) + @staticmethod def decodeAccessToken(accessToken: str) -> AccessTokenClaims: """ Decode the access token and return the claims without verifying the signature. @@ -78,6 +84,7 @@ def decodeAccessToken(accessToken: str) -> AccessTokenClaims: **jwt.decode(accessToken, options={"verify_signature": False}) ) + @staticmethod async def getProviderMetadata(discoveryUrl: str) -> OidcProviderMetadata: """ Fetch the provider metadata from the discovery URL. diff --git a/logto/Storage.py b/logto/Storage.py index 585cfbb..502e9fb 100644 --- a/logto/Storage.py +++ b/logto/Storage.py @@ -52,6 +52,7 @@ class MemoryStorage(Storage): See `Storage` for the interface. """ + @staticmethod def printWarning() -> None: print( "WARNING: Using MemoryStorage for Logto client, this should only be used for testing.", @@ -67,7 +68,10 @@ def get(self, key: str) -> Optional[str]: def set(self, key: str, value: Optional[str]) -> None: MemoryStorage.printWarning() - self._data[key] = value + if value is not None: + self._data[key] = value + else: + self._data.pop(key, None) def delete(self, key: str) -> None: MemoryStorage.printWarning() diff --git a/logto/models/oidc.py b/logto/models/oidc.py index 8a7ad49..864f85d 100644 --- a/logto/models/oidc.py +++ b/logto/models/oidc.py @@ -1,7 +1,8 @@ +import warnings from enum import Enum -from typing import List, Optional, Any +from typing import Any, List, Optional, Type, TypeVar + from pydantic import BaseModel, ConfigDict -import warnings class OidcProviderMetadata(BaseModel): @@ -59,8 +60,10 @@ def __new__(cls, value: Any): return member @classmethod - def _get_deprecated_member(cls, member): - # _get_deprecated_member is a protect util method to get the deprecated member with warning. + def _get_deprecated_member(cls, member: Any) -> Any: + """ + Get the deprecated member with warning. + """ warnings.warn(f"{member.name} is deprecated.", DeprecationWarning, stacklevel=2) return member @@ -119,7 +122,7 @@ class UserInfoScope(Scope): """ @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any) -> Any: """ `_missing_` is a [built-in method](https://docs.python.org/3/library/enum.html#supported-sunder-names) to handle missing members, we overwrite it and throws a warning for deprecated members. @@ -184,6 +187,108 @@ class ReservedResource(Enum): """The resource for organization template per [RFC 0001](https://github.com/logto-io/rfcs).""" +T = TypeVar("T", bound="BaseEnum") + + +class BaseEnum(Enum): + """Base enum class with common functionality""" + + def __new__(cls, value: str): + member = object.__new__(cls) + member._value_ = value + return member + + @classmethod + def from_value(cls: Type[T], value: str) -> Optional[T]: + """Create an enum member from a string value""" + try: + return cls(value) + except ValueError: + return None + + def __str__(self) -> str: + return self.value + + +class FirstScreen(BaseEnum): + """ + The first screen for the sign-in experience. + + Refer to [Authentication parameters > First screen](https://docs.logto.io/docs/references/openid-connect/authentication-parameters/#first-screen) for more information. + """ + + register = "identifier:register" + """ + Show the register form on first screen. + """ + sign_in = "identifier:sign_in" + """ + Show the sign-in form on first screen. + """ + single_sign_on = "single_sign_on" + """ + Show the single sign-on form on first screen. + """ + reset_password = "reset_password" + """ + Show the reset password form on first screen. + """ + + +class Identifier(BaseEnum): + """ + The identifiers for the sign-in experience. MUST work with `first_screen`. + """ + + username = "username" + """ + Use `username` as identifier. + """ + email = "email" + """ + Use `email` as identifier. + """ + phone = "phone" + """ + Use `phone` as identifier. + """ + + +class DirectSignInOptionMethod(BaseEnum): + """ + The prefix for the direct sign-in methods. + """ + + sso = "sso" + """ + The prefix for the single sign-on (SSO) sign-in method. + """ + social = "social" + """ + The prefix for the social sign-in method. + """ + + +class DirectSignInOption(BaseModel): + """ + The direct sign-in options. + + Refer to [Authentication parameters > Direct sign-in](https://docs.logto.io/docs/references/openid-connect/authentication-parameters/#direct-sign-in) for more information. + """ + + method: DirectSignInOptionMethod + """ + The method for the direct sign-in. See `DirectSignInOptionMethod` for more information. + """ + identifier: str + """ + The identifier for the direct sign-in. + + `identifier` is IdP name when using social sign-in (method is `DirectSignInOptionMethod.social`). + `identifier` is enterprise SSO ID when using SSO sign-in (method is `DirectSignInOptionMethod.sso`). + """ + + class AccessTokenClaims(BaseModel): """ The access token claims object. diff --git a/pdm.lock b/pdm.lock index 1a906b6..7bb1db7 100644 --- a/pdm.lock +++ b/pdm.lock @@ -4,8 +4,11 @@ [metadata] groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] -lock_version = "4.4.1" -content_hash = "sha256:9490efac1103224b3f2d97600639dc903b3a55ed211a5943c6ad61c00b7b4580" +lock_version = "4.5.0" +content_hash = "sha256:eb328ae2a89f2933376218a8805a1acbd3020cec3b20aeeb62441bf9f38a4f67" + +[[metadata.targets]] +requires_python = "~=3.8" [[package]] name = "aiohttp" @@ -149,6 +152,9 @@ requires_python = ">=3.7" summary = "Timeout context manager for asyncio programs" groups = ["default"] marker = "python_version < \"3.11\"" +dependencies = [ + "typing-extensions>=3.6.5; python_version < \"3.8\"", +] files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, @@ -160,6 +166,9 @@ version = "23.2.0" requires_python = ">=3.7" summary = "Classes Without Boilerplate" groups = ["default"] +dependencies = [ + "importlib-metadata; python_version < \"3.8\"", +] files = [ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, @@ -386,6 +395,7 @@ summary = "Composable command line interface toolkit" groups = ["dev"] dependencies = [ "colorama; platform_system == \"Windows\"", + "importlib-metadata; python_version < \"3.8\"", ] files = [ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, @@ -659,6 +669,7 @@ requires_python = ">=3.7,<4.0" summary = "A parser based on lib2to3 producing docspec data from Python source code." groups = ["dev"] dependencies = [ + "black<24.0.0,>=23.1.0", "docspec<3.0.0,>=2.2.1", "nr-util>=0.7.0", ] @@ -828,6 +839,7 @@ requires_python = ">=3.8" summary = "Read metadata from Python packages" groups = ["dev"] dependencies = [ + "typing-extensions>=3.6.4; python_version < \"3.8\"", "zipp>=0.5", ] files = [ @@ -1034,6 +1046,9 @@ version = "2.1.0" requires_python = ">=3.6,<4.0" summary = "" groups = ["dev"] +dependencies = [ + "dataclasses<0.9,>=0.8; python_version == \"3.6\"", +] files = [ {file = "nr_date-2.1.0-py3-none-any.whl", hash = "sha256:bd672a9dfbdcf7c4b9289fea6750c42490eaee08036a72059dcc78cb236ed568"}, {file = "nr_date-2.1.0.tar.gz", hash = "sha256:0643aea13bcdc2a8bc56af9d5e6a89ef244c9744a1ef00cdc735902ba7f7d2e6"}, @@ -1261,6 +1276,9 @@ version = "2.8.0" requires_python = ">=3.7" summary = "JSON Web Token implementation in Python" groups = ["default"] +dependencies = [ + "typing-extensions; python_version <= \"3.7\"", +] files = [ {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, @@ -1291,6 +1309,7 @@ groups = ["dev"] dependencies = [ "colorama; sys_platform == \"win32\"", "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", + "importlib-metadata>=0.12; python_version < \"3.8\"", "iniconfig", "packaging", "pluggy<2.0,>=0.12", @@ -1628,6 +1647,7 @@ groups = ["default"] dependencies = [ "idna>=2.0", "multidict>=4.0", + "typing-extensions>=3.7.4; python_version < \"3.8\"", ] files = [ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, diff --git a/samples/flask.py b/samples/flask.py index a373848..2800ee9 100644 --- a/samples/flask.py +++ b/samples/flask.py @@ -1,13 +1,15 @@ -from flask import Flask, g, redirect, request, jsonify -from logto import LogtoException from dotenv import load_dotenv +from flask import Flask, g, jsonify, redirect, request + +from logto import LogtoException +from logto.models.oidc import FirstScreen, Identifier from samples.authenticated import authenticated +from samples.client import client from samples.config import ( APP_SECRET_KEY, LOGTO_POST_LOGOUT_REDIRECT_URI, LOGTO_REDIRECT_URI, ) -from samples.client import client load_dotenv() app = Flask(__name__) @@ -39,12 +41,22 @@ async def index(): @app.route("/sign-in") async def sign_in(): - return redirect( - await client.signIn( - redirectUri=LOGTO_REDIRECT_URI, - interactionMode="signUp", # Remove to show the sign-in as the first screen - ) + signInUrl = await client.signIn( + redirectUri=LOGTO_REDIRECT_URI, + interactionMode="signIn", + # Show sign in form on first screen + firstScreen=FirstScreen.sign_in, + # Show username/email on sign in form, MUST be used with `firstScreen` parameter + identifiers=[Identifier.email, Identifier.username], + # Go directly to `github` social sign-in + # E.g.: + # directSignIn=DirectSignInOption( + # method=DirectSignInOptionMethod.social.value, + # identifier="github", + # ), + directSignIn=None, ) + return redirect(signInUrl) @app.route("/sign-out")