Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions api/tasks/document_indexing_sync_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
tenant_id = None

with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)

if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
Expand All @@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return

dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")

Expand Down Expand Up @@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)

with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
Expand All @@ -112,15 +114,15 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)

with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
Expand All @@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
Expand All @@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
Expand Down
11 changes: 5 additions & 6 deletions api/tests/unit_tests/tasks/test_document_indexing_sync_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset):
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
session.scalar.side_effect = [mock_document, mock_dataset]

begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
Expand Down Expand Up @@ -242,14 +242,13 @@ def test_data_source_info_serialized_as_json_string(
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
# All .first() calls are now session.scalar() — ordered by call sequence:
# session 1: document + dataset, session 2: dataset (clean), session 3: document (update),
# session 4: document (indexing)
session.scalar.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]
Expand Down
Loading