diff --git a/automata/tests/unit/cli/test_cli_scripts_run_doc_embedding.py b/automata/tests/unit/cli/test_cli_scripts_run_doc_embedding.py index 039c0b15..80e29f65 100644 --- a/automata/tests/unit/cli/test_cli_scripts_run_doc_embedding.py +++ b/automata/tests/unit/cli/test_cli_scripts_run_doc_embedding.py @@ -56,14 +56,14 @@ def symbol_doc_embedding_handler_mock(): def test_initialize_providers( get_mock, set_overrides_mock, - OpenAIEmbeddingProvider_mock, + openai_embedding_provider_mock, ChromaSymbolEmbeddingVectorDatabase_mock, SymbolGraph_mock, symbol_graph_mock, ): SymbolGraph_mock.return_value = symbol_graph_mock ChromaSymbolEmbeddingVectorDatabase_mock.return_value = MagicMock() - OpenAIEmbeddingProvider_mock.return_value = MagicMock() + openai_embedding_provider_mock.return_value.return_value = MagicMock() symbol_code_embedding_handler = MagicMock() symbol_code_embedding_handler._get_sorted_supported_symbols.return_value = [ FakeSymbol("symbol1"), @@ -91,7 +91,7 @@ def test_initialize_providers( "symbol_graph": symbol_graph_mock, "code_embedding_db": ChromaSymbolEmbeddingVectorDatabase_mock.return_value, "doc_embedding_db": ChromaSymbolEmbeddingVectorDatabase_mock.return_value, - "embedding_provider": OpenAIEmbeddingProvider_mock.return_value, + "embedding_provider": openai_embedding_provider_mock.return_value, "disable_synchronization": True, } set_overrides_mock.assert_called_once_with(**overrides)