diff --git a/typeagent/mcp/server.py b/typeagent/mcp/server.py index 6fbdcd2..25b26d7 100644 --- a/typeagent/mcp/server.py +++ b/typeagent/mcp/server.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Fledgling MCP server on top of knowpro.""" +"""Fledgling MCP server on top of typeagent.""" +import argparse from dataclasses import dataclass import time from typing import Any @@ -19,6 +20,8 @@ from typeagent.knowpro.search_query_schema import SearchQuery from typeagent.podcasts import podcast from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex +from typeagent.storage.sqlite import SqliteStorageProvider +from typeagent.storage.utils import create_storage_provider class MCPTypeChatModel(typechat.TypeChatLanguageModel): @@ -101,18 +104,28 @@ def __repr__(self) -> str: return f"Context({', '.join(parts)})" -async def make_context(session: ServerSession) -> ProcessingContext: +async def make_context( + session: ServerSession, dbname: str | None = None +) -> ProcessingContext: """Create processing context using MCP-based language model. Args: session: The MCP server session that provides create_message() for sampling. + dbname: Path to SQLite database file, or None to load from JSON file. Note: Embeddings still require API keys since MCP doesn't support embeddings yet. Make sure to set OPENAI_API_KEY or AZURE_OPENAI_API_KEY for embeddings. """ - utils.load_dotenv() - settings = ConversationSettings() + + # Uses SQLite provider if dbname is specified, otherwise use memory provider + settings.storage_provider = await create_storage_provider( + settings.message_text_index_settings, + settings.related_term_index_settings, + dbname, + podcast.PodcastMessage, + ) + lang_search_options = searchlang.LanguageSearchOptions( compile_options=searchlang.LanguageQueryCompileOptions( exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True @@ -124,9 +137,7 @@ async def make_context(session: ServerSession) -> ProcessingContext: entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None ) - query_context = await load_podcast_index( - "testdata/Episode_53_AdrianTchaikovsky_index", settings - ) + query_context = await load_podcast_index_or_database(settings, dbname) # Use MCP-based model instead of one that requires API keys model = MCPTypeChatModel(session) @@ -145,18 +156,24 @@ async def make_context(session: ServerSession) -> ProcessingContext: return context -async def load_podcast_index( - podcast_file_prefix: str, settings: ConversationSettings +async def load_podcast_index_or_database( + settings: ConversationSettings, + dbname: str | None = None, ) -> query.QueryEvalContext[podcast.PodcastMessage, Any]: - conversation = await podcast.Podcast.read_from_file(podcast_file_prefix, settings) - assert ( - conversation is not None - ), f"Failed to load podcast from {podcast_file_prefix!r}" + if dbname is None: + conversation = await podcast.Podcast.read_from_file( + "testdata/Episode_53_AdrianTchaikovsky_index", settings + ) + else: + conversation = await podcast.Podcast.create(settings) return query.QueryEvalContext(conversation) # Create an MCP server -mcp = FastMCP("knowpro") +mcp = FastMCP("typagent") + +# Global variable to store database path (set via command-line argument) +_dbname: str | None = None @dataclass @@ -178,7 +195,7 @@ async def query_conversation( return QuestionResponse( success=False, answer="No question provided", time_used=dt ) - context = await make_context(ctx.request_context.session) + context = await make_context(ctx.request_context.session, _dbname) # Stages 1, 2, 3 (LLM -> proto-query, compile, execute query) result = await searchlang.search_conversation_with_language( @@ -213,5 +230,22 @@ async def query_conversation( # Run the MCP server if __name__ == "__main__": + # Load env vars + utils.load_dotenv() + + # Set up command-line argument parsing and parse command line + parser = argparse.ArgumentParser(description="MCP server for knowpro") + parser.add_argument( + "-d", + "--database", + type=str, + default=None, + help="Path to the SQLite database file (default: load from JSON file)", + ) + args = parser.parse_args() + + # Store database path in global variable (no other straightforward way to pass to tool) + _dbname = args.database + # Use stdio transport for simplicity mcp.run(transport="stdio")