Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"pyreadline3>=3.5.4 ; sys_platform == 'win32'",
"pyright>=1.1.409",
"python-dotenv>=1.1.0",
"stamina>=26.1.0",
"tiktoken>=0.12.0",
"typechat>=0.0.4",
"webvtt-py>=0.5.1",
Expand Down
57 changes: 40 additions & 17 deletions src/typeagent/aitools/model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@

import numpy as np
from numpy.typing import NDArray
import stamina
from stamina import BoundAsyncRetryingCaller

import openai
from pydantic_ai import Embedder as _PydanticAIEmbedder
from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase
from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType
Expand All @@ -52,6 +55,20 @@
NormalizedEmbeddings,
)

_TRANSIENT_ERRORS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
)

DEFAULT_CHAT_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on(
_TRANSIENT_ERRORS
)
DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=4, timeout=30).on(
_TRANSIENT_ERRORS
)
Comment thread
gvanrossum marked this conversation as resolved.

# ---------------------------------------------------------------------------
# Chat model adapter
# ---------------------------------------------------------------------------
Expand All @@ -65,8 +82,13 @@ class PydanticAIChatModel(typechat.TypeChatLanguageModel):
used wherever TypeChat expects a ``TypeChatLanguageModel``.
"""

def __init__(self, model: Model) -> None:
def __init__(
self,
model: Model,
retrier: BoundAsyncRetryingCaller | None = None,
) -> None:
self._model = model
self._retrier = retrier or DEFAULT_CHAT_RETRIER

async def complete(
self, prompt: str | list[typechat.PromptSection]
Expand All @@ -84,7 +106,7 @@ async def complete(
messages: list[ModelMessage] = [ModelRequest(parts=parts)]
params = ModelRequestParameters()

response = await self._model.request(messages, None, params)
response = await self._retrier(self._model.request, messages, None, params)
text_parts = [p.content for p in response.parts if isinstance(p, TextPart)]
if text_parts:
return typechat.Success("".join(text_parts))
Expand All @@ -111,24 +133,20 @@ def __init__(
self,
embedder: _PydanticAIEmbedder,
model_name: str,
retrier: BoundAsyncRetryingCaller | None = None,
) -> None:
self._embedder = embedder
self.model_name = model_name
self._retrier = retrier or DEFAULT_EMBED_RETRIER

async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding:
result = await self._embedder.embed_documents([input])
embedding: NDArray[np.float32] = np.array(
result.embeddings[0], dtype=np.float32
)
norm = float(np.linalg.norm(embedding))
if norm > 0:
embedding = (embedding / norm).astype(np.float32)
return embedding
embeddings = await self.get_embeddings_nocache([input])
return embeddings[0]

async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings:
if not input:
raise ValueError("Cannot embed an empty list")
result = await self._embedder.embed_documents(input)
result = await self._retrier(self._embedder.embed_documents, input)
embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32)
norms = np.where(norms > 0, norms, np.float32(1.0))
Expand Down Expand Up @@ -182,7 +200,7 @@ def _make_azure_provider(
azure_endpoint=azure_endpoint,
api_version=api_version,
azure_ad_token_provider=token_provider.get_token,
max_retries=5,
max_retries=0,
)
else:
apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY")
Expand All @@ -193,7 +211,7 @@ def _make_azure_provider(
default_headers=(
{"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None
),
max_retries=5,
max_retries=0,
)
return AzureProvider(openai_client=client)

Expand All @@ -208,6 +226,8 @@ def _make_azure_provider(

def create_chat_model(
model_spec: str | None = None,
*,
retrier: BoundAsyncRetryingCaller | None = None,
) -> PydanticAIChatModel:
"""Create a chat model from a ``provider:model`` spec.

Expand Down Expand Up @@ -249,14 +269,15 @@ def create_chat_model(
)
else:
model = infer_model(model_spec)
return PydanticAIChatModel(model)
return PydanticAIChatModel(model, retrier)


DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002"


def create_embedding_model(
model_spec: str | None = None,
retrier: BoundAsyncRetryingCaller | None = None,
) -> CachingEmbeddingModel:
"""Create an embedding model from a ``provider:model`` spec.

Expand Down Expand Up @@ -313,7 +334,7 @@ def create_embedding_model(
embedder = _PydanticAIEmbedder(embedding_model)
else:
embedder = _PydanticAIEmbedder(model_spec)
return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name))
return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name, retrier))


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -400,6 +421,8 @@ def create_test_embedding_model(
def configure_models(
chat_model_spec: str,
embedding_model_spec: str,
chat_retrier: BoundAsyncRetryingCaller | None = None,
embed_retrier: BoundAsyncRetryingCaller | None = None,
) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]:
"""Configure both a chat model and an embedding model at once.

Expand All @@ -416,6 +439,6 @@ def configure_models(
extractor = KnowledgeExtractor(model=chat)
"""
return (
create_chat_model(chat_model_spec),
create_embedding_model(embedding_model_spec),
create_chat_model(chat_model_spec, retrier=chat_retrier),
create_embedding_model(embedding_model_spec, retrier=embed_retrier),
)
4 changes: 3 additions & 1 deletion src/typeagent/emails/email_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

class EmailMemorySettings:
def __init__(self, conversation_settings: ConversationSettings) -> None:
self.language_model = model_adapters.create_chat_model()
self.language_model = model_adapters.create_chat_model(
retrier=conversation_settings.chat_retrier
)
self.query_translator = utils.create_translator(
self.language_model, search_query_schema.SearchQuery
)
Expand Down
4 changes: 2 additions & 2 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,12 @@ async def query(
"""
# Create translators lazily (once per conversation instance)
if self._query_translator is None:
model = model_adapters.create_chat_model()
model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier)
self._query_translator = utils.create_translator(
model, search_query_schema.SearchQuery
)
if self._answer_translator is None:
model = model_adapters.create_chat_model()
model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier)
self._answer_translator = utils.create_translator(
model, answer_response_schema.AnswerResponse
)
Expand Down
11 changes: 10 additions & 1 deletion src/typeagent/knowpro/convsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from dataclasses import dataclass

from stamina import BoundAsyncRetryingCaller

from ..aitools.embeddings import IEmbeddingModel
from ..aitools.model_adapters import create_embedding_model
from ..aitools.vectorbase import TextEmbeddingIndexSettings
Expand Down Expand Up @@ -41,9 +43,16 @@ def __init__(
self,
model: IEmbeddingModel | None = None,
storage_provider: IStorageProvider | None = None,
*,
chat_retrier: BoundAsyncRetryingCaller | None = None,
embed_retrier: BoundAsyncRetryingCaller | None = None,
):
# Retry callers -- None means "use the default" in model_adapters.
self.chat_retrier = chat_retrier
self.embed_retrier = embed_retrier

# All settings share the same model, so they share the embedding cache.
model = model or create_embedding_model()
model = model or create_embedding_model(retrier=embed_retrier)
self.embedding_model = model
min_score = 0.85
self.related_term_index_settings = RelatedTermIndexSettings(
Expand Down
14 changes: 1 addition & 13 deletions src/typeagent/knowpro/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,13 @@
from collections.abc import Callable
from dataclasses import dataclass

from typechat import Result, TypeChatLanguageModel
from typechat import Result

from . import convknowledge
from . import knowledge_schema as kplib
from ..aitools import model_adapters
from .interfaces import IKnowledgeExtractor


def create_knowledge_extractor(
chat_model: TypeChatLanguageModel | None = None,
) -> convknowledge.KnowledgeExtractor:
"""Create a knowledge extractor using the given Chat Model."""
chat_model = chat_model or model_adapters.create_chat_model()
extractor = convknowledge.KnowledgeExtractor(
chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False
)
return extractor


async def extract_knowledge_from_text(
knowledge_extractor: IKnowledgeExtractor,
text: str,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from dotenv import load_dotenv
import pytest
import pytest_asyncio
import stamina

stamina.set_testing(True)


from typeagent.aitools.embeddings import IEmbeddingModel
from typeagent.aitools.model_adapters import create_test_embedding_model
Expand Down
7 changes: 0 additions & 7 deletions tests/test_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typeagent.knowpro import convknowledge
from typeagent.knowpro import knowledge_schema as kplib
from typeagent.knowpro.knowledge import (
create_knowledge_extractor,
extract_knowledge_from_text,
extract_knowledge_from_text_batch,
merge_concrete_entities,
Expand All @@ -34,12 +33,6 @@ def mock_knowledge_extractor() -> convknowledge.KnowledgeExtractor:
return MockKnowledgeExtractor() # type: ignore


def test_create_knowledge_extractor(really_needs_auth: None):
"""Test creating a knowledge extractor."""
extractor = create_knowledge_extractor()
assert isinstance(extractor, convknowledge.KnowledgeExtractor)


@pytest.mark.asyncio
async def test_extract_knowledge_from_text(
mock_knowledge_extractor: convknowledge.KnowledgeExtractor,
Expand Down
2 changes: 1 addition & 1 deletion tools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ async def main():
"Error: non-empty --search-results required for batch mode."
)

model = model_adapters.create_chat_model()
model = model_adapters.create_chat_model(retrier=settings.chat_retrier)
query_translator = utils.create_translator(model, search_query_schema.SearchQuery)
if args.alt_schema:
if args.verbose:
Expand Down
23 changes: 23 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading