diff --git a/pytest.ini b/pytest.ini index 4f2c039f..e17cf7c9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,6 +4,8 @@ python_files = test_*.py python_classes = Test* python_functions = test_* asyncio_mode = auto +asyncio_default_test_loop_scope = session +asyncio_default_fixture_loop_scope = session markers = unit: marks tests as unit tests diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index 5fa24212..2970d1cb 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -183,4 +183,4 @@ async def _extract_page_text(self) -> ExtractResult: output_string = tree["simplified"] output_dict = {"page_text": output_string} validated_model = EmptyExtractSchema.model_validate(output_dict) - return ExtractResult(data=validated_model).data + return ExtractResult(data=validated_model) diff --git a/tests/conftest.py b/tests/conftest.py index 36767e1a..5f1bf16a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,19 +11,6 @@ # Set up pytest-asyncio as the default pytest_plugins = ["pytest_asyncio"] - -@pytest.fixture(scope="session") -def event_loop(): - """ - Create an instance of the default event loop for each test session. - This helps with running async tests. - """ - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() - - @pytest.fixture def mock_stagehand_config(): """Provide a mock StagehandConfig for testing""" diff --git a/tests/e2e/test_act_integration.py b/tests/e2e/test_act_integration.py index 221dea8b..87e53375 100644 --- a/tests/e2e/test_act_integration.py +++ b/tests/e2e/test_act_integration.py @@ -161,17 +161,17 @@ async def test_selecting_option_custom_input_local(self, local_stagehand): await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/expand-dropdown/") # Select an option from the dropdown. - await stagehand.page.act("Click the 'Select a Country' dropdown") + await stagehand.page.act("Click on the text 'Select a country'") # Wait for dropdown to expand - await asyncio.sleep(1) + await asyncio.sleep(2) # We are expecting stagehand to click the dropdown to expand it, and therefore # the available options should now be contained in the full a11y tree. # To test, we'll grab the full a11y tree, and make sure it contains 'Canada' extraction = await stagehand.page.extract() - assert "Canada" in extraction.data + assert "Canada" in extraction.data.page_text @pytest.mark.asyncio @pytest.mark.local @@ -193,7 +193,7 @@ async def test_selecting_option_hidden_input_local(self, local_stagehand): # To test, we'll grab the full a11y tree, and make sure it contains 'Green' extraction = await stagehand.page.extract() - assert "Green" in extraction.data + assert "Green" in extraction.data.page_text @pytest.mark.asyncio @pytest.mark.local diff --git a/tests/e2e/test_workflows.py b/tests/e2e/test_workflows.py index a03a06fb..cc1720bb 100644 --- a/tests/e2e/test_workflows.py +++ b/tests/e2e/test_workflows.py @@ -51,14 +51,14 @@ async def test_search_and_extract_workflow(self, mock_stagehand_config, sample_h stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.act = AsyncMock(return_value=ActResult( + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.act = AsyncMock(return_value=ActResult( success=True, message="Search executed", action="search" )) - stagehand.page.extract = AsyncMock(return_value={ + stagehand._page.extract = AsyncMock(return_value={ "title": "OpenAI Search Results", "results": [ {"title": "OpenAI Official Website", "url": "https://openai.com"}, @@ -82,9 +82,9 @@ async def test_search_and_extract_workflow(self, mock_stagehand_config, sample_h assert extracted_data["results"][0]["title"] == "OpenAI Official Website" # Verify calls were made - stagehand.page.goto.assert_called_with("https://google.com") - stagehand.page.act.assert_called_with("search for openai") - stagehand.page.extract.assert_called_with("extract search results") + stagehand._page.goto.assert_called_with("https://google.com") + stagehand._page.act.assert_called_with("search for openai") + stagehand._page.extract.assert_called_with("extract search results") finally: stagehand._closed = True @@ -149,14 +149,14 @@ def form_response_generator(messages, **kwargs): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.act = AsyncMock() - stagehand.page.extract = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.act = AsyncMock() + stagehand._page.extract = AsyncMock() stagehand._initialized = True - + # Mock act responses - stagehand.page.act.side_effect = [ + stagehand._page.act.side_effect = [ ActResult(success=True, message="Username filled", action="fill"), ActResult(success=True, message="Email filled", action="fill"), ActResult(success=True, message="Password filled", action="fill"), @@ -164,7 +164,7 @@ def form_response_generator(messages, **kwargs): ] # Mock success verification - stagehand.page.extract.return_value = {"success": True, "message": "Registration successful!"} + stagehand._page.extract.return_value = {"success": True, "message": "Registration successful!"} try: # Execute form filling workflow @@ -189,7 +189,7 @@ def form_response_generator(messages, **kwargs): assert verification["success"] is True # Verify all steps were executed - assert stagehand.page.act.call_count == 4 + assert stagehand._page.act.call_count == 4 finally: stagehand._closed = True @@ -235,10 +235,10 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.observe = AsyncMock() - stagehand.page.act = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.observe = AsyncMock() + stagehand._page.act = AsyncMock() stagehand._initialized = True # Mock observe results @@ -278,8 +278,8 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config): ) ] - stagehand.page.observe.side_effect = [nav_buttons, add_to_cart_buttons] - stagehand.page.act.return_value = ActResult( + stagehand._page.observe.side_effect = [nav_buttons, add_to_cart_buttons] + stagehand._page.act.return_value = ActResult( success=True, message="Button clicked", action="click" @@ -307,8 +307,8 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config): assert add_to_cart_result.success is True # Verify method calls - assert stagehand.page.observe.call_count == 2 - assert stagehand.page.act.call_count == 2 + assert stagehand._page.observe.call_count == 2 + assert stagehand._page.act.call_count == 2 finally: stagehand._closed = True @@ -370,10 +370,10 @@ async def test_multi_page_navigation_workflow(self, mock_stagehand_config): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.extract = AsyncMock() - stagehand.page.act = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.extract = AsyncMock() + stagehand._page.act = AsyncMock() stagehand._initialized = True # Mock page responses @@ -403,9 +403,9 @@ def navigation_side_effect(url): else: current_page[0] = "/products" - stagehand.page.extract.side_effect = lambda inst: extract_response(inst) - stagehand.page.goto.side_effect = navigation_side_effect - stagehand.page.act.return_value = ActResult( + stagehand._page.extract.side_effect = lambda inst: extract_response(inst) + stagehand._page.goto.side_effect = navigation_side_effect + stagehand._page.act.return_value = ActResult( success=True, message="Navigation successful", action="click" @@ -434,8 +434,8 @@ def navigation_side_effect(url): assert len(details["specs"]) == 3 # Verify navigation flow - assert stagehand.page.goto.call_count == 2 - assert stagehand.page.extract.call_count == 2 + assert stagehand._page.goto.call_count == 2 + assert stagehand._page.extract.call_count == 2 finally: stagehand._closed = True @@ -457,9 +457,9 @@ async def test_error_recovery_workflow(self, mock_stagehand_config): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.act = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.act = AsyncMock() stagehand._initialized = True # Simulate intermittent failures and recovery @@ -481,7 +481,7 @@ def act_with_failures(*args, **kwargs): action="click" ) - stagehand.page.act.side_effect = act_with_failures + stagehand._page.act.side_effect = act_with_failures try: await stagehand.page.goto("https://example.com") @@ -498,7 +498,7 @@ def act_with_failures(*args, **kwargs): assert success is True assert failure_count == 3 # 2 failures + 1 success - assert stagehand.page.act.call_count == 3 + assert stagehand._page.act.call_count == 3 finally: stagehand._closed = True @@ -538,10 +538,10 @@ async def test_browserbase_session_workflow(self, mock_browserbase_config): # Mock the browser connection parts stagehand._client = http_client stagehand.session_id = "test-bb-session" - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.act = AsyncMock() - stagehand.page.extract = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.act = AsyncMock() + stagehand._page.extract = AsyncMock() stagehand._initialized = True # Mock page methods to use server @@ -561,8 +561,8 @@ async def mock_extract(instruction, **kwargs): ) return response.json() - stagehand.page.act = mock_act - stagehand.page.extract = mock_extract + stagehand._page.act = mock_act + stagehand._page.extract = mock_extract try: # Execute Browserbase workflow @@ -616,9 +616,9 @@ class ProductList(BaseModel): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.goto = AsyncMock() - stagehand.page.extract = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.goto = AsyncMock() + stagehand._page.extract = AsyncMock() stagehand._initialized = True # Mock structured extraction responses @@ -642,7 +642,7 @@ class ProductList(BaseModel): "total_count": 2 } - stagehand.page.extract.return_value = mock_product_data + stagehand._page.extract.return_value = mock_product_data try: await stagehand.page.goto("https://electronics-store.com") @@ -671,7 +671,7 @@ class ProductList(BaseModel): assert product2["in_stock"] is False # Verify extract was called with schema - stagehand.page.extract.assert_called_once() + stagehand._page.extract.assert_called_once() finally: stagehand._closed = True @@ -696,8 +696,8 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config): stagehand._playwright = playwright stagehand._browser = browser stagehand._context = context - stagehand.page = MagicMock() - stagehand.page.extract = AsyncMock() + stagehand._page = MagicMock() + stagehand._page.extract = AsyncMock() stagehand._initialized = True # Mock multiple concurrent extractions @@ -707,7 +707,7 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config): {"section": "footer", "content": "Footer content"} ] - stagehand.page.extract.side_effect = extraction_responses + stagehand._page.extract.side_effect = extraction_responses try: # Execute concurrent extractions @@ -727,7 +727,7 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config): assert results[2]["section"] == "footer" # Verify all extractions were called - assert stagehand.page.extract.call_count == 3 + assert stagehand._page.extract.call_count == 3 finally: stagehand._closed = True \ No newline at end of file diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py index afec5c6b..062e3319 100644 --- a/tests/unit/test_client_initialization.py +++ b/tests/unit/test_client_initialization.py @@ -235,6 +235,7 @@ def test_init_with_model_api_key_in_env(self): client = Stagehand(config=config) assert client.model_api_key == "test-model-api-key" + @mock.patch.dict(os.environ, {}, clear=True) def test_init_with_custom_llm(self): config = StagehandConfig( env="LOCAL", @@ -245,6 +246,7 @@ def test_init_with_custom_llm(self): assert client.model_client_options["apiKey"] == "custom-llm-key" assert client.model_client_options["baseURL"] == "https://custom-llm.com" + @mock.patch.dict(os.environ, {}, clear=True) def test_init_with_custom_llm_override(self): config = StagehandConfig( env="LOCAL",