Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions stagehand/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
debug_dom: Optional[bool] = None,
httpx_client: Optional[httpx.AsyncClient] = None,
timeout_settings: Optional[httpx.Timeout] = None,
model_client_options: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Stagehand client.
Expand All @@ -65,6 +66,7 @@ def __init__(
debug_dom (Optional[bool]): Whether to enable DOM debugging mode.
httpx_client (Optional[httpx.AsyncClient]): Optional custom httpx.AsyncClient instance.
timeout_settings (Optional[httpx.Timeout]): Optional custom timeout settings for httpx.
model_client_options (Optional[Dict[str, Any]]): Optional model client options.
"""
self.server_url = server_url or os.getenv("STAGEHAND_SERVER_URL")

Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
# Additional config parameters available for future use:
self.headless = config.headless
self.enable_caching = config.enable_caching
self.model_client_options = model_client_options
else:
self.browserbase_api_key = browserbase_api_key or os.getenv(
"BROWSERBASE_API_KEY"
Expand All @@ -104,6 +107,7 @@ def __init__(
self.model_name = model_name
self.dom_settle_timeout_ms = dom_settle_timeout_ms
self.debug_dom = debug_dom
self.model_client_options = model_client_options

self.on_log = on_log
self.verbose = verbose
Expand Down Expand Up @@ -312,6 +316,9 @@ async def _create_session(self):
"verbose": self.verbose,
"debugDom": self.debug_dom,
}

if hasattr(self, "model_client_options") and self.model_client_options:
payload["modelClientOptions"] = self.model_client_options

headers = {
"x-bb-api-key": self.browserbase_api_key,
Expand Down Expand Up @@ -350,21 +357,25 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any:
}
if self.model_api_key:
headers["x-model-api-key"] = self.model_api_key


modified_payload = dict(payload)
if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload:
modified_payload["modelClientOptions"] = self.model_client_options

client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings)
self._log(f"\n==== EXECUTING {method.upper()} ====", level=3)
self._log(
f"URL: {self.server_url}/sessions/{self.session_id}/{method}", level=3
)
self._log(f"Payload: {payload}", level=3)
self._log(f"Payload: {modified_payload}", level=3)
self._log(f"Headers: {headers}", level=3)

async with client:
try:
async with client.stream(
"POST",
f"{self.server_url}/sessions/{self.session_id}/{method}",
json=payload,
json=modified_payload,
headers=headers,
) as response:
if response.status_code != 200:
Expand Down
6 changes: 3 additions & 3 deletions stagehand/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class StagehandConfig(BaseModel):
enable_caching (Optional[bool]): Enable caching functionality.
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
model_name (Optional[str]): Name of the model to use.
selfHeal (Optional[bool]): Enable self-healing functionality.
self_heal (Optional[bool]): Enable self-healing functionality.
"""

env: str = "BROWSERBASE"
Expand Down Expand Up @@ -53,8 +53,8 @@ class StagehandConfig(BaseModel):
model_name: Optional[str] = Field(
AvailableModel.GPT_4O, alias="modelName", description="Name of the model to use"
)
selfHeal: Optional[bool] = Field(
True, description="Enable self-healing functionality"
self_heal: Optional[bool] = Field(
True, alias="selfHeal", description="Enable self-healing functionality"
)

class Config:
Expand Down
26 changes: 20 additions & 6 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def goto(
if timeout is not None:
options["timeout"] = timeout
if wait_until is not None:
options["wait_until"] = wait_until
options["waitUntil"] = wait_until

payload = {"url": url}
Expand All @@ -68,18 +69,31 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult
Execute an AI action via the Stagehand server.

Args:
options (Union[str, ActOptions]): Either a string with the action command or
a Pydantic model encapsulating the action.
See `stagehand.schemas.ActOptions` for details on expected fields.
options (Union[str, ActOptions, ObserveResult]):
- A string with the action command to be executed by the AI
- An ActOptions object encapsulating the action command and optional parameters
- An ObserveResult with selector and method fields for direct execution without LLM

When an ObserveResult with both 'selector' and 'method' fields is provided,
the SDK will directly execute the action against the selector using the method
and arguments provided, bypassing the LLM processing.

Returns:
Any: The result from the Stagehand server's action execution.
ActResult: The result from the Stagehand server's action execution.
"""
# Check if options is an ObserveResult with both selector and method
if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"):
# For ObserveResult, we directly pass it to the server which will
# execute the method against the selector
payload = options.model_dump(exclude_none=True)
# Convert string to ActOptions if needed
if isinstance(options, str):
elif isinstance(options, str):
options = ActOptions(action=options)
payload = options.model_dump(exclude_none=True)
# Otherwise, it should be an ActOptions object
else:
payload = options.model_dump(exclude_none=True)

payload = options.model_dump(exclude_none=True)
lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("act", payload)
Expand Down
41 changes: 21 additions & 20 deletions stagehand/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class ActOptions(BaseModel):
Attributes:
action (str): The action command to be executed by the AI.
variables: Optional[Dict[str, str]] = None
modelName: Optional[AvailableModel] = None
slowDomBasedAct: Optional[bool] = None
model_name: Optional[AvailableModel] = None
slow_dom_based_act: Optional[bool] = None
"""

action: str = Field(..., description="The action command to be executed by the AI.")
variables: Optional[Dict[str, str]] = None
modelName: Optional[AvailableModel] = None
slowDomBasedAct: Optional[bool] = None
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct")


class ActResult(BaseModel):
Expand All @@ -56,25 +56,26 @@ class ExtractOptions(BaseModel):

Attributes:
instruction (str): Instruction specifying what data to extract using AI.
modelName: Optional[AvailableModel] = None
model_name: Optional[AvailableModel] = None
selector: Optional[str] = None
schemaDefinition (Union[Dict[str, Any], Type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data.
schema_definition (Union[Dict[str, Any], Type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data.
Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable.
useTextExtract: Optional[bool] = None
use_text_extract: Optional[bool] = None
"""

instruction: str = Field(
..., description="Instruction specifying what data to extract using AI."
)
modelName: Optional[AvailableModel] = None
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
selector: Optional[str] = None
# IMPORTANT: If using a Pydantic model for schemaDefinition, please call its .model_json_schema() method
# IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method
# to convert it to a JSON serializable dictionary before sending it with the extract command.
schemaDefinition: Union[Dict[str, Any], Type[BaseModel]] = Field(
schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field(
default=DEFAULT_EXTRACT_SCHEMA,
description="A JSON schema or Pydantic model that defines the structure of the expected data.",
alias="schemaDefinition",
)
useTextExtract: Optional[bool] = True
use_text_extract: Optional[bool] = Field(True, alias="useTextExtract")

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -108,19 +109,19 @@ class ObserveOptions(BaseModel):

Attributes:
instruction (str): Instruction detailing what the AI should observe.
modelName: Optional[AvailableModel] = None
onlyVisible: Optional[bool] = None
returnAction: Optional[bool] = None
drawOverlay: Optional[bool] = None
model_name: Optional[AvailableModel] = None
only_visible: Optional[bool] = None
return_action: Optional[bool] = None
draw_overlay: Optional[bool] = None
"""

instruction: str = Field(
..., description="Instruction detailing what the AI should observe."
)
onlyVisible: Optional[bool] = False
modelName: Optional[AvailableModel] = None
returnAction: Optional[bool] = None
drawOverlay: Optional[bool] = None
only_visible: Optional[bool] = Field(False, alias="onlyVisible")
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
return_action: Optional[bool] = Field(None, alias="returnAction")
draw_overlay: Optional[bool] = Field(None, alias="drawOverlay")


class ObserveResult(BaseModel):
Expand All @@ -132,7 +133,7 @@ class ObserveResult(BaseModel):
description: str = Field(
..., description="The description of the observed element."
)
backendNodeId: Optional[int] = None
backend_node_id: Optional[int] = Field(None, alias="backendNodeId")
method: Optional[str] = None
arguments: Optional[List[str]] = None

Expand Down