diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index f99e90062fbbc1..90c80be3a1e743 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): tenant_id = None with session_factory.create_session() as session, session.begin(): - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) @@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") @@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." @@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: index_processor = IndexProcessorFactory(index_type).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green")) @@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.exception("Failed to clean vector index for document %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow")) return @@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: indexing_runner.run([document]) end_at = time.perf_counter() @@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): except Exception as e: logger.exception("document_indexing_sync_task failed for document_id: %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = str(e) diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index f49f4535af77c1..41d3068a103d94 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset): with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory: session = MagicMock() session.scalars.return_value.all.return_value = [] - session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + session.scalar.side_effect = [mock_document, mock_dataset] begin_cm = MagicMock() begin_cm.__enter__.return_value = session @@ -242,14 +242,13 @@ def test_data_source_info_serialized_as_json_string( # DB session mock — shared across all ``session_factory.create_session()`` calls session = MagicMock() session.scalars.return_value.all.return_value = [] - # .where() path: session 1 reads document + dataset, session 2 reads dataset - session.query.return_value.where.return_value.first.side_effect = [ + # All .first() calls are now session.scalar() — ordered by call sequence: + # session 1: document + dataset, session 2: dataset (clean), session 3: document (update), + # session 4: document (indexing) + session.scalar.side_effect = [ mock_document, mock_dataset, mock_dataset, - ] - # .filter_by() path: session 3 (update), session 4 (indexing) - session.query.return_value.filter_by.return_value.first.side_effect = [ mock_document, mock_document, ]