Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/typeagent/aitools/model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@
"""

from collections.abc import Sequence
import logging
import os

import numpy as np
from numpy.typing import NDArray
import stamina
from stamina import BoundAsyncRetryingCaller
from stamina.instrumentation import RetryDetails, set_on_retry_hooks

import openai
from pydantic_ai import Embedder as _PydanticAIEmbedder
from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase
from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType
from pydantic_ai.embeddings.settings import EmbeddingSettings
from pydantic_ai.exceptions import ModelAPIError
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
Expand All @@ -60,15 +63,42 @@
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
ModelAPIError,
)

DEFAULT_CHAT_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on(
_TRANSIENT_ERRORS
)
DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=4, timeout=30).on(
DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on(
_TRANSIENT_ERRORS
)

_logger = logging.getLogger("stamina")

_CALLABLE_LABELS: dict[str, str] = {
"request": "chat",
"embed_documents": "embedding",
}


def _on_retry(details: RetryDetails) -> None:
kind = _CALLABLE_LABELS.get(details.name, details.name)
caused = details.caused_by
exc_summary = repr(caused)[:200]
_logger.warning(
"stamina: retrying %s request (attempt %d, waited %.1fs so far, "
"waiting %.1fs): %s",
kind,
details.retry_num,
details.waited_so_far,
details.wait_for,
exc_summary,
)


set_on_retry_hooks([_on_retry])


Comment thread
gvanrossum marked this conversation as resolved.
# ---------------------------------------------------------------------------
# Chat model adapter
# ---------------------------------------------------------------------------
Expand Down
72 changes: 50 additions & 22 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ async def add_messages_streaming(

Args:
messages: An async iterable of messages to ingest.
batch_size: Number of messages per commit batch.
batch_size: Target number of text chunks per commit batch.
Messages are never split across batches, so the actual
chunk count may exceed ``batch_size`` if a single message
has more chunks than that.
Comment thread
gvanrossum marked this conversation as resolved.
on_batch_committed: Optional callback invoked after each batch is
committed, receiving the batch's ``AddMessagesResult``.

Expand All @@ -255,23 +258,28 @@ def _accumulate(result: AddMessagesResult) -> None:
total.messages_added += result.messages_added
total.semrefs_added += result.semrefs_added
total.chunks_added += result.chunks_added
total.messages_skipped += result.messages_skipped
if on_batch_committed:
on_batch_committed(result)

pending_commit: asyncio.Task[AddMessagesResult] | None = None
pending_skipped: int = 0

async def _drain_commit() -> None:
nonlocal pending_commit
nonlocal pending_commit, pending_skipped
if pending_commit is not None:
_accumulate(await pending_commit)
result = await pending_commit
result.messages_skipped += pending_skipped
_accumulate(result)
pending_commit = None
pending_skipped = 0

async def _submit_batch(filtered: list[TMessage]) -> None:
nonlocal pending_commit
if not filtered:
async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
nonlocal pending_commit, pending_skipped
if not filtered and not skipped:
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead guard. All three call sites guarantee batch is non-empty before calling _filter_ingested, so len(filtered) + skipped == len(batch) > 0. This branch can never fire. The if not filtered: check at line 290 already handles the all-skipped case.

Suggested change
return
if filtered and should_extract:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KRRT7 I am confused by your suggestions. This one, when applied, generates invalid syntax. Maybe you are using a tool to generate them and it's got an off-by-one error or is working from an outdated file version? It's the case for all suggestions in this batch.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. If we ever revisit this code we can delete it. Bernhard merged it before I got a chance to look at your feedback, in part confused by the misaligned suggestions. It's harmless so I see no need to fix it in a separate PR.


if should_extract:
if filtered and should_extract:
next_extraction = asyncio.create_task(
self._extract_knowledge_for_batch(filtered)
)
Expand All @@ -281,6 +289,13 @@ async def _submit_batch(filtered: list[TMessage]) -> None:
# Wait for previous commit to finish (frees the DB connection)
await _drain_commit()

if not filtered:
# Nothing to commit, just report skipped
total.messages_skipped += skipped
if on_batch_committed:
on_batch_committed(AddMessagesResult(messages_skipped=skipped))
Comment thread
gvanrossum marked this conversation as resolved.
return

# Await extraction result for this batch
extraction = await next_extraction if next_extraction is not None else None

Expand All @@ -289,19 +304,29 @@ async def _submit_batch(filtered: list[TMessage]) -> None:
pending_commit = asyncio.create_task(
self._commit_batch_streaming(storage, filtered, extraction)
)
pending_skipped = skipped

try:
batch: list[TMessage] = []
batch_chunks = 0
async for msg in messages:
msg_chunks = len(msg.text_chunks)
if batch and batch_chunks + msg_chunks > batch_size:
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)
batch = []
batch_chunks = 0
batch.append(msg)
if len(batch) >= batch_size:
filtered = await self._filter_ingested(storage, batch)
await _submit_batch(filtered)
batch_chunks += msg_chunks
if batch_chunks >= batch_size:
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)
batch = []
batch_chunks = 0

if batch:
filtered = await self._filter_ingested(storage, batch)
await _submit_batch(filtered)
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)

await _drain_commit()
except BaseException:
Expand All @@ -317,23 +342,26 @@ async def _filter_ingested(
self,
storage: IStorageProvider[TMessage],
batch: list[TMessage],
) -> list[TMessage]:
) -> tuple[list[TMessage], int]:
"""Filter out messages whose source_id has already been ingested.

Safe to call while a pending_commit task exists: is_source_ingested
is a synchronous SELECT on SQLite's single connection, so it won't
interleave with the commit task's cursor operations in asyncio's
cooperative model. If the storage provider becomes truly async
(e.g. aiosqlite), this assumption needs revisiting.
Returns (filtered_messages, skipped_count).

Uses a single batch query instead of per-message lookups.
"""
source_ids = [m.source_id for m in batch if m.source_id is not None]
if source_ids:
ingested = await storage.are_sources_ingested(source_ids)
else:
ingested = set[str]()
filtered: list[TMessage] = []
skipped = 0
for msg in batch:
if msg.source_id is not None and await storage.is_source_ingested(
msg.source_id
):
if msg.source_id is not None and msg.source_id in ingested:
skipped += 1
continue
filtered.append(msg)
return filtered
return filtered, skipped

async def _extract_knowledge_for_batch(
self,
Expand Down
1 change: 1 addition & 0 deletions src/typeagent/knowpro/interfaces_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class AddMessagesResult:
messages_added: int = 0
chunks_added: int = 0
semrefs_added: int = 0
messages_skipped: int = 0


# Messages are referenced by their sequential ordinal numbers.
Expand Down
4 changes: 4 additions & 0 deletions src/typeagent/knowpro/interfaces_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ async def is_source_ingested(self, source_id: str) -> bool:
"""Check if a source has already been ingested."""
...

async def are_sources_ingested(self, source_ids: list[str]) -> set[str]:
"""Return the subset of source_ids that have already been ingested."""
...

async def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source."""
...
Expand Down
4 changes: 4 additions & 0 deletions src/typeagent/storage/memory/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ async def is_source_ingested(self, source_id: str) -> bool:
"""
return source_id in self._ingested_sources

async def are_sources_ingested(self, source_ids: list[str]) -> set[str]:
"""Return the subset of source_ids that have already been ingested."""
return self._ingested_sources & set(source_ids)

async def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source.

Expand Down
20 changes: 20 additions & 0 deletions src/typeagent/storage/sqlite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,26 @@ async def is_source_ingested(self, source_id: str) -> bool:
row = cursor.fetchone()
return row is not None and row[0] == STATUS_INGESTED

async def are_sources_ingested(self, source_ids: list[str]) -> set[str]:
"""Return the subset of source_ids that have already been ingested."""
if not source_ids:
return set()
cursor = self.db.cursor()
result: set[str] = set()
# Chunk to stay within SQLite's SQLITE_MAX_VARIABLE_NUMBER
# (999 on older builds, 32766 on 3.32.0+).
chunk_size = 500
for i in range(0, len(source_ids), chunk_size):
chunk = source_ids[i : i + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"SELECT source_id FROM IngestedSources"
f" WHERE source_id IN ({placeholders}) AND status = ?",
[*chunk, STATUS_INGESTED],
)
result.update(row[0] for row in cursor.fetchall())
return result

async def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source.

Expand Down
Loading
Loading