Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions typeagent/mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Fledgling MCP server on top of knowpro."""
"""Fledgling MCP server on top of typeagent."""

import argparse
from dataclasses import dataclass
import time
from typing import Any
Expand All @@ -19,6 +20,8 @@
from typeagent.knowpro.search_query_schema import SearchQuery
from typeagent.podcasts import podcast
from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex
from typeagent.storage.sqlite import SqliteStorageProvider
from typeagent.storage.utils import create_storage_provider


class MCPTypeChatModel(typechat.TypeChatLanguageModel):
Expand Down Expand Up @@ -101,18 +104,28 @@ def __repr__(self) -> str:
return f"Context({', '.join(parts)})"


async def make_context(session: ServerSession) -> ProcessingContext:
async def make_context(
session: ServerSession, dbname: str | None = None
) -> ProcessingContext:
"""Create processing context using MCP-based language model.

Args:
session: The MCP server session that provides create_message() for sampling.
dbname: Path to SQLite database file, or None to load from JSON file.

Note: Embeddings still require API keys since MCP doesn't support embeddings yet.
Make sure to set OPENAI_API_KEY or AZURE_OPENAI_API_KEY for embeddings.
"""
utils.load_dotenv()

settings = ConversationSettings()

# Uses SQLite provider if dbname is specified, otherwise use memory provider
settings.storage_provider = await create_storage_provider(
settings.message_text_index_settings,
settings.related_term_index_settings,
dbname,
podcast.PodcastMessage,
)

lang_search_options = searchlang.LanguageSearchOptions(
compile_options=searchlang.LanguageQueryCompileOptions(
exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
Expand All @@ -124,9 +137,7 @@ async def make_context(session: ServerSession) -> ProcessingContext:
entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
)

query_context = await load_podcast_index(
"testdata/Episode_53_AdrianTchaikovsky_index", settings
)
query_context = await load_podcast_index_or_database(settings, dbname)

# Use MCP-based model instead of one that requires API keys
model = MCPTypeChatModel(session)
Expand All @@ -145,18 +156,24 @@ async def make_context(session: ServerSession) -> ProcessingContext:
return context


async def load_podcast_index(
podcast_file_prefix: str, settings: ConversationSettings
async def load_podcast_index_or_database(
settings: ConversationSettings,
dbname: str | None = None,
) -> query.QueryEvalContext[podcast.PodcastMessage, Any]:
conversation = await podcast.Podcast.read_from_file(podcast_file_prefix, settings)
assert (
conversation is not None
), f"Failed to load podcast from {podcast_file_prefix!r}"
if dbname is None:
conversation = await podcast.Podcast.read_from_file(
"testdata/Episode_53_AdrianTchaikovsky_index", settings
)
else:
conversation = await podcast.Podcast.create(settings)
return query.QueryEvalContext(conversation)


# Create an MCP server
mcp = FastMCP("knowpro")
mcp = FastMCP("typagent")

# Global variable to store database path (set via command-line argument)
_dbname: str | None = None


@dataclass
Expand All @@ -178,7 +195,7 @@ async def query_conversation(
return QuestionResponse(
success=False, answer="No question provided", time_used=dt
)
context = await make_context(ctx.request_context.session)
context = await make_context(ctx.request_context.session, _dbname)

# Stages 1, 2, 3 (LLM -> proto-query, compile, execute query)
result = await searchlang.search_conversation_with_language(
Expand Down Expand Up @@ -213,5 +230,22 @@ async def query_conversation(

# Run the MCP server
if __name__ == "__main__":
# Load env vars
utils.load_dotenv()

# Set up command-line argument parsing and parse command line
parser = argparse.ArgumentParser(description="MCP server for knowpro")
parser.add_argument(
"-d",
"--database",
type=str,
default=None,
help="Path to the SQLite database file (default: load from JSON file)",
)
args = parser.parse_args()

# Store database path in global variable (no other straightforward way to pass to tool)
_dbname = args.database

# Use stdio transport for simplicity
mcp.run(transport="stdio")