Skip to content

Fm/improvements #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,44 @@ config = StagehandConfig(
)
```

## Evaluations

The Stagehand Python SDK includes a set of evaluations to test its core functionality. These evaluations are organized by the primary methods they test: `act`, `extract`, and `observe`.

### Running Evaluations

You can run evaluations using the `run_all_evals.py` script in the `evals/` directory:

```bash
# Run only observe evaluations (default behavior)
python -m evals.run_all_evals

# Run all evaluations (act, extract, and observe)
python -m evals.run_all_evals --all

# Run a specific evaluation
python -m evals.run_all_evals --all --eval observe_taxes
python -m evals.run_all_evals --all --eval google_jobs

# Specify a different model
python -m evals.run_all_evals --model gpt-4o-mini
```

### Evaluation Types

The evaluations test the following capabilities:

- **act**: Tests for browser actions (clicking, typing)
- `google_jobs`: Google jobs search and extraction

- **extract**: Tests for data extraction capabilities
- `extract_press_releases`: Extracting press releases from a dummy site

- **observe**: Tests for element observation and identification
- `observe_taxes`: Tax form elements observation

Results are printed to the console with a summary showing success/failure for each evaluation.

## License

MIT License (c) 2025 Browserbase, Inc.
50 changes: 36 additions & 14 deletions evals/act/google_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import traceback
from typing import Any, Optional, dict
from typing import Any, Optional, Dict

from pydantic import BaseModel

Expand All @@ -19,23 +19,43 @@ class JobDetails(BaseModel):
preferred_qualifications: Qualifications


def is_job_details_valid(details: dict[str, Any]) -> bool:
def is_job_details_valid(details: Dict[str, Any] | JobDetails) -> bool:
"""
Validates that each top-level field in the extracted job details is not None.
For nested dictionary values, each sub-value must be non-null and a string
or a number.
Validates that the extracted job details are in the correct format.
application_deadline is allowed to be None.
For qualifications, degree and years_of_experience are allowed to be None.
"""
if not details:
return False
for _key, value in details.items():
if value is None:

# Convert Pydantic model to dict if needed
if hasattr(details, "model_dump"):
details_dict = details.model_dump()
else:
details_dict = details

# application_deadline is allowed to be None
# minimum_qualifications and preferred_qualifications must exist
required_fields = ["minimum_qualifications", "preferred_qualifications"]
for field in required_fields:
if field not in details_dict or details_dict[field] is None:
return False
if isinstance(value, dict):
for v in value.values():
if v is None or not isinstance(v, (str, int, float)):
return False
elif not isinstance(value, (str, int, float)):

# For qualifications, check that they're dictionaries but allow None values
for field in ["minimum_qualifications", "preferred_qualifications"]:
if not isinstance(details_dict[field], dict):
return False

# Each qualification should have the expected structure
quals = details_dict[field]
if "degree" not in quals or "years_of_experience" not in quals:
return False

# Values can be None or proper types
for k, v in quals.items():
if v is not None and not isinstance(v, (str, int, float)):
return False

return True


Expand Down Expand Up @@ -79,7 +99,7 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
)

try:
await stagehand.page.navigate("https://www.google.com/")
await stagehand.page.goto("https://www.google.com/")
await asyncio.sleep(3)
await stagehand.page.act(ActOptions(action="click on the about page"))
await stagehand.page.act(ActOptions(action="click on the careers page"))
Expand All @@ -96,11 +116,13 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
"(degree and years of experience), and preferred qualifications "
"(degree and years of experience)"
),
schemaDefinition=JobDetails.model_json_schema(),
schemaDefinition=JobDetails,
useTextExtract=use_text_extract,
)
)

print("Extracted job details:", job_details)

valid = is_job_details_valid(job_details)

await stagehand.close()
Expand Down
41 changes: 41 additions & 0 deletions evals/env_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Environment variable loader for Stagehand evaluations.

This module loads environment variables from an .env file in the evals directory,
making them available to all submodules (act, extract, observe).
"""
import os
from pathlib import Path
from dotenv import load_dotenv

def load_evals_env():
"""
Load environment variables from the .env file in the evals directory.
This ensures all submodules have access to the same environment variables.
"""
# Get the evals directory path (where this file is located)
evals_dir = Path(__file__).parent.absolute()
env_path = evals_dir / '.env'

# Load from root directory as fallback if evals/.env doesn't exist
root_env_path = evals_dir.parent / '.env'

# First try to load from evals/.env
if env_path.exists():
print(f"Loading environment variables from {env_path}")
load_dotenv(env_path)
# Fall back to root .env file if it exists
elif root_env_path.exists():
print(f"Loading environment variables from {root_env_path}")
load_dotenv(root_env_path)
else:
print("No .env file found. Please create one in the evals directory.")
print("Required variables: MODEL_API_KEY, BROWSERBASE_API_KEY, BROWSERBASE_PROJECT_ID")

# Check for essential environment variables
essential_vars = ['MODEL_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID']
missing_vars = [var for var in essential_vars if not os.getenv(var)]

if missing_vars:
print(f"Warning: Missing essential environment variables: {', '.join(missing_vars)}")
print("Some evaluations may fail without these variables.")
24 changes: 14 additions & 10 deletions evals/extract/extract_press_releases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import os

from pydantic import BaseModel

Expand Down Expand Up @@ -47,8 +48,8 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
session_url = init_response["sessionUrl"]

# Navigate to the dummy press releases page # TODO - choose a different page
await stagehand.page.navigate(
"https://dummy-press-releases.surge.sh/news", wait_until="networkidle"
await stagehand.page.goto(
"https://dummy-press-releases.surge.sh/news"
)
# Wait for 5 seconds to ensure content has loaded
await asyncio.sleep(5)
Expand All @@ -61,14 +62,21 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
"extract the title and corresponding publish date of EACH AND EVERY "
"press releases on this page. DO NOT MISS ANY PRESS RELEASES."
),
schemaDefinition=PressReleases.model_json_schema(),
schemaDefinition=PressReleases,
useTextExtract=use_text_extract,
)
)
print("Raw result:", raw_result)
# Check that the extraction returned a valid dictionary
if not raw_result or not isinstance(raw_result, dict):
error_message = "Extraction did not return a valid dictionary."

# Get the items list from the raw_result, which could be a dict or a PressReleases object
if isinstance(raw_result, PressReleases):
items = raw_result.items
elif isinstance(raw_result, dict) and "items" in raw_result:
# Parse the raw result using the defined schema if it's a dictionary
parsed = PressReleases.model_validate(raw_result)
items = parsed.items
else:
error_message = "Extraction did not return valid press releases data."
logger.error({"message": error_message, "raw_result": raw_result})
return {
"_success": False,
Expand All @@ -78,10 +86,6 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
"sessionUrl": session_url,
}

# Parse the raw result using the defined schema.
parsed = PressReleases.parse_obj(raw_result)
items = parsed.items

# Expected results (from the TS eval)
expected_length = 28
expected_first = PressRelease(
Expand Down
4 changes: 4 additions & 0 deletions evals/run_all_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import inspect
import os

from .env_loader import load_evals_env

# Load environment variables at module import time
load_evals_env()

# A simple logger to collect logs for the evals
class SimpleLogger:
Expand Down
5 changes: 5 additions & 0 deletions evals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
from typing import Any, Optional

from .env_loader import load_evals_env

# Try to import LiteLLM, which is used for model inference
try:
import litellm
Expand Down Expand Up @@ -89,6 +91,9 @@ async def complete(

def setup_environment():
"""Set up the environment for running evaluations."""
# First, load environment variables from .env files
load_evals_env()

# If OPENAI_API_KEY is set but MODEL_API_KEY is not, copy it over
if os.getenv("OPENAI_API_KEY") and not os.getenv("MODEL_API_KEY"):
os.environ["MODEL_API_KEY"] = os.getenv("OPENAI_API_KEY")
Expand Down
3 changes: 0 additions & 3 deletions stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@ def __init__(
async def extract(
self,
options: Optional[ExtractOptions] = None,
request_id: str = "",
schema: Optional[type[BaseModel]] = None,
) -> ExtractResult:
"""
Execute an extraction operation locally.

Args:
options: ExtractOptions containing the instruction and other parameters
request_id: Unique identifier for the request
schema: Optional Pydantic model for structured output

Returns:
Expand Down Expand Up @@ -101,7 +99,6 @@ async def extract(
tree_elements=output_string,
schema=transformed_schema,
llm_client=self.stagehand.llm,
request_id=request_id,
user_provided_instructions=self.user_provided_instructions,
logger=self.logger,
log_inference_to_file=False, # TODO: Implement logging to file if needed
Expand Down
3 changes: 0 additions & 3 deletions stagehand/handlers/observe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
async def observe(
self,
options: ObserveOptions,
*request_id: str,
from_act: bool = False,
) -> list[ObserveResult]:
"""
Execute an observation operation locally.

Args:
options: ObserveOptions containing the instruction and other parameters
request_id: Unique identifier for the request

Returns:
list of ObserveResult instances
Expand Down Expand Up @@ -80,7 +78,6 @@ async def observe(
instruction=instruction,
tree_elements=output_string,
llm_client=self.stagehand.llm,
request_id=request_id,
user_provided_instructions=self.user_provided_instructions,
logger=self.logger,
log_inference_to_file=False, # TODO: Implement logging to file if needed
Expand Down
3 changes: 1 addition & 2 deletions stagehand/llm/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""LLM client for model interactions."""

import logging
import time
from typing import Any, Callable, Optional

import litellm

from stagehand.metrics import start_inference_timer, get_inference_time_ms
from stagehand.metrics import get_inference_time_ms, start_inference_timer

# Configure logger for the module
logger = logging.getLogger(__name__)
Expand Down
7 changes: 0 additions & 7 deletions stagehand/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def observe(
instruction: str,
tree_elements: str,
llm_client: Any,
request_id: str,
user_provided_instructions: Optional[str] = None,
logger: Optional[Callable] = None,
log_inference_to_file: bool = False,
Expand All @@ -38,7 +37,6 @@ def observe(
instruction: The instruction to follow when finding elements
tree_elements: String representation of DOM/accessibility tree elements
llm_client: Client for calling LLM
request_id: Unique ID for this request
user_provided_instructions: Optional custom system instructions
logger: Optional logger function
log_inference_to_file: Whether to log inference to file
Expand Down Expand Up @@ -73,7 +71,6 @@ def observe(
messages=messages,
response_format=ObserveInferenceSchema,
temperature=0.1,
request_id=request_id,
function_name="ACT" if from_act else "OBSERVE",
)
inference_time_ms = int((time.time() - start_time) * 1000)
Expand Down Expand Up @@ -131,7 +128,6 @@ def extract(
tree_elements: str,
schema: Optional[Union[type[BaseModel], dict]] = None,
llm_client: Any = None,
request_id: str = "",
user_provided_instructions: Optional[str] = None,
logger: Optional[Callable] = None,
log_inference_to_file: bool = False,
Expand All @@ -146,7 +142,6 @@ def extract(
tree_elements: The DOM or accessibility tree representation
schema: Pydantic model defining the structure of the data to extract
llm_client: The LLM client to use for the request
request_id: Unique identifier for the request
user_provided_instructions: Optional custom system instructions
logger: Logger instance for logging
log_inference_to_file: Whether to log inference to file
Expand Down Expand Up @@ -187,7 +182,6 @@ def extract(
messages=extract_messages,
response_format=response_format,
temperature=0.1,
request_id=request_id,
function_name="EXTRACT", # Always set to EXTRACT
**kwargs,
)
Expand Down Expand Up @@ -238,7 +232,6 @@ def extract(
messages=metadata_messages,
response_format=metadata_schema,
temperature=0.1,
request_id=request_id,
function_name="EXTRACT", # Metadata for extraction should also be tracked as EXTRACT
)
metadata_end_time = time.time()
Expand Down
Loading