Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions libraries/python/openai-client/tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import openai_client
import pytest
from openai import OpenAI
from openai import AuthenticationError, OpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam


Expand All @@ -12,7 +12,19 @@ def client() -> OpenAI:
if not api_key:
pytest.skip("OPENAI_API_KEY is not set.")

return OpenAI(api_key=api_key)
client = OpenAI(api_key=api_key)

# Test if the API key is valid by making a minimal request
try:
client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "test"}],
max_tokens=1,
)
except AuthenticationError as e:
pytest.skip(f"OPENAI_API_KEY is invalid or deactivated: {e}")

return client


@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions workbench-app/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# NOTE: If you set the environment variables in the host environment, those values will
# take precedence over the values in this file.
#
# The following is the client id from the Semantic Workbench GitHub sample app registration
# The following is the client id from the Semantic Workbench Consumer app registration
# This is used to authenticate the user with the Semantic Workbench API and must match the
# client id used by the Semantic Workbench service. The default value allows you to run the
# sample app without registering your own app while running this app locally (https://127.0.0.1:300)
# If you register your own app, you must update this value to match the client id of your app.
#
# See /docs/CUSTOM_APP_REGISTRATION.md for more information.
#
VITE_SEMANTIC_WORKBENCH_CLIENT_ID=22cb77c3-ca98-4a26-b4db-ac4dcecba690
VITE_SEMANTIC_WORKBENCH_CLIENT_ID=d0a2fed8-abb0-4831-8a24-09f5a0b54d97

# The authority to use for authentication requests.
# The authority value depends on the type of account you want to authenticate and should
Expand All @@ -23,7 +23,7 @@ VITE_SEMANTIC_WORKBENCH_CLIENT_ID=22cb77c3-ca98-4a26-b4db-ac4dcecba690
# - Work + School accounts: 'https://login.microsoftonline.com/organizations',
# - Work + School + Personal: 'https://login.microsoftonline.com/common'
#
VITE_SEMANTIC_WORKBENCH_AUTHORITY=https://login.microsoftonline.com/common
VITE_SEMANTIC_WORKBENCH_AUTHORITY=https://login.microsoftonline.com/consumer

# The URL for the Semantic Workbench service.
VITE_SEMANTIC_WORKBENCH_SERVICE_URL=http://localhost:3000
7 changes: 4 additions & 3 deletions workbench-app/src/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,19 @@ export const Constants = {
msal: {
method: 'redirect', // 'redirect' | 'popup'
auth: {
// Semantic Workbench GitHub sample app registration
// Semantic Workbench app registration
// The same value is set also in AuthSettings in
// "semantic_workbench_service.config.py" in the backend
// Can be overridden by env var VITE_SEMANTIC_WORKBENCH_CLIENT_ID
clientId: import.meta.env.VITE_SEMANTIC_WORKBENCH_CLIENT_ID || '22cb77c3-ca98-4a26-b4db-ac4dcecba690',
clientId: import.meta.env.VITE_SEMANTIC_WORKBENCH_CLIENT_ID || 'd0a2fed8-abb0-4831-8a24-09f5a0b54d97',

// Specific tenant only: 'https://login.microsoftonline.com/<tenant>/',
// Personal accounts only: 'https://login.microsoftonline.com/consumers',
// Work + School accounts: 'https://login.microsoftonline.com/organizations',
// Work + School + Personal: 'https://login.microsoftonline.com/common'
// Can be overridden by env var VITE_SEMANTIC_WORKBENCH_AUTHORITY
authority: import.meta.env.VITE_SEMANTIC_WORKBENCH_AUTHORITY || 'https://login.microsoftonline.com/common',
authority:
import.meta.env.VITE_SEMANTIC_WORKBENCH_AUTHORITY || 'https://login.microsoftonline.com/consumers',
},
cache: {
cacheLocation: 'localStorage',
Expand Down
31 changes: 1 addition & 30 deletions workbench-app/src/libs/useWorkbenchEventSource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s
const startEventSource = async () => {
if (!isMounted) return;

const accessToken = await getAccessToken();
const idToken = await getIdTokenAsync();

// this promise is intentionally not awaited. it runs in the background and is cancelled when
Expand All @@ -34,8 +33,7 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s
signal: abortController.signal,
openWhenHidden: true,
headers: {
Authorization: `Bearer ${accessToken}`,
'X-OpenIdToken': idToken,
Authorization: `Bearer ${idToken}`,
},
async onopen(response) {
if (!isMounted) return;
Expand Down Expand Up @@ -85,33 +83,6 @@ const useWorkbenchEventSource = (manager: EventSubscriptionManager, endpoint?: s
}, [endpoint, manager]);
};

const getAccessToken = async (forceRefresh?: boolean) => {
const msalInstance = await getMsalInstance();

const account = msalInstance.getActiveAccount();
if (!account) {
throw new Error('No active account');
}

const response = await msalInstance
.acquireTokenSilent({
...AuthHelper.loginRequest,
account,
forceRefresh,
})
.catch(async (error) => {
if (error instanceof InteractionRequiredAuthError) {
return await AuthHelper.loginAsync(msalInstance);
}
throw error;
});
if (!response) {
throw new Error('Could not acquire access token');
}

return response.accessToken;
};

const getIdTokenAsync = async (forceRefresh?: boolean) => {
const msalInstance = await getMsalInstance();

Expand Down
35 changes: 4 additions & 31 deletions workbench-app/src/libs/useWorkbenchService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,6 @@ export const useWorkbenchService = () => {
const account = useAccount();
const msal = useMsal();

const getAccessTokenAsync = React.useCallback(async () => {
if (!account) {
throw new Error('No active account');
}

const response = await msal.instance
.acquireTokenSilent({
...AuthHelper.loginRequest,
account,
})
.catch(async (error) => {
if (error instanceof InteractionRequiredAuthError) {
return await AuthHelper.loginAsync(msal.instance);
}
throw error;
});
if (!response) {
dispatch(addError({ title: 'Failed to acquire token', message: 'Could not acquire access token' }));
throw new Error('Could not acquire access token');
}
return response.accessToken;
}, [account, dispatch, msal.instance]);

const getIdTokenAsync = React.useCallback(async () => {
if (!account) {
throw new Error('No active account');
Expand All @@ -69,14 +46,12 @@ export const useWorkbenchService = () => {

const tryFetchAsync = React.useCallback(
async (operationTitle: string, url: string, options?: RequestInit): Promise<Response> => {
const accessToken = await getAccessTokenAsync();
const idToken = await getIdTokenAsync();
const response = await fetch(url, {
...options,
headers: {
...options?.headers,
Authorization: `Bearer ${accessToken}`,
'X-OpenIdToken': idToken,
Authorization: `Bearer ${idToken}`,
},
});

Expand All @@ -89,19 +64,17 @@ export const useWorkbenchService = () => {

return response;
},
[dispatch, getAccessTokenAsync, getIdTokenAsync],
[dispatch, getIdTokenAsync],
);

const tryFetchStreamAsync = React.useCallback(
async (operationTitle: string, url: string, options?: RequestInit): Promise<Response> => {
const accessToken = await getAccessTokenAsync();
const idToken = await getIdTokenAsync();
const response = await fetch(url, {
...options,
headers: {
...options?.headers,
Authorization: `Bearer ${accessToken}`,
'X-OpenIdToken': idToken,
Authorization: `Bearer ${idToken}`,
},
});

Expand All @@ -114,7 +87,7 @@ export const useWorkbenchService = () => {

return response;
},
[dispatch, getAccessTokenAsync, getIdTokenAsync],
[dispatch, getIdTokenAsync],
);

const tryFetchFileAsync = React.useCallback(
Expand Down
3 changes: 2 additions & 1 deletion workbench-app/src/services/workbench/workbench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
throw new Error('Could not acquire token');
}

headers.set('Authorization', `Bearer ${response.accessToken}`);
// Use idToken (always JWT format) instead of accessToken (may be compact format for MSA)
headers.set('Authorization', `Bearer ${response.idToken}`);
headers.set('X-Request-ID', generateUuid().replace(/-/g, '').toLowerCase());
return headers;
};
Expand Down
2 changes: 1 addition & 1 deletion workbench-service/semantic_workbench_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def is_secured(self) -> bool:

class AuthSettings(BaseSettings):
allowed_jwt_algorithms: set[str] = {"RS256"}
allowed_app_id: str = "22cb77c3-ca98-4a26-b4db-ac4dcecba690"
allowed_app_id: str = "d0a2fed8-abb0-4831-8a24-09f5a0b54d97"


class AssistantIdentifiers(BaseSettings):
Expand Down
27 changes: 22 additions & 5 deletions workbench-service/semantic_workbench_service/middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import secrets
import time
from collections.abc import Awaitable, Callable
from functools import lru_cache
from typing import Any, Awaitable, Callable
from typing import Any

import httpx
from fastapi import HTTPException, Request, Response, status
Expand Down Expand Up @@ -85,11 +86,22 @@ async def _user_principal_from_request(request: Request) -> auth.UserPrincipal |
key=keys,
options={"verify_signature": False, "verify_aud": False},
)
app_id: str = decoded.get("appid", "")
# ID tokens have 'aud', access tokens have 'appid'
app_id: str = decoded.get("appid", "") or decoded.get("aud", "")
azp: str = decoded.get("azp", "")
tid: str = decoded.get("tid", "")
oid: str = decoded.get("oid", "")
sub: str = decoded.get("sub", "")
name: str = decoded.get("name", "")
user_id = f"{tid}.{oid}"

# For Entra ID tokens: use tid.oid
# For MSA tokens: use sub (since tid/oid are not present)
if tid and oid:
user_id = f"{tid}.{oid}"
elif sub:
user_id = sub
else:
raise ValueError("Token missing required user identification claims")

except ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Expired token")
Expand All @@ -101,8 +113,13 @@ async def _user_principal_from_request(request: Request) -> auth.UserPrincipal |
if algorithm not in allowed_jwt_algorithms:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token algorithm")

if app_id != settings.auth.allowed_app_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid app")
# Verify token is for our application
# - ID tokens: 'aud' claim = our app ID
# - Access tokens: 'appid' claim = our app ID, or 'azp' = our app ID
if app_id != settings.auth.allowed_app_id and azp != settings.auth.allowed_app_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid app. App ID must match in client and server."
)

return auth.UserPrincipal(user_id=user_id, name=name)

Expand Down
8 changes: 7 additions & 1 deletion workbench-service/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ async def test_auth_middleware_rejects_disallowed_algo(monkeypatch: pytest.Monke
monkeypatch.setattr(settings.auth, "allowed_jwt_algorithms", {"RS256"})

tid = str(uuid.uuid4())
oid = str(uuid.uuid4())
token = jwt.encode(
claims={
"tid": tid,
"oid": oid,
},
key="",
algorithm="HS256",
Expand All @@ -44,9 +46,13 @@ def test_auth_middleware_rejects_disallowed_app_id(monkeypatch: pytest.MonkeyPat
monkeypatch.setattr(settings.auth, "allowed_app_id", "fake-app-id")
monkeypatch.setattr(settings.auth, "allowed_jwt_algorithms", {algo})

tid = str(uuid.uuid4())
oid = str(uuid.uuid4())
token = jwt.encode(
claims={
"appid": "not allowed",
"tid": tid,
"oid": oid,
},
key="",
algorithm=algo,
Expand All @@ -59,7 +65,7 @@ def test_auth_middleware_rejects_disallowed_app_id(monkeypatch: pytest.MonkeyPat
http_response = client.get("/", headers={"Authorization": f"Bearer {token}"})

assert http_response.status_code == 401
assert http_response.json()["detail"].lower() == "invalid app"
assert http_response.json()["detail"].lower() == "invalid app. app id must match in client and server."


def test_auth_middleware_rejects_missing_authorization_header():
Expand Down