diff --git a/pyproject.toml b/pyproject.toml index 99371ae7..4a3601de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pyreadline3>=3.5.4 ; sys_platform == 'win32'", "pyright>=1.1.409", "python-dotenv>=1.1.0", + "stamina>=26.1.0", "tiktoken>=0.12.0", "typechat>=0.0.4", "webvtt-py>=0.5.1", diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 46208f6e..230cec18 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -31,7 +31,10 @@ import numpy as np from numpy.typing import NDArray +import stamina +from stamina import BoundAsyncRetryingCaller +import openai from pydantic_ai import Embedder as _PydanticAIEmbedder from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType @@ -52,6 +55,20 @@ NormalizedEmbeddings, ) +_TRANSIENT_ERRORS = ( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, +) + +DEFAULT_CHAT_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on( + _TRANSIENT_ERRORS +) +DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=4, timeout=30).on( + _TRANSIENT_ERRORS +) + # --------------------------------------------------------------------------- # Chat model adapter # --------------------------------------------------------------------------- @@ -65,8 +82,13 @@ class PydanticAIChatModel(typechat.TypeChatLanguageModel): used wherever TypeChat expects a ``TypeChatLanguageModel``. """ - def __init__(self, model: Model) -> None: + def __init__( + self, + model: Model, + retrier: BoundAsyncRetryingCaller | None = None, + ) -> None: self._model = model + self._retrier = retrier or DEFAULT_CHAT_RETRIER async def complete( self, prompt: str | list[typechat.PromptSection] @@ -84,7 +106,7 @@ async def complete( messages: list[ModelMessage] = [ModelRequest(parts=parts)] params = ModelRequestParameters() - response = await self._model.request(messages, None, params) + response = await self._retrier(self._model.request, messages, None, params) text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] if text_parts: return typechat.Success("".join(text_parts)) @@ -111,24 +133,20 @@ def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, + retrier: BoundAsyncRetryingCaller | None = None, ) -> None: self._embedder = embedder self.model_name = model_name + self._retrier = retrier or DEFAULT_EMBED_RETRIER async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - result = await self._embedder.embed_documents([input]) - embedding: NDArray[np.float32] = np.array( - result.embeddings[0], dtype=np.float32 - ) - norm = float(np.linalg.norm(embedding)) - if norm > 0: - embedding = (embedding / norm).astype(np.float32) - return embedding + embeddings = await self.get_embeddings_nocache([input]) + return embeddings[0] async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: raise ValueError("Cannot embed an empty list") - result = await self._embedder.embed_documents(input) + result = await self._retrier(self._embedder.embed_documents, input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) @@ -182,7 +200,7 @@ def _make_azure_provider( azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider.get_token, - max_retries=5, + max_retries=0, ) else: apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY") @@ -193,7 +211,7 @@ def _make_azure_provider( default_headers=( {"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None ), - max_retries=5, + max_retries=0, ) return AzureProvider(openai_client=client) @@ -208,6 +226,8 @@ def _make_azure_provider( def create_chat_model( model_spec: str | None = None, + *, + retrier: BoundAsyncRetryingCaller | None = None, ) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. @@ -249,7 +269,7 @@ def create_chat_model( ) else: model = infer_model(model_spec) - return PydanticAIChatModel(model) + return PydanticAIChatModel(model, retrier) DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002" @@ -257,6 +277,7 @@ def create_chat_model( def create_embedding_model( model_spec: str | None = None, + retrier: BoundAsyncRetryingCaller | None = None, ) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. @@ -313,7 +334,7 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name)) + return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name, retrier)) # --------------------------------------------------------------------------- @@ -400,6 +421,8 @@ def create_test_embedding_model( def configure_models( chat_model_spec: str, embedding_model_spec: str, + chat_retrier: BoundAsyncRetryingCaller | None = None, + embed_retrier: BoundAsyncRetryingCaller | None = None, ) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. @@ -416,6 +439,6 @@ def configure_models( extractor = KnowledgeExtractor(model=chat) """ return ( - create_chat_model(chat_model_spec), - create_embedding_model(embedding_model_spec), + create_chat_model(chat_model_spec, retrier=chat_retrier), + create_embedding_model(embedding_model_spec, retrier=embed_retrier), ) diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index 6dd50cc4..1523f3a8 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -23,7 +23,9 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = model_adapters.create_chat_model() + self.language_model = model_adapters.create_chat_model( + retrier=conversation_settings.chat_retrier + ) self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 131b0ceb..e8a7db43 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -541,12 +541,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier) self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier) self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 97c2bee2..acf559fe 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,6 +5,8 @@ from dataclasses import dataclass +from stamina import BoundAsyncRetryingCaller + from ..aitools.embeddings import IEmbeddingModel from ..aitools.model_adapters import create_embedding_model from ..aitools.vectorbase import TextEmbeddingIndexSettings @@ -41,9 +43,16 @@ def __init__( self, model: IEmbeddingModel | None = None, storage_provider: IStorageProvider | None = None, + *, + chat_retrier: BoundAsyncRetryingCaller | None = None, + embed_retrier: BoundAsyncRetryingCaller | None = None, ): + # Retry callers -- None means "use the default" in model_adapters. + self.chat_retrier = chat_retrier + self.embed_retrier = embed_retrier + # All settings share the same model, so they share the embedding cache. - model = model or create_embedding_model() + model = model or create_embedding_model(retrier=embed_retrier) self.embedding_model = model min_score = 0.85 self.related_term_index_settings = RelatedTermIndexSettings( diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index e2503967..b16bdb81 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -5,25 +5,13 @@ from collections.abc import Callable from dataclasses import dataclass -from typechat import Result, TypeChatLanguageModel +from typechat import Result from . import convknowledge from . import knowledge_schema as kplib -from ..aitools import model_adapters from .interfaces import IKnowledgeExtractor -def create_knowledge_extractor( - chat_model: TypeChatLanguageModel | None = None, -) -> convknowledge.KnowledgeExtractor: - """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or model_adapters.create_chat_model() - extractor = convknowledge.KnowledgeExtractor( - chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False - ) - return extractor - - async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, diff --git a/tests/conftest.py b/tests/conftest.py index 7f0f11f5..3ebedd21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,10 @@ from dotenv import load_dotenv import pytest import pytest_asyncio +import stamina + +stamina.set_testing(True) + from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.model_adapters import create_test_embedding_model diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index 0374dfec..44515ada 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -8,7 +8,6 @@ from typeagent.knowpro import convknowledge from typeagent.knowpro import knowledge_schema as kplib from typeagent.knowpro.knowledge import ( - create_knowledge_extractor, extract_knowledge_from_text, extract_knowledge_from_text_batch, merge_concrete_entities, @@ -34,12 +33,6 @@ def mock_knowledge_extractor() -> convknowledge.KnowledgeExtractor: return MockKnowledgeExtractor() # type: ignore -def test_create_knowledge_extractor(really_needs_auth: None): - """Test creating a knowledge extractor.""" - extractor = create_knowledge_extractor() - assert isinstance(extractor, convknowledge.KnowledgeExtractor) - - @pytest.mark.asyncio async def test_extract_knowledge_from_text( mock_knowledge_extractor: convknowledge.KnowledgeExtractor, diff --git a/tools/query.py b/tools/query.py index c7ec7908..6c5f08a2 100644 --- a/tools/query.py +++ b/tools/query.py @@ -577,7 +577,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=settings.chat_retrier) query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: diff --git a/uv.lock b/uv.lock index 86f16351..8f7f34ad 100644 --- a/uv.lock +++ b/uv.lock @@ -2304,6 +2304,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/7f/3de5402f39890ac5660b86bcf5c03f9d855dad5c4ed764866d7b592b46fd/sse_starlette-3.3.4-py3-none-any.whl", hash = "sha256:84bb06e58939a8b38d8341f1bc9792f06c2b53f48c608dd207582b664fc8f3c1", size = 14330, upload-time = "2026-03-29T09:00:21.846Z" }, ] +[[package]] +name = "stamina" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/bd/b2f71ae14368a066f103d182f25bbc6c3bf4aa695889f3ed3cba026d6f36/stamina-26.1.0.tar.gz", hash = "sha256:0214d05fdf5102c518194a4aac7520ce53cf660550ae3b940701aad88cf50c17", size = 568171, upload-time = "2026-04-13T17:44:31.012Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/f0/1ff90a1d1dd02de23feafdf9dffaecef3958348be5c192df56670ccb4f86/stamina-26.1.0-py3-none-any.whl", hash = "sha256:62e06829bec87c06d4cafde520b32a6097d1017c378a9eb63253c5bf5ebbbb88", size = 18508, upload-time = "2026-04-13T17:44:29.545Z" }, +] + [[package]] name = "starlette" version = "1.0.0" @@ -2326,6 +2338,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/97/b4f2f442fee92a1406f08b4fbc990bd7d02dc84b3b5e6315a59fa9b2a9f4/std_uritemplate-2.0.8-py3-none-any.whl", hash = "sha256:839807a7f9d07f0bad1a88977c3428bd97b9ff0d229412a0bf36123d8c724257", size = 6512, upload-time = "2025-10-16T15:51:28.713Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "tiktoken" version = "0.12.0" @@ -2400,6 +2421,7 @@ dependencies = [ { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "pyright" }, { name = "python-dotenv" }, + { name = "stamina" }, { name = "tiktoken" }, { name = "typechat" }, { name = "webvtt-py" }, @@ -2444,6 +2466,7 @@ requires-dist = [ { name = "pyreadline3", marker = "sys_platform == 'win32'", specifier = ">=3.5.4" }, { name = "pyright", specifier = ">=1.1.409" }, { name = "python-dotenv", specifier = ">=1.1.0" }, + { name = "stamina", specifier = ">=26.1.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, { name = "typechat", specifier = ">=0.0.4" }, { name = "webvtt-py", specifier = ">=0.5.1" },