Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
asyncio_default_test_loop_scope = session
asyncio_default_fixture_loop_scope = session

markers =
unit: marks tests as unit tests
Expand Down
2 changes: 1 addition & 1 deletion stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ async def _extract_page_text(self) -> ExtractResult:
output_string = tree["simplified"]
output_dict = {"page_text": output_string}
validated_model = EmptyExtractSchema.model_validate(output_dict)
return ExtractResult(data=validated_model).data
return ExtractResult(data=validated_model)
13 changes: 0 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,6 @@
# Set up pytest-asyncio as the default
pytest_plugins = ["pytest_asyncio"]


@pytest.fixture(scope="session")
def event_loop():
"""
Create an instance of the default event loop for each test session.
This helps with running async tests.
"""
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
yield loop
loop.close()


@pytest.fixture
def mock_stagehand_config():
"""Provide a mock StagehandConfig for testing"""
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/test_act_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,17 @@ async def test_selecting_option_custom_input_local(self, local_stagehand):
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/expand-dropdown/")

# Select an option from the dropdown.
await stagehand.page.act("Click the 'Select a Country' dropdown")
await stagehand.page.act("Click on the text 'Select a country'")

# Wait for dropdown to expand
await asyncio.sleep(1)
await asyncio.sleep(2)

# We are expecting stagehand to click the dropdown to expand it, and therefore
# the available options should now be contained in the full a11y tree.

# To test, we'll grab the full a11y tree, and make sure it contains 'Canada'
extraction = await stagehand.page.extract()
assert "Canada" in extraction.data
assert "Canada" in extraction.data.page_text

@pytest.mark.asyncio
@pytest.mark.local
Expand All @@ -193,7 +193,7 @@ async def test_selecting_option_hidden_input_local(self, local_stagehand):

# To test, we'll grab the full a11y tree, and make sure it contains 'Green'
extraction = await stagehand.page.extract()
assert "Green" in extraction.data
assert "Green" in extraction.data.page_text

@pytest.mark.asyncio
@pytest.mark.local
Expand Down
104 changes: 52 additions & 52 deletions tests/e2e/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ async def test_search_and_extract_workflow(self, mock_stagehand_config, sample_h
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.act = AsyncMock(return_value=ActResult(
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.act = AsyncMock(return_value=ActResult(
success=True,
message="Search executed",
action="search"
))
stagehand.page.extract = AsyncMock(return_value={
stagehand._page.extract = AsyncMock(return_value={
"title": "OpenAI Search Results",
"results": [
{"title": "OpenAI Official Website", "url": "https://openai.com"},
Expand All @@ -82,9 +82,9 @@ async def test_search_and_extract_workflow(self, mock_stagehand_config, sample_h
assert extracted_data["results"][0]["title"] == "OpenAI Official Website"

# Verify calls were made
stagehand.page.goto.assert_called_with("https://google.com")
stagehand.page.act.assert_called_with("search for openai")
stagehand.page.extract.assert_called_with("extract search results")
stagehand._page.goto.assert_called_with("https://google.com")
stagehand._page.act.assert_called_with("search for openai")
stagehand._page.extract.assert_called_with("extract search results")

finally:
stagehand._closed = True
Expand Down Expand Up @@ -149,22 +149,22 @@ def form_response_generator(messages, **kwargs):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.act = AsyncMock()
stagehand.page.extract = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.act = AsyncMock()
stagehand._page.extract = AsyncMock()
stagehand._initialized = True

# Mock act responses
stagehand.page.act.side_effect = [
stagehand._page.act.side_effect = [
ActResult(success=True, message="Username filled", action="fill"),
ActResult(success=True, message="Email filled", action="fill"),
ActResult(success=True, message="Password filled", action="fill"),
ActResult(success=True, message="Form submitted", action="click")
]

# Mock success verification
stagehand.page.extract.return_value = {"success": True, "message": "Registration successful!"}
stagehand._page.extract.return_value = {"success": True, "message": "Registration successful!"}

try:
# Execute form filling workflow
Expand All @@ -189,7 +189,7 @@ def form_response_generator(messages, **kwargs):
assert verification["success"] is True

# Verify all steps were executed
assert stagehand.page.act.call_count == 4
assert stagehand._page.act.call_count == 4

finally:
stagehand._closed = True
Expand Down Expand Up @@ -235,10 +235,10 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.observe = AsyncMock()
stagehand.page.act = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.observe = AsyncMock()
stagehand._page.act = AsyncMock()
stagehand._initialized = True

# Mock observe results
Expand Down Expand Up @@ -278,8 +278,8 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config):
)
]

stagehand.page.observe.side_effect = [nav_buttons, add_to_cart_buttons]
stagehand.page.act.return_value = ActResult(
stagehand._page.observe.side_effect = [nav_buttons, add_to_cart_buttons]
stagehand._page.act.return_value = ActResult(
success=True,
message="Button clicked",
action="click"
Expand Down Expand Up @@ -307,8 +307,8 @@ async def test_observe_then_act_workflow(self, mock_stagehand_config):
assert add_to_cart_result.success is True

# Verify method calls
assert stagehand.page.observe.call_count == 2
assert stagehand.page.act.call_count == 2
assert stagehand._page.observe.call_count == 2
assert stagehand._page.act.call_count == 2

finally:
stagehand._closed = True
Expand Down Expand Up @@ -370,10 +370,10 @@ async def test_multi_page_navigation_workflow(self, mock_stagehand_config):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.extract = AsyncMock()
stagehand.page.act = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.extract = AsyncMock()
stagehand._page.act = AsyncMock()
stagehand._initialized = True

# Mock page responses
Expand Down Expand Up @@ -403,9 +403,9 @@ def navigation_side_effect(url):
else:
current_page[0] = "/products"

stagehand.page.extract.side_effect = lambda inst: extract_response(inst)
stagehand.page.goto.side_effect = navigation_side_effect
stagehand.page.act.return_value = ActResult(
stagehand._page.extract.side_effect = lambda inst: extract_response(inst)
stagehand._page.goto.side_effect = navigation_side_effect
stagehand._page.act.return_value = ActResult(
success=True,
message="Navigation successful",
action="click"
Expand Down Expand Up @@ -434,8 +434,8 @@ def navigation_side_effect(url):
assert len(details["specs"]) == 3

# Verify navigation flow
assert stagehand.page.goto.call_count == 2
assert stagehand.page.extract.call_count == 2
assert stagehand._page.goto.call_count == 2
assert stagehand._page.extract.call_count == 2

finally:
stagehand._closed = True
Expand All @@ -457,9 +457,9 @@ async def test_error_recovery_workflow(self, mock_stagehand_config):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.act = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.act = AsyncMock()
stagehand._initialized = True

# Simulate intermittent failures and recovery
Expand All @@ -481,7 +481,7 @@ def act_with_failures(*args, **kwargs):
action="click"
)

stagehand.page.act.side_effect = act_with_failures
stagehand._page.act.side_effect = act_with_failures

try:
await stagehand.page.goto("https://example.com")
Expand All @@ -498,7 +498,7 @@ def act_with_failures(*args, **kwargs):

assert success is True
assert failure_count == 3 # 2 failures + 1 success
assert stagehand.page.act.call_count == 3
assert stagehand._page.act.call_count == 3

finally:
stagehand._closed = True
Expand Down Expand Up @@ -538,10 +538,10 @@ async def test_browserbase_session_workflow(self, mock_browserbase_config):
# Mock the browser connection parts
stagehand._client = http_client
stagehand.session_id = "test-bb-session"
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.act = AsyncMock()
stagehand.page.extract = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.act = AsyncMock()
stagehand._page.extract = AsyncMock()
stagehand._initialized = True

# Mock page methods to use server
Expand All @@ -561,8 +561,8 @@ async def mock_extract(instruction, **kwargs):
)
return response.json()

stagehand.page.act = mock_act
stagehand.page.extract = mock_extract
stagehand._page.act = mock_act
stagehand._page.extract = mock_extract

try:
# Execute Browserbase workflow
Expand Down Expand Up @@ -616,9 +616,9 @@ class ProductList(BaseModel):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.goto = AsyncMock()
stagehand.page.extract = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.goto = AsyncMock()
stagehand._page.extract = AsyncMock()
stagehand._initialized = True

# Mock structured extraction responses
Expand All @@ -642,7 +642,7 @@ class ProductList(BaseModel):
"total_count": 2
}

stagehand.page.extract.return_value = mock_product_data
stagehand._page.extract.return_value = mock_product_data

try:
await stagehand.page.goto("https://electronics-store.com")
Expand Down Expand Up @@ -671,7 +671,7 @@ class ProductList(BaseModel):
assert product2["in_stock"] is False

# Verify extract was called with schema
stagehand.page.extract.assert_called_once()
stagehand._page.extract.assert_called_once()

finally:
stagehand._closed = True
Expand All @@ -696,8 +696,8 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config):
stagehand._playwright = playwright
stagehand._browser = browser
stagehand._context = context
stagehand.page = MagicMock()
stagehand.page.extract = AsyncMock()
stagehand._page = MagicMock()
stagehand._page.extract = AsyncMock()
stagehand._initialized = True

# Mock multiple concurrent extractions
Expand All @@ -707,7 +707,7 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config):
{"section": "footer", "content": "Footer content"}
]

stagehand.page.extract.side_effect = extraction_responses
stagehand._page.extract.side_effect = extraction_responses

try:
# Execute concurrent extractions
Expand All @@ -727,7 +727,7 @@ async def test_concurrent_operations_workflow(self, mock_stagehand_config):
assert results[2]["section"] == "footer"

# Verify all extractions were called
assert stagehand.page.extract.call_count == 3
assert stagehand._page.extract.call_count == 3

finally:
stagehand._closed = True
2 changes: 2 additions & 0 deletions tests/unit/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def test_init_with_model_api_key_in_env(self):
client = Stagehand(config=config)
assert client.model_api_key == "test-model-api-key"

@mock.patch.dict(os.environ, {}, clear=True)
def test_init_with_custom_llm(self):
config = StagehandConfig(
env="LOCAL",
Expand All @@ -245,6 +246,7 @@ def test_init_with_custom_llm(self):
assert client.model_client_options["apiKey"] == "custom-llm-key"
assert client.model_client_options["baseURL"] == "https://custom-llm.com"

@mock.patch.dict(os.environ, {}, clear=True)
def test_init_with_custom_llm_override(self):
config = StagehandConfig(
env="LOCAL",
Expand Down