Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
MODEL_API_KEY = "anthropic-or-openai-api-key"
BROWSERBASE_API_KEY = "browserbase-api-key"
BROWSERBASE_PROJECT_ID = "browserbase-project-id"
STAGEHAND_SERVER_URL = "api_url"
6 changes: 0 additions & 6 deletions MANIFEST.in

This file was deleted.

70 changes: 41 additions & 29 deletions evals/act/google_jobs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import traceback
from typing import Optional, Any, Dict
from typing import Any, Dict, Optional

from pydantic import BaseModel

from evals.init_stagehand import init_stagehand
from stagehand.schemas import ActOptions, ExtractOptions

Expand Down Expand Up @@ -49,12 +51,12 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
- Clicking on the search button
- Clicking on the first job link
4. Extracting job posting details using an AI-driven extraction schema.

The extraction schema requires:
- applicationDeadline: The opening date until which applications are accepted.
- minimumQualifications: An object with degree and yearsOfExperience.
- preferredQualifications: An object with degree and yearsOfExperience.

Returns a dictionary containing:
- _success (bool): Whether valid job details were extracted.
- jobDetails (dict): The extracted job details.
Expand All @@ -77,23 +79,25 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:

try:
await stagehand.page.navigate("https://www.google.com/")
await asyncio.sleep(3)
await asyncio.sleep(3)
await stagehand.page.act(ActOptions(action="click on the about page"))
await stagehand.page.act(ActOptions(action="click on the careers page"))
await stagehand.page.act(ActOptions(action="input data scientist into role"))
await stagehand.page.act(ActOptions(action="input new york city into location"))
await stagehand.page.act(ActOptions(action="click on the search button"))
await stagehand.page.act(ActOptions(action="click on the first job link"))

job_details = await stagehand.page.extract(ExtractOptions(
instruction=(
"Extract the following details from the job posting: application deadline, "
"minimum qualifications (degree and years of experience), and preferred qualifications "
"(degree and years of experience)"
),
schemaDefinition=JobDetails.model_json_schema(),
useTextExtract=use_text_extract
))
job_details = await stagehand.page.extract(
ExtractOptions(
instruction=(
"Extract the following details from the job posting: application deadline, "
"minimum qualifications (degree and years of experience), and preferred qualifications "
"(degree and years of experience)"
),
schemaDefinition=JobDetails.model_json_schema(),
useTextExtract=use_text_extract,
)
)

valid = is_job_details_valid(job_details)

Expand All @@ -104,19 +108,21 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
"jobDetails": job_details,
"debugUrl": debug_url,
"sessionUrl": session_url,
"logs": logger.get_logs() if hasattr(logger, "get_logs") else []
"logs": logger.get_logs() if hasattr(logger, "get_logs") else [],
}
except Exception as e:
err_message = str(e)
err_trace = traceback.format_exc()
logger.error({
"message": "error in google_jobs function",
"level": 0,
"auxiliary": {
"error": {"value": err_message, "type": "string"},
"trace": {"value": err_trace, "type": "string"}
logger.error(
{
"message": "error in google_jobs function",
"level": 0,
"auxiliary": {
"error": {"value": err_message, "type": "string"},
"trace": {"value": err_trace, "type": "string"},
},
}
})
)

await stagehand.close()

Expand All @@ -125,31 +131,37 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
"debugUrl": debug_url,
"sessionUrl": session_url,
"error": {"message": err_message, "trace": err_trace},
"logs": logger.get_logs() if hasattr(logger, "get_logs") else []
}

"logs": logger.get_logs() if hasattr(logger, "get_logs") else [],
}


# For quick local testing
if __name__ == "__main__":
import os
import asyncio
import logging

logging.basicConfig(level=logging.INFO)

class SimpleLogger:
def __init__(self):
self._logs = []

def info(self, message):
self._logs.append(message)
print("INFO:", message)

def error(self, message):
self._logs.append(message)
print("ERROR:", message)

def get_logs(self):
return self._logs

async def main():
logger = SimpleLogger()
result = await google_jobs("gpt-4o-mini", logger, use_text_extract=False) # TODO - use text extract
result = await google_jobs(
"gpt-4o-mini", logger, use_text_extract=False
) # TODO - use text extract
print("Result:", result)
asyncio.run(main())

asyncio.run(main())
59 changes: 38 additions & 21 deletions evals/extract/extract_press_releases.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
import asyncio

from pydantic import BaseModel
from stagehand.schemas import ExtractOptions

from evals.init_stagehand import init_stagehand
from evals.utils import compare_strings
from stagehand.schemas import ExtractOptions


# Define Pydantic models for validating press release data
class PressRelease(BaseModel):
title: str
publish_date: str


class PressReleases(BaseModel):
items: list[PressRelease]


async def extract_press_releases(model_name: str, logger, use_text_extract: bool):
"""
Extract press releases from the dummy press releases page using the Stagehand client.

Args:
model_name (str): Name of the AI model to use.
logger: A custom logger that provides .error() and .get_logs() methods.
use_text_extract (bool): Flag to control text extraction behavior.

Returns:
dict: A result object containing:
- _success (bool): Whether the eval was successful.
Expand All @@ -34,12 +39,16 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
session_url = None
try:
# Initialize Stagehand (mimicking the TS initStagehand)
stagehand, init_response = await init_stagehand(model_name, logger, dom_settle_timeout_ms=3000)
stagehand, init_response = await init_stagehand(
model_name, logger, dom_settle_timeout_ms=3000
)
debug_url = init_response["debugUrl"]
session_url = init_response["sessionUrl"]

# Navigate to the dummy press releases page # TODO - choose a different page
await stagehand.page.navigate("https://dummy-press-releases.surge.sh/news", wait_until="networkidle")
await stagehand.page.navigate(
"https://dummy-press-releases.surge.sh/news", wait_until="networkidle"
)
# Wait for 5 seconds to ensure content has loaded
await asyncio.sleep(5)

Expand All @@ -49,7 +58,7 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
ExtractOptions(
instruction="extract the title and corresponding publish date of EACH AND EVERY press releases on this page. DO NOT MISS ANY PRESS RELEASES.",
schemaDefinition=PressReleases.model_json_schema(),
useTextExtract=use_text_extract
useTextExtract=use_text_extract,
)
)
print("Raw result:", raw_result)
Expand All @@ -73,19 +82,21 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
expected_length = 28
expected_first = PressRelease(
title="UAW Region 9A Endorses Brad Lander for Mayor",
publish_date="Dec 4, 2024"
publish_date="Dec 4, 2024",
)
expected_last = PressRelease(
title="Fox Sued by New York City Pension Funds Over Election Falsehoods",
publish_date="Nov 12, 2023"
publish_date="Nov 12, 2023",
)

if len(items) <= expected_length:
logger.error({
"message": "Not enough items extracted",
"expected": f"> {expected_length}",
"actual": len(items)
})
logger.error(
{
"message": "Not enough items extracted",
"expected": f"> {expected_length}",
"actual": len(items),
}
)
return {
"_success": False,
"error": "Not enough items extracted",
Expand All @@ -111,10 +122,9 @@ def is_item_match(item: PressRelease, expected: PressRelease) -> bool:
await stagehand.close()
return result
except Exception as e:
logger.error({
"message": "Error in extract_press_releases function",
"error": str(e)
})
logger.error(
{"message": "Error in extract_press_releases function", "error": str(e)}
)
return {
"_success": False,
"error": str(e),
Expand All @@ -127,26 +137,33 @@ def is_item_match(item: PressRelease, expected: PressRelease) -> bool:
if stagehand:
await stagehand.close()


# For quick local testing.
if __name__ == "__main__":
import logging

logging.basicConfig(level=logging.INFO)

class SimpleLogger:
def __init__(self):
self._logs = []

def info(self, message):
self._logs.append(message)
print("INFO:", message)

def error(self, message):
self._logs.append(message)
print("ERROR:", message)

def get_logs(self):
return self._logs

async def main():
logger = SimpleLogger()
result = await extract_press_releases("gpt-4o", logger, use_text_extract=False) # TODO - use text extract
result = await extract_press_releases(
"gpt-4o", logger, use_text_extract=False
) # TODO - use text extract
print("Result:", result)
asyncio.run(main())

asyncio.run(main())
21 changes: 14 additions & 7 deletions evals/init_stagehand.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
import os
import asyncio

from stagehand import Stagehand
from stagehand.config import StagehandConfig


async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3000):
"""
Initialize a Stagehand client with the given model name, logger, and DOM settle timeout.

This function creates a configuration from environment variables, initializes the Stagehand client,
and returns a tuple of (stagehand, init_response). The init_response contains debug and session URLs.

Args:
model_name (str): The name of the AI model to use.
logger: A logger instance for logging errors and debug messages.
dom_settle_timeout_ms (int): Milliseconds to wait for the DOM to settle.

Returns:
tuple: (stagehand, init_response) where init_response is a dict containing:
- "debugUrl": A dict with a "value" key for the debug URL.
- "sessionUrl": A dict with a "value" key for the session URL.
"""
# Build a Stagehand configuration object using environment variables
config = StagehandConfig(
env="BROWSERBASE" if os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID") else "LOCAL",
env=(
"BROWSERBASE"
if os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")
else "LOCAL"
),
api_key=os.getenv("BROWSERBASE_API_KEY"),
project_id=os.getenv("BROWSERBASE_PROJECT_ID"),
debug_dom=True,
Expand All @@ -33,7 +38,9 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3
)

# Create a Stagehand client with the configuration; server_url is taken from environment variables.
stagehand = Stagehand(config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2)
stagehand = Stagehand(
config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2
)
await stagehand.init()

# Construct the URL from the session id using the new format.
Expand All @@ -43,4 +50,4 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3
url = f"wss://connect.browserbase.com?apiKey={api_key}&sessionId={stagehand.session_id}"

# Return both URLs as dictionaries with the "value" key.
return stagehand, {"debugUrl": {"value": url}, "sessionUrl": {"value": url}}
return stagehand, {"debugUrl": {"value": url}, "sessionUrl": {"value": url}}
Loading