diff --git a/libraries/python/openai-client/tests/test_tokens.py b/libraries/python/openai-client/tests/test_tokens.py index dc47ad723..814179ab8 100644 --- a/libraries/python/openai-client/tests/test_tokens.py +++ b/libraries/python/openai-client/tests/test_tokens.py @@ -2,7 +2,7 @@ import openai_client import pytest -from openai import OpenAI +from openai import AuthenticationError, OpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam @@ -12,7 +12,19 @@ def client() -> OpenAI: if not api_key: pytest.skip("OPENAI_API_KEY is not set.") - return OpenAI(api_key=api_key) + client = OpenAI(api_key=api_key) + + # Test if the API key is valid by making a minimal request + try: + client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "test"}], + max_tokens=1, + ) + except AuthenticationError as e: + pytest.skip(f"OPENAI_API_KEY is invalid or deactivated: {e}") + + return client @pytest.mark.parametrize( diff --git a/workbench-app/.env.example b/workbench-app/.env.example index 7b1f58aaf..9e724df0f 100644 --- a/workbench-app/.env.example +++ b/workbench-app/.env.example @@ -4,7 +4,7 @@ # NOTE: If you set the environment variables in the host environment, those values will # take precedence over the values in this file. # -# The following is the client id from the Semantic Workbench GitHub sample app registration +# The following is the client id from the Semantic Workbench Consumer app registration # This is used to authenticate the user with the Semantic Workbench API and must match the # client id used by the Semantic Workbench service. The default value allows you to run the # sample app without registering your own app while running this app locally (https://127.0.0.1:300) @@ -12,7 +12,7 @@ # # See /docs/CUSTOM_APP_REGISTRATION.md for more information. # -VITE_SEMANTIC_WORKBENCH_CLIENT_ID=22cb77c3-ca98-4a26-b4db-ac4dcecba690 +VITE_SEMANTIC_WORKBENCH_CLIENT_ID=d0a2fed8-abb0-4831-8a24-09f5a0b54d97 # The authority to use for authentication requests. # The authority value depends on the type of account you want to authenticate and should @@ -23,7 +23,7 @@ VITE_SEMANTIC_WORKBENCH_CLIENT_ID=22cb77c3-ca98-4a26-b4db-ac4dcecba690 # - Work + School accounts: 'https://login.microsoftonline.com/organizations', # - Work + School + Personal: 'https://login.microsoftonline.com/common' # -VITE_SEMANTIC_WORKBENCH_AUTHORITY=https://login.microsoftonline.com/common +VITE_SEMANTIC_WORKBENCH_AUTHORITY=https://login.microsoftonline.com/consumer # The URL for the Semantic Workbench service. VITE_SEMANTIC_WORKBENCH_SERVICE_URL=http://localhost:3000 diff --git a/workbench-app/src/Constants.ts b/workbench-app/src/Constants.ts index da65980d3..385becef1 100644 --- a/workbench-app/src/Constants.ts +++ b/workbench-app/src/Constants.ts @@ -62,18 +62,19 @@ export const Constants = { msal: { method: 'redirect', // 'redirect' | 'popup' auth: { - // Semantic Workbench GitHub sample app registration + // Semantic Workbench app registration // The same value is set also in AuthSettings in // "semantic_workbench_service.config.py" in the backend // Can be overridden by env var VITE_SEMANTIC_WORKBENCH_CLIENT_ID - clientId: import.meta.env.VITE_SEMANTIC_WORKBENCH_CLIENT_ID || '22cb77c3-ca98-4a26-b4db-ac4dcecba690', + clientId: import.meta.env.VITE_SEMANTIC_WORKBENCH_CLIENT_ID || 'd0a2fed8-abb0-4831-8a24-09f5a0b54d97', // Specific tenant only: 'https://login.microsoftonline.com//', // Personal accounts only: 'https://login.microsoftonline.com/consumers', // Work + School accounts: 'https://login.microsoftonline.com/organizations', // Work + School + Personal: 'https://login.microsoftonline.com/common' // Can be overridden by env var VITE_SEMANTIC_WORKBENCH_AUTHORITY - authority: import.meta.env.VITE_SEMANTIC_WORKBENCH_AUTHORITY || 'https://login.microsoftonline.com/common', + authority: + import.meta.env.VITE_SEMANTIC_WORKBENCH_AUTHORITY || 'https://login.microsoftonline.com/consumers', }, cache: { cacheLocation: 'localStorage', diff --git a/workbench-app/src/libs/useWorkbenchEventSource.ts b/workbench-app/src/libs/useWorkbenchEventSource.ts index a0fc21d3c..adea69021 100644 --- a/workbench-app/src/libs/useWorkbenchEventSource.ts +++ b/workbench-app/src/libs/useWorkbenchEventSource.ts @@ -25,7 +25,6 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s const startEventSource = async () => { if (!isMounted) return; - const accessToken = await getAccessToken(); const idToken = await getIdTokenAsync(); // this promise is intentionally not awaited. it runs in the background and is cancelled when @@ -34,8 +33,7 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s signal: abortController.signal, openWhenHidden: true, headers: { - Authorization: `Bearer ${accessToken}`, - 'X-OpenIdToken': idToken, + Authorization: `Bearer ${idToken}`, }, async onopen(response) { if (!isMounted) return; @@ -85,33 +83,6 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s }, [endpoint, manager]); }; -const getAccessToken = async (forceRefresh?: boolean) => { - const msalInstance = await getMsalInstance(); - - const account = msalInstance.getActiveAccount(); - if (!account) { - throw new Error('No active account'); - } - - const response = await msalInstance - .acquireTokenSilent({ - ...AuthHelper.loginRequest, - account, - forceRefresh, - }) - .catch(async (error) => { - if (error instanceof InteractionRequiredAuthError) { - return await AuthHelper.loginAsync(msalInstance); - } - throw error; - }); - if (!response) { - throw new Error('Could not acquire access token'); - } - - return response.accessToken; -}; - const getIdTokenAsync = async (forceRefresh?: boolean) => { const msalInstance = await getMsalInstance(); diff --git a/workbench-app/src/libs/useWorkbenchService.ts b/workbench-app/src/libs/useWorkbenchService.ts index 7b23ec873..c131fb4f4 100644 --- a/workbench-app/src/libs/useWorkbenchService.ts +++ b/workbench-app/src/libs/useWorkbenchService.ts @@ -21,29 +21,6 @@ export const useWorkbenchService = () => { const account = useAccount(); const msal = useMsal(); - const getAccessTokenAsync = React.useCallback(async () => { - if (!account) { - throw new Error('No active account'); - } - - const response = await msal.instance - .acquireTokenSilent({ - ...AuthHelper.loginRequest, - account, - }) - .catch(async (error) => { - if (error instanceof InteractionRequiredAuthError) { - return await AuthHelper.loginAsync(msal.instance); - } - throw error; - }); - if (!response) { - dispatch(addError({ title: 'Failed to acquire token', message: 'Could not acquire access token' })); - throw new Error('Could not acquire access token'); - } - return response.accessToken; - }, [account, dispatch, msal.instance]); - const getIdTokenAsync = React.useCallback(async () => { if (!account) { throw new Error('No active account'); @@ -69,14 +46,12 @@ export const useWorkbenchService = () => { const tryFetchAsync = React.useCallback( async (operationTitle: string, url: string, options?: RequestInit): Promise => { - const accessToken = await getAccessTokenAsync(); const idToken = await getIdTokenAsync(); const response = await fetch(url, { ...options, headers: { ...options?.headers, - Authorization: `Bearer ${accessToken}`, - 'X-OpenIdToken': idToken, + Authorization: `Bearer ${idToken}`, }, }); @@ -89,19 +64,17 @@ export const useWorkbenchService = () => { return response; }, - [dispatch, getAccessTokenAsync, getIdTokenAsync], + [dispatch, getIdTokenAsync], ); const tryFetchStreamAsync = React.useCallback( async (operationTitle: string, url: string, options?: RequestInit): Promise => { - const accessToken = await getAccessTokenAsync(); const idToken = await getIdTokenAsync(); const response = await fetch(url, { ...options, headers: { ...options?.headers, - Authorization: `Bearer ${accessToken}`, - 'X-OpenIdToken': idToken, + Authorization: `Bearer ${idToken}`, }, }); @@ -114,7 +87,7 @@ export const useWorkbenchService = () => { return response; }, - [dispatch, getAccessTokenAsync, getIdTokenAsync], + [dispatch, getIdTokenAsync], ); const tryFetchFileAsync = React.useCallback( diff --git a/workbench-app/src/services/workbench/workbench.ts b/workbench-app/src/services/workbench/workbench.ts index 7b30dadda..36b67eb7d 100644 --- a/workbench-app/src/services/workbench/workbench.ts +++ b/workbench-app/src/services/workbench/workbench.ts @@ -50,7 +50,8 @@ const dynamicBaseQuery: BaseQueryFn bool: class AuthSettings(BaseSettings): allowed_jwt_algorithms: set[str] = {"RS256"} - allowed_app_id: str = "22cb77c3-ca98-4a26-b4db-ac4dcecba690" + allowed_app_id: str = "d0a2fed8-abb0-4831-8a24-09f5a0b54d97" class AssistantIdentifiers(BaseSettings): diff --git a/workbench-service/semantic_workbench_service/middleware.py b/workbench-service/semantic_workbench_service/middleware.py index 0bae49508..f8fd3249c 100644 --- a/workbench-service/semantic_workbench_service/middleware.py +++ b/workbench-service/semantic_workbench_service/middleware.py @@ -1,8 +1,9 @@ import logging import secrets import time +from collections.abc import Awaitable, Callable from functools import lru_cache -from typing import Any, Awaitable, Callable +from typing import Any import httpx from fastapi import HTTPException, Request, Response, status @@ -85,11 +86,22 @@ async def _user_principal_from_request(request: Request) -> auth.UserPrincipal | key=keys, options={"verify_signature": False, "verify_aud": False}, ) - app_id: str = decoded.get("appid", "") + # ID tokens have 'aud', access tokens have 'appid' + app_id: str = decoded.get("appid", "") or decoded.get("aud", "") + azp: str = decoded.get("azp", "") tid: str = decoded.get("tid", "") oid: str = decoded.get("oid", "") + sub: str = decoded.get("sub", "") name: str = decoded.get("name", "") - user_id = f"{tid}.{oid}" + + # For Entra ID tokens: use tid.oid + # For MSA tokens: use sub (since tid/oid are not present) + if tid and oid: + user_id = f"{tid}.{oid}" + elif sub: + user_id = sub + else: + raise ValueError("Token missing required user identification claims") except ExpiredSignatureError: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Expired token") @@ -101,8 +113,13 @@ async def _user_principal_from_request(request: Request) -> auth.UserPrincipal | if algorithm not in allowed_jwt_algorithms: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token algorithm") - if app_id != settings.auth.allowed_app_id: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid app") + # Verify token is for our application + # - ID tokens: 'aud' claim = our app ID + # - Access tokens: 'appid' claim = our app ID, or 'azp' = our app ID + if app_id != settings.auth.allowed_app_id and azp != settings.auth.allowed_app_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid app. App ID must match in client and server." + ) return auth.UserPrincipal(user_id=user_id, name=name) diff --git a/workbench-service/tests/test_middleware.py b/workbench-service/tests/test_middleware.py index 7ce20164c..366aa516a 100644 --- a/workbench-service/tests/test_middleware.py +++ b/workbench-service/tests/test_middleware.py @@ -20,9 +20,11 @@ async def test_auth_middleware_rejects_disallowed_algo(monkeypatch: pytest.Monke monkeypatch.setattr(settings.auth, "allowed_jwt_algorithms", {"RS256"}) tid = str(uuid.uuid4()) + oid = str(uuid.uuid4()) token = jwt.encode( claims={ "tid": tid, + "oid": oid, }, key="", algorithm="HS256", @@ -44,9 +46,13 @@ def test_auth_middleware_rejects_disallowed_app_id(monkeypatch: pytest.MonkeyPat monkeypatch.setattr(settings.auth, "allowed_app_id", "fake-app-id") monkeypatch.setattr(settings.auth, "allowed_jwt_algorithms", {algo}) + tid = str(uuid.uuid4()) + oid = str(uuid.uuid4()) token = jwt.encode( claims={ "appid": "not allowed", + "tid": tid, + "oid": oid, }, key="", algorithm=algo, @@ -59,7 +65,7 @@ def test_auth_middleware_rejects_disallowed_app_id(monkeypatch: pytest.MonkeyPat http_response = client.get("/", headers={"Authorization": f"Bearer {token}"}) assert http_response.status_code == 401 - assert http_response.json()["detail"].lower() == "invalid app" + assert http_response.json()["detail"].lower() == "invalid app. app id must match in client and server." def test_auth_middleware_rejects_missing_authorization_header():