diff --git a/packages/apps/src/microsoft/teams/apps/__init__.py b/packages/apps/src/microsoft/teams/apps/__init__.py index d7dc4712..2aa32332 100644 --- a/packages/apps/src/microsoft/teams/apps/__init__.py +++ b/packages/apps/src/microsoft/teams/apps/__init__.py @@ -5,7 +5,6 @@ from . import auth, contexts, events, plugins from .app import App -from .app_tokens import AppTokens from .auth import * # noqa: F403 from .contexts import * # noqa: F403 from .events import * # noqa: F401, F403 @@ -16,7 +15,7 @@ from .routing import ActivityContext # Combine all exports from submodules -__all__: list[str] = ["App", "AppOptions", "HttpPlugin", "HttpStream", "ActivityContext", "AppTokens"] +__all__: list[str] = ["App", "AppOptions", "HttpPlugin", "HttpStream", "ActivityContext"] __all__.extend(auth.__all__) __all__.extend(events.__all__) __all__.extend(plugins.__all__) diff --git a/packages/apps/src/microsoft/teams/apps/app.py b/packages/apps/src/microsoft/teams/apps/app.py index f82104e2..d30e92c2 100644 --- a/packages/apps/src/microsoft/teams/apps/app.py +++ b/packages/apps/src/microsoft/teams/apps/app.py @@ -21,7 +21,6 @@ ConversationAccount, ConversationReference, Credentials, - JsonWebToken, MessageActivityInput, TokenCredentials, ) @@ -32,7 +31,6 @@ from .app_oauth import OauthHandlers from .app_plugins import PluginProcessor from .app_process import ActivityProcessor -from .app_tokens import AppTokens from .auth import TokenValidator from .auth.remote_function_jwt_middleware import remote_function_jwt_validation from .container import Container @@ -45,12 +43,12 @@ get_event_type_from_signature, is_registered_event, ) -from .graph_token_manager import GraphTokenManager from .http_plugin import HttpPlugin from .options import AppOptions, InternalAppOptions from .plugins import PluginBase, PluginStartEvent, get_metadata from .routing import ActivityHandlerMixin, ActivityRouter from .routing.activity_context import ActivityContext +from .token_manager import TokenManager version = importlib.metadata.version("microsoft-teams-apps") @@ -83,22 +81,26 @@ def __init__(self, **options: Unpack[AppOptions]): self._events = EventEmitter[EventType]() self._router = ActivityRouter() - self._tokens = AppTokens() self.credentials = self._init_credentials() + self._token_manager = TokenManager( + http_client=self.http_client, + credentials=self.credentials, + logger=self.log, + default_connection_name=self.options.default_connection_name, + ) + self.container = Container() self.container.set_provider("id", providers.Object(self.id)) - self.container.set_provider("name", providers.Object(self.name)) self.container.set_provider("credentials", providers.Object(self.credentials)) - self.container.set_provider("bot_token", providers.Callable(lambda: self.tokens.bot)) - self.container.set_provider("graph_token", providers.Callable(lambda: self.tokens.graph)) + self.container.set_provider("bot_token", providers.Factory(lambda: self._get_or_get_bot_token)) self.container.set_provider("logger", providers.Object(self.log)) self.container.set_provider("storage", providers.Object(self.storage)) self.container.set_provider(self.http_client.__class__.__name__, providers.Factory(lambda: self.http_client)) self.api = ApiClient( "https://smba.trafficmanager.net/teams", - self.http_client.clone(ClientOptions(token=lambda: self.tokens.bot)), + self.http_client.clone(ClientOptions(token=self._get_or_get_bot_token)), ) plugins: List[PluginBase] = list(self.options.plugins) @@ -125,11 +127,6 @@ def __init__(self, **options: Unpack[AppOptions]): self._running = False # initialize all event, activity, and plugin processors - self.graph_token_manager = GraphTokenManager( - api_client=self.api, - credentials=self.credentials, - logger=self.log, - ) self.activity_processor = ActivityProcessor( self._router, self.log, @@ -137,8 +134,7 @@ def __init__(self, **options: Unpack[AppOptions]): self.storage, self.options.default_connection_name, self.http_client, - self.tokens, - self.graph_token_manager, + self._token_manager, ) self.event_manager = EventManager(self._events, self.activity_processor) self.activity_processor.event_manager = self.event_manager @@ -169,11 +165,6 @@ def is_running(self) -> bool: """Whether the app is currently running.""" return self._running - @property - def tokens(self) -> AppTokens: - """Current authentication tokens.""" - return self._tokens - @property def logger(self) -> Logger: """The logger instance used by the app.""" @@ -191,17 +182,10 @@ def router(self) -> ActivityRouter: @property def id(self) -> Optional[str]: - """The app's ID from tokens.""" - return ( - self._tokens.bot.app_id if self._tokens.bot else self._tokens.graph.app_id if self._tokens.graph else None - ) - - @property - def name(self) -> Optional[str]: - """The app's name from tokens.""" - return getattr(self._tokens.bot, "app_display_name", None) or getattr( - self._tokens.graph, "app_display_name", None - ) + """The app's ID from credentials.""" + if not self.credentials: + return None + return self.credentials.client_id async def start(self, port: Optional[int] = None) -> None: """ @@ -220,9 +204,6 @@ async def start(self, port: Optional[int] = None) -> None: self._port = port or int(os.getenv("PORT", "3978")) try: - await self._refresh_tokens(force=True) - self._running = True - for plugin in self.plugins: # Inject the dependencies self._plugin_processor.inject(plugin) @@ -234,6 +215,7 @@ async def on_http_ready() -> None: self.log.info("Teams app started successfully") assert self._port is not None, "Port must be set before emitting start event" self._events.emit("start", StartEvent(port=self._port)) + self._running = True self.http.on_ready_callback = on_http_ready @@ -280,13 +262,13 @@ async def on_http_stopped() -> None: async def send(self, conversation_id: str, activity: str | ActivityParams | AdaptiveCard): """Send an activity proactively.""" - if self.id is None or self.name is None: + if self.id is None: raise ValueError("app not started") conversation_ref = ConversationReference( channel_id="msteams", service_url=self.api.service_url, - bot=Account(id=self.id, name=self.name, role="bot"), + bot=Account(id=self.id, role="bot"), conversation=ConversationAccount(id=conversation_id, conversation_type="personal"), ) @@ -326,65 +308,6 @@ def _init_credentials(self) -> Optional[Credentials]: return None - async def _refresh_tokens(self, force: bool = False) -> None: - """Refresh bot and graph tokens.""" - await asyncio.gather(self._refresh_bot_token(force), self._refresh_graph_token(force), return_exceptions=True) - - async def _refresh_bot_token(self, force: bool = False) -> None: - """Refresh the bot authentication token.""" - if not self.credentials: - self.log.warning("No credentials provided, skipping bot token refresh") - return - - if not force and self._tokens.bot and not self._tokens.bot.is_expired(): - return - - if self._tokens.bot: - self.log.debug("Refreshing bot token") - - try: - token_response = await self.api.bots.token.get(self.credentials) - self._tokens.bot = JsonWebToken(token_response.access_token) - self.log.debug("Bot token refreshed successfully") - - except Exception as error: - self.log.error(f"Failed to refresh bot token: {error}") - - self._events.emit("error", ErrorEvent(error, context={"method": "_refresh_bot_token"})) - raise - - async def _refresh_graph_token(self, force: bool = False) -> None: - """Refresh the Graph API token.""" - if not self.credentials: - self.log.warning("No credentials provided, skipping graph token refresh") - return - - if not force and self._tokens.graph and not self._tokens.graph.is_expired(): - return - - if self._tokens.graph: - self.log.debug("Refreshing graph token") - - try: - # Use GraphTokenManager for tenant-aware token management - tenant_id = self.credentials.tenant_id if self.credentials else None - token = await self.graph_token_manager.get_token(tenant_id) - - if token: - self._tokens.graph = token - self.log.debug("Graph token refreshed successfully") - - # Emit token acquired event - self._events.emit("token", {"type": "graph", "token": self._tokens.graph}) - else: - self.log.debug("Failed to get graph token from GraphTokenManager") - - except Exception as error: - self.log.error(f"Failed to refresh graph token: {error}") - - self._events.emit("error", ErrorEvent(error, context={"method": "_refresh_graph_token"})) - raise - @overload def event(self, func_or_event_type: F) -> F: """Register event handler with auto-detected type from function signature.""" @@ -521,7 +444,6 @@ async def endpoint(req: Request): async def call_next(r: Request) -> Any: ctx = FunctionContext( id=self.id, - name=self.name, api=self.api, http=self.http, log=self.log, @@ -541,3 +463,6 @@ async def call_next(r: Request) -> Any: # Named decoration: @app.func("name") return decorator + + async def _get_or_get_bot_token(self): + return await self._token_manager.get_bot_token() diff --git a/packages/apps/src/microsoft/teams/apps/app_process.py b/packages/apps/src/microsoft/teams/apps/app_process.py index cf57e473..509b4888 100644 --- a/packages/apps/src/microsoft/teams/apps/app_process.py +++ b/packages/apps/src/microsoft/teams/apps/app_process.py @@ -11,23 +11,22 @@ ActivityParams, ApiClient, ConversationReference, - GetUserTokenParams, InvokeResponse, SentActivity, TokenProtocol, is_invoke_response, ) +from microsoft.teams.api.clients.user.params import GetUserTokenParams from microsoft.teams.cards import AdaptiveCard from microsoft.teams.common import Client, ClientOptions, LocalStorage, Storage if TYPE_CHECKING: from .app_events import EventManager -from .app_tokens import AppTokens from .events import ActivityEvent, ActivityResponseEvent, ActivitySentEvent -from .graph_token_manager import GraphTokenManager from .plugins import PluginActivityEvent, PluginBase, Sender from .routing.activity_context import ActivityContext from .routing.router import ActivityHandler, ActivityRouter +from .token_manager import TokenManager from .utils import extract_tenant_id @@ -42,8 +41,7 @@ def __init__( storage: Union[Storage[str, Any], LocalStorage[Any]], default_connection_name: str, http_client: Client, - token: AppTokens, - graph_token_manager: GraphTokenManager, + token_manager: TokenManager, ) -> None: self.router = router self.logger = logger @@ -51,21 +49,12 @@ def __init__( self.storage = storage self.default_connection_name = default_connection_name self.http_client = http_client - self.tokens = token - self._graph_token_manager = graph_token_manager + self.token_manager = token_manager # This will be set after the EventManager is initialized due to # a circular dependency self.event_manager: Optional["EventManager"] = None - async def _get_or_refresh_graph_token(self, tenant_id: Optional[str] = None) -> Optional[TokenProtocol]: - """Get the current graph token or refresh it if needed.""" - try: - return await self._graph_token_manager.get_token(tenant_id) - except Exception as e: - self.logger.error(f"Failed to get graph token via manager: {e}") - return self.tokens.graph - async def _build_context( self, activity: ActivityBase, @@ -92,7 +81,9 @@ async def _build_context( locale=activity.locale, user=activity.from_, ) - api_client = ApiClient(service_url, self.http_client.clone(ClientOptions(token=self.tokens.bot))) + api_client = ApiClient( + service_url, self.http_client.clone(ClientOptions(token=self.token_manager.get_bot_token)) + ) # Check if user is signed in is_signed_in = False @@ -100,21 +91,21 @@ async def _build_context( try: user_token_res = await api_client.users.token.get( GetUserTokenParams( - connection_name=self.default_connection_name, - user_id=activity.from_.id, channel_id=activity.channel_id, + user_id=activity.from_.id, + connection_name=self.default_connection_name, ) ) + user_token = user_token_res.token is_signed_in = True except Exception: # User token not available + self.logger.debug("No user token available") pass tenant_id = extract_tenant_id(activity) - graph_token = await self._get_or_refresh_graph_token(tenant_id) - activityCtx = ActivityContext( activity, self.id or "", @@ -126,7 +117,7 @@ async def _build_context( is_signed_in, self.default_connection_name, sender, - app_token=graph_token, + app_token=lambda: self.token_manager.get_graph_token(tenant_id), ) send = activityCtx.send diff --git a/packages/apps/src/microsoft/teams/apps/app_tokens.py b/packages/apps/src/microsoft/teams/apps/app_tokens.py deleted file mode 100644 index a9a9c1f4..00000000 --- a/packages/apps/src/microsoft/teams/apps/app_tokens.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT License. -""" - -from dataclasses import dataclass -from typing import Optional - -from microsoft.teams.api.auth.token import TokenProtocol - - -@dataclass -class AppTokens: - """Application tokens for API access.""" - - bot: Optional[TokenProtocol] = None - graph: Optional[TokenProtocol] = None diff --git a/packages/apps/src/microsoft/teams/apps/graph_token_manager.py b/packages/apps/src/microsoft/teams/apps/graph_token_manager.py deleted file mode 100644 index 83c19471..00000000 --- a/packages/apps/src/microsoft/teams/apps/graph_token_manager.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT License. -""" - -import logging -from typing import Dict, Optional - -from microsoft.teams.api import ApiClient, ClientCredentials, Credentials, JsonWebToken, TokenProtocol - - -class GraphTokenManager: - """Simple token manager for Graph API tokens.""" - - def __init__( - self, - api_client: "ApiClient", - credentials: Optional["Credentials"], - logger: Optional[logging.Logger] = None, - ): - self._api_client = api_client - self._credentials = credentials - - if not logger: - self._logger = logging.getLogger(__name__ + ".GraphTokenManager") - else: - self._logger = logger.getChild("GraphTokenManager") - - self._token_cache: Dict[str, TokenProtocol] = {} - - async def get_token(self, tenant_id: Optional[str] = None) -> Optional[TokenProtocol]: - """Get a Graph token for the specified tenant.""" - if not self._credentials: - return None - - if not tenant_id: - tenant_id = "botframework.com" # Default tenant ID, assuming multi-tenant app - - # Check cache first - cached_token = self._token_cache.get(tenant_id) - if cached_token and not cached_token.is_expired(): - return cached_token - - # Refresh token - try: - tenant_credentials = self._credentials - if isinstance(self._credentials, ClientCredentials): - tenant_credentials = ClientCredentials( - client_id=self._credentials.client_id, - client_secret=self._credentials.client_secret, - tenant_id=tenant_id, - ) - - response = await self._api_client.bots.token.get_graph(tenant_credentials) - token = JsonWebToken(response.access_token) - self._token_cache[tenant_id] = token - - self._logger.debug(f"Refreshed graph token for tenant {tenant_id}") - - return token - - except Exception as e: - self._logger.error(f"Failed to refresh graph token for {tenant_id}: {e}") - return None diff --git a/packages/apps/src/microsoft/teams/apps/http_plugin.py b/packages/apps/src/microsoft/teams/apps/http_plugin.py index ab0c6d40..651cfead 100644 --- a/packages/apps/src/microsoft/teams/apps/http_plugin.py +++ b/packages/apps/src/microsoft/teams/apps/http_plugin.py @@ -24,7 +24,7 @@ TokenProtocol, ) from microsoft.teams.apps.http_stream import HttpStream -from microsoft.teams.common.http.client import Client, ClientOptions +from microsoft.teams.common.http import Client, ClientOptions, Token from microsoft.teams.common.logging import ConsoleLogger from pydantic import BaseModel, ValidationError from starlette.applications import Starlette @@ -60,8 +60,7 @@ class HttpPlugin(Sender): client: Annotated[Client, DependencyMetadata()] - bot_token: Annotated[Optional[Callable[[], TokenProtocol]], DependencyMetadata(optional=True)] - graph_token: Annotated[Optional[Callable[[], TokenProtocol]], DependencyMetadata(optional=True)] + bot_token: Annotated[Optional[Callable[[], Token]], DependencyMetadata(optional=True)] lifespans: list[Lifespan[Starlette]] = [] diff --git a/packages/apps/src/microsoft/teams/apps/plugins/metadata.py b/packages/apps/src/microsoft/teams/apps/plugins/metadata.py index 3e0c9352..98b34a9f 100644 --- a/packages/apps/src/microsoft/teams/apps/plugins/metadata.py +++ b/packages/apps/src/microsoft/teams/apps/plugins/metadata.py @@ -99,12 +99,6 @@ class BotTokenDependencyOptions(DependencyMetadata): optional = True -@dataclass -class GraphTokenDependencyOptions(DependencyMetadata): - name = "graph_token" - optional = True - - @dataclass class LoggerDependencyOptions(DependencyMetadata): name = "logger" @@ -129,7 +123,6 @@ class PluginDependencyOptions(DependencyMetadata): ManifestDependencyOptions, CredentialsDependencyOptions, BotTokenDependencyOptions, - GraphTokenDependencyOptions, LoggerDependencyOptions, StorageDependencyOptions, PluginDependencyOptions, diff --git a/packages/apps/src/microsoft/teams/apps/routing/activity_context.py b/packages/apps/src/microsoft/teams/apps/routing/activity_context.py index eaf2f651..4095be7c 100644 --- a/packages/apps/src/microsoft/teams/apps/routing/activity_context.py +++ b/packages/apps/src/microsoft/teams/apps/routing/activity_context.py @@ -26,7 +26,6 @@ TokenExchangeResource, TokenExchangeState, TokenPostResource, - TokenProtocol, ) from microsoft.teams.api.models.attachment.card_attachment import ( OAuthCardAttachment, @@ -35,6 +34,7 @@ from microsoft.teams.api.models.oauth import OAuthCard from microsoft.teams.cards import AdaptiveCard from microsoft.teams.common import Storage +from microsoft.teams.common.http.client_token import Token if TYPE_CHECKING: from msgraph.graph_service_client import GraphServiceClient @@ -46,7 +46,7 @@ SendCallable = Callable[[str | ActivityParams | AdaptiveCard], Awaitable[SentActivity]] -def _get_graph_client(token: TokenProtocol): +def _get_graph_client(token: Token): """Lazy import and call get_graph_client when needed.""" try: from microsoft.teams.graph import get_graph_client @@ -97,7 +97,7 @@ def __init__( is_signed_in: bool, connection_name: str, sender: Sender, - app_token: Optional[TokenProtocol], + app_token: Token, ): self.activity = activity self.app_id = app_id @@ -158,9 +158,6 @@ def app_graph(self) -> "GraphServiceClient": ImportError: If the graph dependencies are not installed. """ - if not self._app_token: - raise ValueError("No app token available for Graph client") - if self._app_graph is None: try: self._app_graph = _get_graph_client(self._app_token) diff --git a/packages/apps/src/microsoft/teams/apps/token_manager.py b/packages/apps/src/microsoft/teams/apps/token_manager.py new file mode 100644 index 00000000..d2173a53 --- /dev/null +++ b/packages/apps/src/microsoft/teams/apps/token_manager.py @@ -0,0 +1,98 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import logging +from typing import Optional + +from microsoft.teams.api import ( + BotTokenClient, + ClientCredentials, + Credentials, + JsonWebToken, + TokenProtocol, +) +from microsoft.teams.common import Client, ConsoleLogger, LocalStorage, LocalStorageOptions + + +class TokenManager: + """Manages authentication tokens for the Teams application.""" + + def __init__( + self, + http_client: Client, + credentials: Optional[Credentials], + logger: Optional[logging.Logger] = None, + default_connection_name: Optional[str] = None, + ): + self._bot_token_client = BotTokenClient(http_client.clone()) + self._credentials = credentials + self._default_connection_name = default_connection_name + + if not logger: + self._logger = ConsoleLogger().create_logger("TokenManager") + else: + self._logger = logger.getChild("TokenManager") + + self._bot_token: Optional[TokenProtocol] = None + + # Key: tenant_id (empty string "" for default app graph token) + self._graph_tokens: LocalStorage[TokenProtocol] = LocalStorage({}, LocalStorageOptions(max=20000)) + + async def get_bot_token(self, force: bool = False) -> Optional[TokenProtocol]: + """Refresh the bot authentication token.""" + if not self._credentials: + self._logger.warning("No credentials provided, skipping bot token refresh") + return None + + if not force and self._bot_token and not self._bot_token.is_expired(): + return self._bot_token + + if self._bot_token: + self._logger.debug("Refreshing bot token") + else: + self._logger.debug("Retrieving bot token") + + token_response = await self._bot_token_client.get(self._credentials) + self._bot_token = JsonWebToken(token_response.access_token) + self._logger.debug("Bot token refreshed successfully") + return self._bot_token + + async def get_graph_token(self, tenant_id: Optional[str] = None, force: bool = False) -> Optional[TokenProtocol]: + """ + Get or refresh a Graph API token. + + Args: + tenant_id: If provided, gets a tenant-specific token. Otherwise uses app's default. + force: Force refresh even if token is not expired + + Returns: + The graph token or None if not available + """ + if not self._credentials: + self._logger.debug("No credentials provided for graph token refresh") + return None + + # Use empty string as key for default graph token + key = tenant_id or "" + + cached = self._graph_tokens.get(key) + if not force and cached and not cached.is_expired(): + return cached + + creds = self._credentials + if tenant_id and isinstance(self._credentials, ClientCredentials): + creds = ClientCredentials( + client_id=self._credentials.client_id, + client_secret=self._credentials.client_secret, + tenant_id=tenant_id, + ) + + response = await self._bot_token_client.get_graph(creds) + token = JsonWebToken(response.access_token) + self._graph_tokens.set(key, token) + + self._logger.debug(f"Refreshed graph token tenant_id={tenant_id}") + + return token diff --git a/packages/apps/tests/test_app.py b/packages/apps/tests/test_app.py index 6c9b40c9..134c0151 100644 --- a/packages/apps/tests/test_app.py +++ b/packages/apps/tests/test_app.py @@ -131,11 +131,14 @@ def test_app_starts_successfully(self, basic_options): @pytest.mark.asyncio async def test_app_lifecycle_start_stop(self, app_with_options): """Test basic app lifecycle: start and stop.""" + # Mock the underlying HTTP server to avoid actual server startup - with ( - patch.object(app_with_options, "_refresh_tokens", new_callable=AsyncMock), - patch.object(app_with_options.http, "on_start", new_callable=AsyncMock), - ): + async def mock_on_start(event): + # Simulate the HTTP plugin calling the ready callback + if app_with_options.http.on_ready_callback: + await app_with_options.http.on_ready_callback() + + with patch.object(app_with_options.http, "on_start", new_callable=AsyncMock, side_effect=mock_on_start): # Test start start_task = asyncio.create_task(app_with_options.start(3978)) await asyncio.sleep(0.1) @@ -507,16 +510,18 @@ def get_token(scope, tenant_id=None): options = AppOptions(client_id="test-client-123", token=get_token) - app = App(**options) + # Mock environment variables to ensure they don't interfere + with patch.dict("os.environ", {"CLIENT_ID": "", "CLIENT_SECRET": "", "TENANT_ID": ""}, clear=False): + app = App(**options) - assert app.credentials is not None - assert type(app.credentials) is TokenCredentials - assert app.credentials.client_id == "test-client-123" - assert callable(app.credentials.token) + assert app.credentials is not None + assert type(app.credentials) is TokenCredentials + assert app.credentials.client_id == "test-client-123" + assert callable(app.credentials.token) - res = await app.api.bots.token.get(app.credentials) - assert token_called is True - assert res.access_token == "test.jwt.token" + res = await app.api.bots.token.get(app.credentials) + assert token_called is True + assert res.access_token == "test.jwt.token" def test_middleware_registration(self, app_with_options: App) -> None: """Test that middleware is registered correctly using app.use().""" diff --git a/packages/apps/tests/test_app_process.py b/packages/apps/tests/test_app_process.py index 56819d47..261cf139 100644 --- a/packages/apps/tests/test_app_process.py +++ b/packages/apps/tests/test_app_process.py @@ -8,11 +8,11 @@ import pytest from microsoft.teams.api import Activity, ActivityBase, ConversationReference -from microsoft.teams.apps import ActivityContext, AppTokens, Sender +from microsoft.teams.apps import ActivityContext, Sender from microsoft.teams.apps.app_events import EventManager from microsoft.teams.apps.app_process import ActivityProcessor -from microsoft.teams.apps.graph_token_manager import GraphTokenManager from microsoft.teams.apps.routing.router import ActivityHandler, ActivityRouter +from microsoft.teams.apps.token_manager import TokenManager from microsoft.teams.common import Client, ConsoleLogger, LocalStorage @@ -30,8 +30,7 @@ def activity_processor(self, mock_logger, mock_http_client): """Create an ActivityProcessor instance.""" mock_storage = MagicMock(spec=LocalStorage) mock_activity_router = MagicMock(spec=ActivityRouter) - mock_tokens = MagicMock(spec=AppTokens) - mock_graph_token_manager = MagicMock(spec=GraphTokenManager) + mock_token_manager = MagicMock(spec=TokenManager) return ActivityProcessor( mock_activity_router, mock_logger, @@ -39,8 +38,7 @@ def activity_processor(self, mock_logger, mock_http_client): mock_storage, "default_connection", mock_http_client, - mock_tokens, - mock_graph_token_manager, + mock_token_manager, ) @pytest.mark.asyncio diff --git a/packages/apps/tests/test_graph_token_manager.py b/packages/apps/tests/test_graph_token_manager.py deleted file mode 100644 index 4dfbf175..00000000 --- a/packages/apps/tests/test_graph_token_manager.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT License. -""" - -from unittest.mock import AsyncMock, MagicMock - -import pytest -from microsoft.teams.api import ClientCredentials, JsonWebToken -from microsoft.teams.apps.graph_token_manager import GraphTokenManager - -# Valid JWT-like token for testing (format: header.payload.signature) -VALID_TEST_TOKEN = ( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." - "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." - "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" -) -ANOTHER_VALID_TOKEN = ( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." - "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkphbmUgRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." - "Twzj7LKlhYUUe2GFRME4WOZdWq2TdayZhWjhBr1r5X4" -) - - -class TestGraphTokenManager: - """Test GraphTokenManager functionality.""" - - def test_initialization(self): - """Test GraphTokenManager initialization.""" - mock_api_client = MagicMock() - mock_credentials = MagicMock() - mock_logger = MagicMock() - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - logger=mock_logger, - ) - - assert manager is not None - # Test successful initialization by verifying the manager was created - - def test_initialization_without_logger(self): - """Test GraphTokenManager initialization without logger.""" - mock_api_client = MagicMock() - mock_credentials = MagicMock() - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - ) - - assert manager is not None - - @pytest.mark.asyncio - async def test_get_token_no_tenant_id(self): - """Test getting token with no tenant_id returns None.""" - mock_api_client = MagicMock() - mock_credentials = MagicMock() - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - ) - - token = await manager.get_token(None) - assert token is None - - @pytest.mark.asyncio - async def test_get_token_no_credentials(self): - """Test getting token with no credentials returns None.""" - mock_api_client = MagicMock() - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=None, - ) - - token = await manager.get_token("test-tenant") - assert token is None - - @pytest.mark.asyncio - async def test_get_token_success(self): - """Test successful token retrieval.""" - mock_api_client = MagicMock() - mock_token_response = MagicMock() - mock_token_response.access_token = VALID_TEST_TOKEN - mock_api_client.bots.token.get_graph = AsyncMock(return_value=mock_token_response) - - mock_credentials = ClientCredentials( - client_id="test-client-id", - client_secret="test-client-secret", - tenant_id="default-tenant-id", - ) - - mock_logger = MagicMock() - mock_child_logger = MagicMock() - mock_logger.getChild.return_value = mock_child_logger - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - logger=mock_logger, - ) - - token = await manager.get_token("test-tenant") - - assert token is not None - assert isinstance(token, JsonWebToken) - # Verify the API was called - mock_api_client.bots.token.get_graph.assert_called_once() - # Verify child logger was created and debug was called - mock_logger.getChild.assert_called_once_with("GraphTokenManager") - mock_child_logger.debug.assert_called_once() - - # Test that subsequent calls use cache by calling again - token2 = await manager.get_token("test-tenant") - assert token2 == token - # API should still only be called once due to caching - mock_api_client.bots.token.get_graph.assert_called_once() - - @pytest.mark.asyncio - async def test_get_token_from_cache(self): - """Test getting token from cache.""" - mock_api_client = MagicMock() - mock_credentials = MagicMock() - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - ) - - # Set up the API response for initial token - mock_token_response = MagicMock() - mock_token_response.access_token = VALID_TEST_TOKEN - mock_api_client.bots.token.get_graph = AsyncMock(return_value=mock_token_response) - - # First call should hit the API - token1 = await manager.get_token("test-tenant") - assert token1 is not None - assert isinstance(token1, JsonWebToken) - mock_api_client.bots.token.get_graph.assert_called_once() - - # Second call should use cache (API should not be called again) - token2 = await manager.get_token("test-tenant") - assert token2 == token1 # Should be the same cached token - # Still only called once due to caching - mock_api_client.bots.token.get_graph.assert_called_once() - - @pytest.mark.asyncio - async def test_get_token_api_error(self): - """Test token retrieval when API call fails.""" - mock_api_client = MagicMock() - mock_api_client.bots.token.get_graph = AsyncMock(side_effect=Exception("API Error")) - - mock_credentials = ClientCredentials( - client_id="test-client-id", - client_secret="test-client-secret", - tenant_id="default-tenant-id", - ) - - mock_logger = MagicMock() - mock_child_logger = MagicMock() - mock_logger.getChild.return_value = mock_child_logger - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - logger=mock_logger, - ) - - token = await manager.get_token("test-tenant") - - assert token is None - # Verify child logger was created and error was logged - mock_logger.getChild.assert_called_once_with("GraphTokenManager") - mock_child_logger.error.assert_called_once() - - @pytest.mark.asyncio - async def test_get_token_no_logger_on_error(self): - """Test token retrieval error handling without logger.""" - mock_api_client = MagicMock() - mock_api_client.bots.token.get_graph = AsyncMock(side_effect=Exception("API Error")) - - mock_credentials = ClientCredentials( - client_id="test-client-id", - client_secret="test-client-secret", - tenant_id="default-tenant-id", - ) - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - # No logger - ) - - token = await manager.get_token("test-tenant") - - assert token is None - # Should not raise exception even without logger - - @pytest.mark.asyncio - async def test_get_token_expired_cache_refresh(self): - """Test that expired tokens in cache are refreshed.""" - mock_api_client = MagicMock() - mock_token_response = MagicMock() - mock_token_response.access_token = ANOTHER_VALID_TOKEN - mock_api_client.bots.token.get_graph = AsyncMock(return_value=mock_token_response) - - mock_credentials = ClientCredentials( - client_id="test-client-id", - client_secret="test-client-secret", - tenant_id="default-tenant-id", - ) - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=mock_credentials, - ) - - # First, get a token to populate cache - first_token_response = MagicMock() - first_token_response.access_token = VALID_TEST_TOKEN - mock_api_client.bots.token.get_graph.return_value = first_token_response - - first_token = await manager.get_token("test-tenant") - assert first_token is not None - - # Now simulate the cached token being expired and get a new one - mock_api_client.bots.token.get_graph.return_value = mock_token_response - second_token = await manager.get_token("test-tenant") - - assert second_token is not None - assert isinstance(second_token, JsonWebToken) - # Verify the API was called multiple times (once for each get) - assert mock_api_client.bots.token.get_graph.call_count >= 1 - - @pytest.mark.asyncio - async def test_get_token_creates_tenant_specific_credentials(self): - """Test that tenant-specific credentials are created for the API call.""" - mock_api_client = MagicMock() - mock_token_response = MagicMock() - mock_token_response.access_token = VALID_TEST_TOKEN - mock_api_client.bots.token.get_graph = AsyncMock(return_value=mock_token_response) - - original_credentials = ClientCredentials( - client_id="test-client-id", - client_secret="test-client-secret", - tenant_id="original-tenant-id", - ) - - manager = GraphTokenManager( - api_client=mock_api_client, - credentials=original_credentials, - ) - - token = await manager.get_token("different-tenant-id") - - assert token is not None - # Verify the API was called - mock_api_client.bots.token.get_graph.assert_called_once() - - # Get the credentials that were passed to the API - call_args = mock_api_client.bots.token.get_graph.call_args - passed_credentials = call_args[0][0] # First positional argument - - # Verify it's a ClientCredentials with the correct tenant - assert isinstance(passed_credentials, ClientCredentials) - assert passed_credentials.client_id == "test-client-id" - assert passed_credentials.client_secret == "test-client-secret" - assert passed_credentials.tenant_id == "different-tenant-id" diff --git a/packages/apps/tests/test_optional_graph_dependencies.py b/packages/apps/tests/test_optional_graph_dependencies.py index f6cf0875..cc78ac1f 100644 --- a/packages/apps/tests/test_optional_graph_dependencies.py +++ b/packages/apps/tests/test_optional_graph_dependencies.py @@ -128,5 +128,5 @@ def test_app_graph_property_no_token(self) -> None: ) # app_graph should raise ValueError when no app token is available - with pytest.raises(ValueError, match="No app token available for Graph client"): + with pytest.raises(RuntimeError, match="Token cannot be None"): _ = activity_context.app_graph diff --git a/packages/apps/tests/test_token_manager.py b/packages/apps/tests/test_token_manager.py new file mode 100644 index 00000000..e46001eb --- /dev/null +++ b/packages/apps/tests/test_token_manager.py @@ -0,0 +1,242 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from microsoft.teams.api import ClientCredentials, JsonWebToken +from microsoft.teams.apps.token_manager import TokenManager +from microsoft.teams.common import Client + +# Valid JWT-like token for testing (format: header.payload.signature) +VALID_TEST_TOKEN = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." + "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" +) + + +class TestTokenManager: + """Test TokenManager functionality.""" + + @pytest.mark.asyncio + async def test_get_bot_token_success(self): + """Test successful bot token refresh, caching, and expiration refresh.""" + # First token response + mock_token_response1 = MagicMock() + mock_token_response1.access_token = VALID_TEST_TOKEN + + # Second token response for expired token + mock_token_response2 = MagicMock() + mock_token_response2.access_token = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiI5ODc2NTQzMjEwIiwibmFtZSI6IkphbmUgRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." + "Twzj7LKlhYUUe2GFRME4WOZdWq2TdayZhWjhBr1r5X4" + ) + + # Third token response for force refresh + mock_token_response3 = MagicMock() + mock_token_response3.access_token = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMTExMTExMTExIiwibmFtZSI6IkZvcmNlIFJlZnJlc2giLCJpYXQiOjE1MTYyMzkwMjJ9." + "dQw4w9WgXcQ" + ) + + mock_credentials = ClientCredentials( + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ) + + # Mock the BotTokenClient + mock_bot_token_client = MagicMock() + mock_bot_token_client.get = AsyncMock( + side_effect=[mock_token_response1, mock_token_response2, mock_token_response3] + ) + + mock_http_client = MagicMock(spec=Client) + mock_http_client.clone = MagicMock(return_value=mock_http_client) + + with patch("microsoft.teams.apps.token_manager.BotTokenClient", return_value=mock_bot_token_client): + manager = TokenManager( + http_client=mock_http_client, + credentials=mock_credentials, + ) + + # First call + token1 = await manager.get_bot_token() + assert token1 is not None + assert isinstance(token1, JsonWebToken) + mock_bot_token_client.get.assert_called_once() + + # Second call should use cache (mock should still only be called once) + token2 = await manager.get_bot_token() + assert token2 == token1 + mock_bot_token_client.get.assert_called_once() # Still only called once due to caching + + # Mock the token as expired + token1.is_expired = MagicMock(return_value=True) + + # Third call should refresh because token is expired + token3 = await manager.get_bot_token() + assert token3 is not None + assert token3 != token1 # New token + assert mock_bot_token_client.get.call_count == 2 + + # Force refresh even if not expired + token3.is_expired = MagicMock(return_value=False) + token4 = await manager.get_bot_token(force=True) + assert token4 is not None + assert mock_bot_token_client.get.call_count == 3 + + @pytest.mark.asyncio + async def test_get_bot_token_no_credentials(self): + """Test refreshing bot token with no credentials returns None.""" + mock_http_client = MagicMock(spec=Client) + mock_http_client.clone = MagicMock(return_value=mock_http_client) + + with patch("microsoft.teams.apps.token_manager.BotTokenClient"): + manager = TokenManager( + http_client=mock_http_client, + credentials=None, + ) + + token = await manager.get_bot_token() + assert token is None + + @pytest.mark.asyncio + async def test_get_graph_token_default(self): + """Test getting default graph token with caching and expiration refresh.""" + # First token response + mock_token_response1 = MagicMock() + mock_token_response1.access_token = VALID_TEST_TOKEN + + # Second token response for expired token + mock_token_response2 = MagicMock() + mock_token_response2.access_token = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiI5ODc2NTQzMjEwIiwibmFtZSI6IkphbmUgRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." + "Twzj7LKlhYUUe2GFRME4WOZdWq2TdayZhWjhBr1r5X4" + ) + + mock_credentials = ClientCredentials( + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="default-tenant-id", + ) + + # Mock the BotTokenClient + mock_bot_token_client = MagicMock() + mock_bot_token_client.get_graph = AsyncMock(side_effect=[mock_token_response1, mock_token_response2]) + + mock_http_client = MagicMock(spec=Client) + mock_http_client.clone = MagicMock(return_value=mock_http_client) + + with patch("microsoft.teams.apps.token_manager.BotTokenClient", return_value=mock_bot_token_client): + manager = TokenManager( + http_client=mock_http_client, + credentials=mock_credentials, + ) + + token1 = await manager.get_graph_token() + + assert token1 is not None + assert isinstance(token1, JsonWebToken) + + # Verify it's cached + token2 = await manager.get_graph_token() + assert token2 == token1 + mock_bot_token_client.get_graph.assert_called_once() + + # Mock the token as expired + token1.is_expired = MagicMock(return_value=True) + + # Third call should refresh because token is expired + token3 = await manager.get_graph_token() + assert token3 is not None + assert token3 != token1 # New token + assert mock_bot_token_client.get_graph.call_count == 2 + + @pytest.mark.asyncio + async def test_get_graph_token_with_tenant(self): + """Test getting tenant-specific graph token.""" + mock_token_response = MagicMock() + mock_token_response.access_token = VALID_TEST_TOKEN + + original_credentials = ClientCredentials( + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="original-tenant-id", + ) + + # Mock the BotTokenClient + mock_bot_token_client = MagicMock() + mock_bot_token_client.get_graph = AsyncMock(return_value=mock_token_response) + + mock_http_client = MagicMock(spec=Client) + mock_http_client.clone = MagicMock(return_value=mock_http_client) + + with patch("microsoft.teams.apps.token_manager.BotTokenClient", return_value=mock_bot_token_client): + manager = TokenManager( + http_client=mock_http_client, + credentials=original_credentials, + ) + + token = await manager.get_graph_token("different-tenant-id") + + assert token is not None + mock_bot_token_client.get_graph.assert_called_once() + + # Verify tenant-specific credentials were created + call_args = mock_bot_token_client.get_graph.call_args + passed_credentials = call_args[0][0] + assert isinstance(passed_credentials, ClientCredentials) + assert passed_credentials.tenant_id == "different-tenant-id" + + @pytest.mark.asyncio + async def test_graph_token_force_refresh(self): + """Test force refreshing graph token even when not expired.""" + mock_token_response1 = MagicMock() + mock_token_response1.access_token = VALID_TEST_TOKEN + + mock_token_response2 = MagicMock() + mock_token_response2.access_token = ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMTExMTExMTExIiwibmFtZSI6IkZvcmNlIFJlZnJlc2giLCJpYXQiOjE1MTYyMzkwMjJ9." + "dQw4w9WgXcQ" + ) + + mock_credentials = ClientCredentials( + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + ) + + mock_bot_token_client = MagicMock() + mock_bot_token_client.get_graph = AsyncMock(side_effect=[mock_token_response1, mock_token_response2]) + + mock_http_client = MagicMock(spec=Client) + mock_http_client.clone = MagicMock(return_value=mock_http_client) + + with patch("microsoft.teams.apps.token_manager.BotTokenClient", return_value=mock_bot_token_client): + manager = TokenManager( + http_client=mock_http_client, + credentials=mock_credentials, + ) + + # First call + token1 = await manager.get_graph_token() + assert token1 is not None + mock_bot_token_client.get_graph.assert_called_once() + + # Second call should use cache + token2 = await manager.get_graph_token() + assert token2 == token1 + mock_bot_token_client.get_graph.assert_called_once() # Still only called once + + # Force refresh should call API even if not expired + token3 = await manager.get_graph_token(force=True) + assert token3 is not None + assert mock_bot_token_client.get_graph.call_count == 2 # Now called twice