Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions logto/LogtoClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
"""

import time
import urllib.parse
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel
import urllib.parse

from .models.oidc import ReservedResource, Scope, UserInfoScope
from .Storage import MemoryStorage, Storage
from .LogtoException import LogtoException
from .models.oidc import (
DirectSignInOption,
FirstScreen,
Identifier,
ReservedResource,
Scope,
UserInfoScope,
)
from .OidcCore import (
AccessTokenClaims,
IdTokenClaims,
OidcCore,
TokenResponse,
UserInfoResponse,
)
from .Storage import MemoryStorage, Storage
from .utilities import OrganizationUrnPrefix, buildOrganizationUrn, removeFalsyKeys


Expand Down Expand Up @@ -162,7 +170,8 @@ def _getAccessTokenMap(self) -> AccessTokenMap:
"""
accessTokenMap = self._storage.get("accessTokenMap")
try:
return AccessTokenMap.model_validate_json(accessTokenMap)
# Returns parsed `AccessTokenMap` if valid JSON, otherwise will be caught by except clause
return AccessTokenMap.model_validate_json(accessTokenMap) # type: ignore
except:
return AccessTokenMap(x={})

Expand Down Expand Up @@ -218,6 +227,10 @@ async def _buildSignInUrl(
codeChallenge: str,
state: str,
interactionMode: Optional[InteractionMode] = None,
firstScreen: Optional[FirstScreen] = None,
identifiers: Optional[List[Identifier]] = None,
directSignIn: Optional[DirectSignInOption] = None,
extraParams: Optional[Dict[str, str]] = None,
) -> str:
appId, prompt, resources, scopes = (
self.config.appId,
Expand Down Expand Up @@ -248,6 +261,18 @@ async def _buildSignInUrl(
"code_challenge_method": "S256",
"state": state,
"interaction_mode": interactionMode,
"first_screen": firstScreen,
"identifier": (
" ".join(identifier.value for identifier in identifiers or [])
if identifiers
else None
),
"direct_sign_in": (
f"{directSignIn.method}:{directSignIn.identifier}"
if directSignIn
else None
),
**(extraParams or {}),
}
),
True,
Expand All @@ -270,27 +295,49 @@ def _getSignInSession(self) -> Optional[SignInSession]:
def _setSignInSession(self, signInSession: SignInSession) -> None:
self._storage.set("signInSession", signInSession.model_dump_json())

def _clearAllTokens(self) -> None:
self._storage.delete("idToken")
self._storage.delete("refreshToken")
self._storage.delete("accessTokenMap")

async def signIn(
self, redirectUri: str, interactionMode: Optional[InteractionMode] = None
self,
redirectUri: str,
interactionMode: Optional[InteractionMode] = None,
firstScreen: Optional[FirstScreen] = None,
identifiers: Optional[List[Identifier]] = None,
directSignIn: Optional[DirectSignInOption] = None,
extraParams: Optional[Dict[str, str]] = None,
) -> str:
"""
Returns the sign-in URL for the given redirect URI. You should redirect the user
to the returned URL to sign in.
Returns the sign-in URL for the given redirect URI.

By specifying the interaction mode, you can control whether the user will be
prompted for sign-in or sign-up on the first screen. If the interaction mode is
not specified, the default one will be used.
Args:
redirectUri: The URI to redirect after sign-in
interactionMode: Control whether to show sign-in or sign-up screen
firstScreen: Specify the first screen to show in sign-in experience
directSignIn: Configure direct sign-in options for SSO or social sign-in

Example:
```python
return redirect(await client.signIn('https://example.com/callback'))
return redirect(await client.signIn(
'https://example.com/callback',
firstScreen=FirstScreen.register
))
```
"""
codeVerifier = OidcCore.generateCodeVerifier()
codeChallenge = OidcCore.generateCodeChallenge(codeVerifier)
state = OidcCore.generateState()
signInUrl = await self._buildSignInUrl(
redirectUri, codeChallenge, state, interactionMode
redirectUri,
codeChallenge,
state,
interactionMode,
firstScreen,
identifiers,
directSignIn,
extraParams,
)

self._setSignInSession(
Expand All @@ -300,8 +347,7 @@ async def signIn(
state=state,
)
)
for key in ["idToken", "accessToken", "refreshToken"]:
self._storage.delete(key)
self._clearAllTokens()

return signInUrl

Expand All @@ -323,9 +369,7 @@ async def signOut(self, postLogoutRedirectUri: Optional[str] = None) -> str:
return redirect(await client.signOut('https://example.com'))
```
"""
self._storage.delete("idToken")
self._storage.delete("refreshToken")
self._storage.delete("accessTokenMap")
self._clearAllTokens()

endSessionEndpoint = (await self.getOidcCore()).metadata.end_session_endpoint

Expand Down Expand Up @@ -438,6 +482,8 @@ async def getAccessTokenClaims(self, resource: str = "") -> AccessTokenClaims:
access token, an exception will be thrown.
"""
accessToken = await self.getAccessToken(resource)
if accessToken is None:
raise LogtoException("Failed to get access token claims.")
return OidcCore.decodeAccessToken(accessToken)

async def getOrganizationTokenClaims(
Expand Down Expand Up @@ -486,4 +532,8 @@ async def fetchUserInfo(self) -> UserInfoResponse:
is expired, it will be refreshed automatically.
"""
accessToken = await self.getAccessToken()
if accessToken is None:
raise LogtoException(
"Can not get access token and fail to fetch user info."
)
return await (await self.getOidcCore()).fetchUserInfo(accessToken)
93 changes: 89 additions & 4 deletions logto/LogtoClient_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from itertools import combinations
from typing import Any, Callable, Dict, Optional
from pytest_mock import MockerFixture
from urllib.parse import quote

import pytest
from pytest_mock import MockerFixture

from . import LogtoClient, LogtoConfig, LogtoException, Storage
from .utilities.test import mockHttp, mockProviderMetadata
from .models.oidc import (
AccessTokenClaims,
DirectSignInOption,
DirectSignInOptionMethod,
FirstScreen,
Identifier,
IdTokenClaims,
UserInfoScope,
)
from .models.response import TokenResponse, UserInfoResponse
from .models.oidc import IdTokenClaims, AccessTokenClaims, UserInfoScope
from .Storage import MemoryStorage, Storage
from .OidcCore import OidcCore
from .Storage import MemoryStorage, Storage
from .utilities.test import mockHttp, mockProviderMetadata

MockRequest = Callable[..., None]

Expand Down Expand Up @@ -110,6 +121,80 @@ async def test_signIn_allConfigs(self, client: LogtoClient) -> None:
== "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=email+phone+openid+offline_access+profile&resource=https%3A%2F%2Fresource1&resource=https%3A%2F%2Fresource2&prompt=login&code_challenge=codeChallenge&code_challenge_method=S256&state=state&interaction_mode=signUp"
)

async def test_signIn_firstScreen(self, client: LogtoClient) -> None:
# Get all possible identifier combinations
all_identifiers = list(Identifier)
possible_identifier_combinations: list[tuple[Identifier]] = []
for r in range(1, len(all_identifiers) + 1):
possible_identifier_combinations.extend(combinations(all_identifiers, r))

for firstScreen in FirstScreen:
for identifiers in possible_identifier_combinations:
url = await client.signIn(
"redirectUri",
firstScreen=firstScreen,
identifiers=list(identifiers),
)

expected_identifiers = "+".join(
quote(identifier.value) for identifier in identifiers
)
assert (
url
== f"https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&first_screen={quote(firstScreen.value)}&identifier={expected_identifiers}"
)

async def test_signIn_directSignIn_sso(self, client: LogtoClient) -> None:
url = await client.signIn(
"redirectUri",
directSignIn=DirectSignInOption(
method=DirectSignInOptionMethod.sso, identifier="arbitrary-sso-id"
),
)

assert (
url
== "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&direct_sign_in=sso%3Aarbitrary-sso-id"
)

async def test_signIn_directSignIn_social(self, client: LogtoClient) -> None:
url = await client.signIn(
"redirectUri",
directSignIn=DirectSignInOption(
method=DirectSignInOptionMethod.social, identifier="google"
),
)

assert (
url
== "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&direct_sign_in=social%3Agoogle"
)

async def test_signIn_extraParams(self, client: LogtoClient) -> None:
url = await client.signIn(
"redirectUri",
extraParams={"custom_param_1": "value_1", "custom_param_2": "value_2"},
)

assert (
url
== "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&custom_param_1=value_1&custom_param_2=value_2"
)

async def test_signIn_multipleParams(self, client: LogtoClient) -> None:
url = await client.signIn(
"redirectUri",
interactionMode="signUp",
firstScreen=FirstScreen.register,
identifiers=[Identifier.email],
extraParams={"custom_param_1": "value_1", "custom_param_2": "value_2"},
)

assert (
url
== "https://logto.app/oidc/auth?client_id=replace-with-your-app-id&redirect_uri=redirectUri&response_type=code&scope=openid+offline_access+profile&prompt=consent&code_challenge=codeChallenge&code_challenge_method=S256&state=state&interaction_mode=signUp&first_screen=identifier%3Aregister&identifier=email&custom_param_1=value_1&custom_param_2=value_2"
)

async def test_signOut(
self, client: LogtoClient, storage: Storage, mockRequest: MockRequest
) -> None:
Expand Down
11 changes: 9 additions & 2 deletions logto/OidcCore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import hashlib
import secrets
from typing import List, Optional

import aiohttp
from jwt import PyJWKClient
import jwt
from typing import List, Optional
from jwt import PyJWKClient

from .LogtoException import LogtoException
from .models.oidc import (
Expand Down Expand Up @@ -42,12 +43,14 @@ def __init__(self, metadata: OidcProviderMetadata) -> None:
metadata.jwks_uri, headers={"user-agent": "@logto/python", "accept": "*/*"}
)

@staticmethod
def generateState() -> str:
"""
Generate a random string (32 bytes) for the state parameter.
"""
return urlsafeEncode(secrets.token_bytes(32))

@staticmethod
def generateCodeVerifier() -> str:
"""
Generate a random code verifier string (32 bytes) for PKCE.
Expand All @@ -56,6 +59,7 @@ def generateCodeVerifier() -> str:
"""
return urlsafeEncode(secrets.token_bytes(32))

@staticmethod
def generateCodeChallenge(codeVerifier: str) -> str:
"""
Generate a code challenge string for the given code verifier string.
Expand All @@ -64,12 +68,14 @@ def generateCodeChallenge(codeVerifier: str) -> str:
"""
return urlsafeEncode(hashlib.sha256(codeVerifier.encode("ascii")).digest())

@staticmethod
def decodeIdToken(idToken: str) -> IdTokenClaims:
"""
Decode the ID Token and return the claims without verifying the signature.
"""
return IdTokenClaims(**jwt.decode(idToken, options={"verify_signature": False}))

@staticmethod
def decodeAccessToken(accessToken: str) -> AccessTokenClaims:
"""
Decode the access token and return the claims without verifying the signature.
Expand All @@ -78,6 +84,7 @@ def decodeAccessToken(accessToken: str) -> AccessTokenClaims:
**jwt.decode(accessToken, options={"verify_signature": False})
)

@staticmethod
async def getProviderMetadata(discoveryUrl: str) -> OidcProviderMetadata:
"""
Fetch the provider metadata from the discovery URL.
Expand Down
6 changes: 5 additions & 1 deletion logto/Storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class MemoryStorage(Storage):
See `Storage` for the interface.
"""

@staticmethod
def printWarning() -> None:
print(
"WARNING: Using MemoryStorage for Logto client, this should only be used for testing.",
Expand All @@ -67,7 +68,10 @@ def get(self, key: str) -> Optional[str]:

def set(self, key: str, value: Optional[str]) -> None:
MemoryStorage.printWarning()
self._data[key] = value
if value is not None:
self._data[key] = value
else:
self._data.pop(key, None)

def delete(self, key: str) -> None:
MemoryStorage.printWarning()
Expand Down
Loading