Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class _FlowErrorTag(Enum):
NONE = "none"
MAGIC_FORMAT = "magic_format"
MAGIC_CODE_INCORRECT = "magic_code_incorrect"
PRECONDITION_FAILED = "precondition_failed"
OTHER = "other"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,32 @@ async def _continue_from_invoke_verify_state(

async def _continue_from_invoke_token_exchange(
self, activity: Activity
) -> TokenResponse:
) -> tuple[TokenResponse, _FlowErrorTag]:
"""Handles the continuation of the flow from an invoke activity for token exchange."""
token_exchange_request = activity.value
token_response = await self._user_token_client.user_token.exchange_token(
user_id=self._user_id,
connection_name=self._abs_oauth_connection_name,
channel_id=self._channel_id,
body=token_exchange_request,
)
return token_response
try:
token_response = await self._user_token_client.user_token.exchange_token(
user_id=self._user_id,
connection_name=self._abs_oauth_connection_name,
channel_id=self._channel_id,
body=token_exchange_request,
)
return token_response, _FlowErrorTag.NONE
except Exception as e:
# A 400 with 'ConsentRequired' means the user hasn't consented yet.
# Return None so the caller can send a 412 back to Teams, which will
# prompt the user for consent and retry the token exchange.
# Any other error is a critical failure and should propagate.
if getattr(e, "status", None) == 400 and "Consent Required" in getattr(
e, "message", ""
):
logger.info(
"Token exchange requires consent for user %s, returning None to trigger consent prompt",
self._user_id,
)

return None, _FlowErrorTag.PRECONDITION_FAILED
raise

async def continue_flow(self, activity: Activity) -> _FlowResponse:
"""Continues the OAuth flow based on the incoming activity.
Expand Down Expand Up @@ -289,7 +305,15 @@ async def continue_flow(self, activity: Activity) -> _FlowResponse:
activity.type == ActivityTypes.invoke
and activity.name == "signin/tokenExchange"
):
token_response = await self._continue_from_invoke_token_exchange(activity)
token_response, flow_error_tag = (
await self._continue_from_invoke_token_exchange(activity)
)
elif (
activity.type == ActivityTypes.invoke and activity.name == "signin/failure"
):
logger.debug("Handling signin/failure invoke activity")
token_response = None
flow_error_tag = _FlowErrorTag.OTHER
else:
raise ValueError(f"Unknown activity type {activity.type}")

Expand All @@ -299,7 +323,8 @@ async def continue_flow(self, activity: Activity) -> _FlowResponse:
if flow_error_tag != _FlowErrorTag.NONE:
logger.debug("Flow error occurred: %s", flow_error_tag)
self._flow_state.tag = _FlowStateTag.CONTINUE
self._use_attempt()
if flow_error_tag != _FlowErrorTag.PRECONDITION_FAILED:
self._use_attempt()
else:
self._flow_state.tag = _FlowStateTag.COMPLETE
self._flow_state.expiration = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@
from typing import Optional

from microsoft_agents.activity import (
Activity,
ActivityTypes,
Attachment,
ActionTypes,
CardAction,
OAuthCard,
TokenResponse,
)

from microsoft_agents.activity.invoke_response import InvokeResponse
from microsoft_agents.activity.token_exchange_invoke_request import (
TokenExchangeInvokeRequest,
)
from microsoft_agents.activity.token_exchange_invoke_response import (
TokenExchangeInvokeResponse,
)
from microsoft_agents.hosting.core._oauth._flow_state import _FlowErrorTag
from microsoft_agents.hosting.core.card_factory import CardFactory
from microsoft_agents.hosting.core.message_factory import MessageFactory
from microsoft_agents.hosting.core.connector.client import UserTokenClient
Expand Down Expand Up @@ -201,6 +211,26 @@ async def _handle_flow_response(
else:
logger.warning("Sign-in flow failed for unknown reasons.")
await context.send_activity("Sign-in failed. Please try again.")
elif (
flow_state.tag == _FlowStateTag.CONTINUE
and flow_response.flow_error_tag == _FlowErrorTag.PRECONDITION_FAILED
):
token_exchange_request = TokenExchangeInvokeRequest().model_validate(
context.activity.value
)
await context.send_activity(
Activity(
type=ActivityTypes.invoke_response,
value=InvokeResponse(
status=412,
body=TokenExchangeInvokeResponse(
id=token_exchange_request.id,
connection_name=flow_state.connection,
failure_detail="The Agent is unable to exchange token. Proceed with regular login.",
),
).model_dump(exclude_unset=True),
)
)

async def _sign_in(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from typing import TypeVar, Optional, Callable, Awaitable, Generic, cast
import jwt

from microsoft_agents.activity import Activity, TokenResponse
from microsoft_agents.activity import Activity, Channels, SignInConstants, TokenResponse
from microsoft_agents.activity.activity_types import ActivityTypes

from ...turn_context import TurnContext
from ...storage import Storage
Expand Down Expand Up @@ -261,9 +262,17 @@ async def _start_or_continue_sign_in(
await self._delete_sign_in_state(context)

elif sign_in_response.tag in [_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE]:
# store continuation activity and wait for next turn
sign_in_state.continuation_activity = context.activity
await self._save_sign_in_state(context, sign_in_state)
# Handling special case for Teams SSO, ConsentRequired
if not (
context.activity.channel_id.channel == Channels.ms_teams
and sign_in_state.continuation_activity
and context.activity.type == ActivityTypes.invoke
and context.activity.name
== SignInConstants.token_exchange_operation_name
):
# store continuation activity and wait for next turn
sign_in_state.continuation_activity = context.activity
await self._save_sign_in_state(context, sign_in_state)

return sign_in_response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _process_turn_results(self, context: TurnContext) -> Optional[InvokeResponse
self.INVOKE_RESPONSE_KEY
)
if not activity_invoke_response:
return InvokeResponse(status=HTTPStatus.NOT_IMPLEMENTED)
return InvokeResponse(status=HTTPStatus.OK)

return InvokeResponse.model_validate(activity_invoke_response.value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
from typing import Optional
from aiohttp import ClientSession
from aiohttp import ClientResponseError, ClientSession

from microsoft_agents.hosting.core.connector import UserTokenClientBase
from microsoft_agents.activity import (
Expand Down Expand Up @@ -297,8 +297,19 @@ async def exchange_token(
span.share(http_method="POST", status_code=response.status)

if response.status >= 300:
logger.error("Error exchanging token: %s", response.status)
response.raise_for_status()
response_text = await response.text("utf-8")
logger.error(
"Error exchanging token: %s %s",
response.status,
response_text,
)
raise ClientResponseError(
response.request_info,
response.history,
status=response.status,
message=response_text,
headers=response.headers,
)

data = await response.json()
return TokenResponse.model_validate(data)
Expand Down
127 changes: 126 additions & 1 deletion tests/hosting_core/app/_oauth/test_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from typing import Optional

from microsoft_agents.activity import Activity, ActivityTypes, TokenResponse
from microsoft_agents.activity import (
Activity,
ActivityTypes,
Channels,
SignInConstants,
TokenResponse,
)

from microsoft_agents.hosting.core.app.oauth import (
_SignInResponse,
Expand Down Expand Up @@ -721,3 +727,122 @@ async def test_on_turn_auth_intercept_with_intercept_complete(
final_state = await authorization._load_sign_in_state(context)
assert sign_in_state_eq(final_state, initial_state)
assert context.turn_state == expected_cache


class TestTeamsSSOConsentRequired(TestEnv):
"""Tests for the Teams SSO ConsentRequired special case in _start_or_continue_sign_in.

When a pending signin/tokenExchange invoke arrives on Teams and a continuation_activity
already exists in the sign-in state, the state must NOT be updated so the original
activity is preserved for the eventual resume turn.
"""

@pytest.fixture(params=[_FlowStateTag.BEGIN, _FlowStateTag.CONTINUE])
def pending_tag(self, request):
return request.param

@pytest.mark.asyncio
async def test_skips_state_update_for_teams_sso_consent_required(
self, mocker, storage, authorization, pending_tag
):
"""All four conditions met → continuation_activity is NOT overwritten."""
original_continuation = Activity(type=ActivityTypes.message, text="original")
teams_invoke_activity = Activity(
type=ActivityTypes.invoke,
channel_id=Channels.ms_teams,
from_property={"id": DEFAULTS.user_id},
name=SignInConstants.token_exchange_operation_name,
)
context = create_testing_TurnContext(mocker, activity=teams_invoke_activity)
initial_state = _SignInState(
active_handler_id=DEFAULTS.auth_handler_id,
continuation_activity=original_continuation,
)
await authorization._save_sign_in_state(context, initial_state)
mock_variants(mocker, sign_in_return=_SignInResponse(tag=pending_tag))

res = await authorization._start_or_continue_sign_in(
context, None, DEFAULTS.auth_handler_id
)

assert res.tag == pending_tag
final_state = await authorization._load_sign_in_state(context)
assert final_state is not None
assert final_state.continuation_activity == original_continuation
assert final_state.active_handler_id == DEFAULTS.auth_handler_id

@pytest.mark.asyncio
@pytest.mark.parametrize(
"channel_id, activity_type, activity_name, has_existing_continuation",
[
# Teams + invoke + tokenExchange but NO existing continuation_activity → saves
(
Channels.ms_teams,
ActivityTypes.invoke,
SignInConstants.token_exchange_operation_name,
False,
),
# Non-Teams channel with all other conditions met → saves
(
"directline",
ActivityTypes.invoke,
SignInConstants.token_exchange_operation_name,
True,
),
# Teams channel with non-invoke activity type → saves
(
Channels.ms_teams,
ActivityTypes.message,
None,
True,
),
# Teams invoke with a name other than token_exchange → saves
(
Channels.ms_teams,
ActivityTypes.invoke,
SignInConstants.verify_state_operation_name,
True,
),
],
)
async def test_saves_state_when_consent_required_condition_not_fully_met(
self,
mocker,
storage,
authorization,
pending_tag,
channel_id,
activity_type,
activity_name,
has_existing_continuation,
):
"""When any single condition of the Teams SSO guard is false, the state IS saved."""
activity_kwargs = dict(
type=activity_type,
channel_id=channel_id,
from_property={"id": DEFAULTS.user_id},
)
if activity_name is not None:
activity_kwargs["name"] = activity_name
activity = Activity(**activity_kwargs)
context = create_testing_TurnContext(mocker, activity=activity)

original_continuation = (
Activity(type=ActivityTypes.message, text="original")
if has_existing_continuation
else None
)
initial_state = _SignInState(
active_handler_id=DEFAULTS.auth_handler_id,
continuation_activity=original_continuation,
)
await authorization._save_sign_in_state(context, initial_state)
mock_variants(mocker, sign_in_return=_SignInResponse(tag=pending_tag))

await authorization._start_or_continue_sign_in(
context, None, DEFAULTS.auth_handler_id
)

final_state = await authorization._load_sign_in_state(context)
assert final_state is not None
assert final_state.continuation_activity == context.activity
Loading