Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
include README.md
include requirements.txt
global-exclude *.pyc
global-exclude __pycache__
global-exclude .DS_Store
global-exclude */node_modules/*
1 change: 1 addition & 0 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def main():

console.print("\n▶️ [highlight] Performing action:[/] search for openai")
await page.act("search for openai")
await page.keyboard.press("Enter")
console.print("✅ [success]Performing Action:[/] Action completed successfully")

console.print("\n▶️ [highlight] Observing page[/] for news button")
Expand Down
5 changes: 4 additions & 1 deletion stagehand/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .config import StagehandConfig
from .page import StagehandPage
from .utils import default_log_handler
from .utils import default_log_handler, convert_dict_keys_to_camel_case

load_dotenv()

Expand Down Expand Up @@ -362,6 +362,9 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any:
if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload:
modified_payload["modelClientOptions"] = self.model_client_options

# Convert snake_case keys to camelCase for the API
modified_payload = convert_dict_keys_to_camel_case(modified_payload)

client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings)
self._log(f"\n==== EXECUTING {method.upper()} ====", level=3)
self._log(
Expand Down
10 changes: 5 additions & 5 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult
if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"):
# For ObserveResult, we directly pass it to the server which will
# execute the method against the selector
payload = options.model_dump(exclude_none=True)
payload = options.model_dump(exclude_none=True, by_alias=True)
# Convert string to ActOptions if needed
elif isinstance(options, str):
options = ActOptions(action=options)
payload = options.model_dump(exclude_none=True)
payload = options.model_dump(exclude_none=True, by_alias=True)
# Otherwise, it should be an ActOptions object
else:
payload = options.model_dump(exclude_none=True)
payload = options.model_dump(exclude_none=True, by_alias=True)

lock = self._stagehand._get_lock_for_session()
async with lock:
Expand All @@ -117,7 +117,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResu
if isinstance(options, str):
options = ObserveOptions(instruction=options)

payload = options.model_dump(exclude_none=True)
payload = options.model_dump(exclude_none=True, by_alias=True)
lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("observe", payload)
Expand Down Expand Up @@ -148,7 +148,7 @@ async def extract(self, options: Union[str, ExtractOptions]) -> ExtractResult:
if isinstance(options, str):
options = ExtractOptions(instruction=options)

payload = options.model_dump(exclude_none=True)
payload = options.model_dump(exclude_none=True, by_alias=True)
lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("extract", payload)
Expand Down
51 changes: 34 additions & 17 deletions stagehand/schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer

# Default extraction schema that matches the TypeScript version
DEFAULT_EXTRACT_SCHEMA = {
Expand All @@ -18,7 +18,18 @@ class AvailableModel(str, Enum):
CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest"


class ActOptions(BaseModel):
class StagehandBaseModel(BaseModel):
"""Base model for all Stagehand models with camelCase conversion support"""

class Config:
populate_by_name = True # Allow accessing fields by their Python name
alias_generator = lambda field_name: ''.join(
[field_name.split('_')[0]] +
[word.capitalize() for word in field_name.split('_')[1:]]
) # snake_case to camelCase


class ActOptions(StagehandBaseModel):
"""
Options for the 'act' command.

Expand All @@ -31,11 +42,11 @@ class ActOptions(BaseModel):

action: str = Field(..., description="The action command to be executed by the AI.")
variables: Optional[Dict[str, str]] = None
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct")
model_name: Optional[AvailableModel] = None
slow_dom_based_act: Optional[bool] = None


class ActResult(BaseModel):
class ActResult(StagehandBaseModel):
"""
Result of the 'act' command.

Expand All @@ -50,7 +61,7 @@ class ActResult(BaseModel):
action: str = Field(..., description="The action command that was executed.")


class ExtractOptions(BaseModel):
class ExtractOptions(StagehandBaseModel):
"""
Options for the 'extract' command.

Expand All @@ -66,22 +77,28 @@ class ExtractOptions(BaseModel):
instruction: str = Field(
..., description="Instruction specifying what data to extract using AI."
)
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
model_name: Optional[AvailableModel] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the model name really optional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes there's a default in config

selector: Optional[str] = None
# IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method
# to convert it to a JSON serializable dictionary before sending it with the extract command.
schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field(
default=DEFAULT_EXTRACT_SCHEMA,
description="A JSON schema or Pydantic model that defines the structure of the expected data.",
alias="schemaDefinition",
)
use_text_extract: Optional[bool] = Field(True, alias="useTextExtract")
use_text_extract: Optional[bool] = True

@field_serializer('schema_definition')
def serialize_schema_definition(self, schema_definition: Union[Dict[str, Any], Type[BaseModel]]) -> Dict[str, Any]:
"""Serialize schema_definition to a JSON schema if it's a Pydantic model"""
if isinstance(schema_definition, type) and issubclass(schema_definition, BaseModel):
return schema_definition.model_json_schema()
return schema_definition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a test for it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smoke tested locally w/ example.py:

class News(BaseModel):
    description: str
    url: str

...

  data = await page.extract(ExtractOptions(
      instruction="extract the first result from the search",
      schema_definition=News,
  ))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, confirming it works for me as well!


class Config:
arbitrary_types_allowed = True


class ExtractResult(BaseModel):
class ExtractResult(StagehandBaseModel):
"""
Result of the 'extract' command.

Expand All @@ -103,7 +120,7 @@ def __getitem__(self, key):
return getattr(self, key)


class ObserveOptions(BaseModel):
class ObserveOptions(StagehandBaseModel):
"""
Options for the 'observe' command.

Expand All @@ -118,13 +135,13 @@ class ObserveOptions(BaseModel):
instruction: str = Field(
..., description="Instruction detailing what the AI should observe."
)
only_visible: Optional[bool] = Field(False, alias="onlyVisible")
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
return_action: Optional[bool] = Field(None, alias="returnAction")
draw_overlay: Optional[bool] = Field(None, alias="drawOverlay")
only_visible: Optional[bool] = False
model_name: Optional[AvailableModel] = None
return_action: Optional[bool] = None
draw_overlay: Optional[bool] = None


class ObserveResult(BaseModel):
class ObserveResult(StagehandBaseModel):
"""
Result of the 'observe' command.
"""
Expand All @@ -133,7 +150,7 @@ class ObserveResult(BaseModel):
description: str = Field(
..., description="The description of the observed element."
)
backend_node_id: Optional[int] = Field(None, alias="backendNodeId")
backend_node_id: Optional[int] = None
method: Optional[str] = None
arguments: Optional[List[str]] = None

Expand Down
64 changes: 49 additions & 15 deletions stagehand/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,57 @@
import asyncio
import logging
from typing import Any, Dict

# Setup logging
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

async def default_log_handler(log_data: Dict[str, Any]) -> None:
"""Default handler for log messages from the Stagehand server."""
level = log_data.get("level", "info").lower()
message = log_data.get("message", "")

async def default_log_handler(log_data: dict):
log_method = getattr(logger, level, logger.info)
log_method(message)


def snake_to_camel(snake_str: str) -> str:
"""
Default async log handler that shows detailed server logs.
Can be overridden by passing a custom handler to Stagehand's constructor.
Convert a snake_case string to camelCase.

Args:
snake_str: The snake_case string to convert

Returns:
The converted camelCase string
"""
if "type" in log_data:
log_type = log_data["type"]
data = log_data.get("data", {})
components = snake_str.split('_')
return components[0] + ''.join(x.title() for x in components[1:])


if log_type == "system":
logger.info(f"🔧 SYSTEM: {data}")
elif log_type == "log":
logger.info(f"📝 LOG: {data}")
else:
logger.info(f"ℹ️ OTHER [{log_type}]: {data}")
else:
# Fallback for any other format
logger.info(f"🤖 RAW LOG: {log_data}")
def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert all keys in a dictionary from snake_case to camelCase.
Works recursively for nested dictionaries.

Args:
data: Dictionary with snake_case keys

Returns:
Dictionary with camelCase keys
"""
result = {}

for key, value in data.items():
if isinstance(value, dict):
value = convert_dict_keys_to_camel_case(value)
elif isinstance(value, list):
value = [convert_dict_keys_to_camel_case(item) if isinstance(item, dict) else item for item in value]

# Convert snake_case key to camelCase
camel_key = snake_to_camel(key)
result[camel_key] = value

return result