From 436668ddc1820998576bdcf4fc7fbe477c4f9578 Mon Sep 17 00:00:00 2001 From: miguel Date: Sun, 2 Mar 2025 18:28:20 -0800 Subject: [PATCH] fixed snake_case migration --- MANIFEST.in | 6 +++++ examples/example.py | 1 + stagehand/client.py | 5 +++- stagehand/page.py | 10 +++---- stagehand/schemas.py | 51 +++++++++++++++++++++++------------ stagehand/utils.py | 64 +++++++++++++++++++++++++++++++++----------- 6 files changed, 99 insertions(+), 38 deletions(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..7240e718 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +include README.md +include requirements.txt +global-exclude *.pyc +global-exclude __pycache__ +global-exclude .DS_Store +global-exclude */node_modules/* \ No newline at end of file diff --git a/examples/example.py b/examples/example.py index 0cb9f051..a7ef2271 100644 --- a/examples/example.py +++ b/examples/example.py @@ -79,6 +79,7 @@ async def main(): console.print("\n▶️ [highlight] Performing action:[/] search for openai") await page.act("search for openai") + await page.keyboard.press("Enter") console.print("✅ [success]Performing Action:[/] Action completed successfully") console.print("\n▶️ [highlight] Observing page[/] for news button") diff --git a/stagehand/client.py b/stagehand/client.py index 40624f60..8474b869 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -12,7 +12,7 @@ from .config import StagehandConfig from .page import StagehandPage -from .utils import default_log_handler +from .utils import default_log_handler, convert_dict_keys_to_camel_case load_dotenv() @@ -362,6 +362,9 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload: modified_payload["modelClientOptions"] = self.model_client_options + # Convert snake_case keys to camelCase for the API + modified_payload = convert_dict_keys_to_camel_case(modified_payload) + client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) self._log(f"\n==== EXECUTING {method.upper()} ====", level=3) self._log( diff --git a/stagehand/page.py b/stagehand/page.py index d14ae605..aeb878be 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -85,14 +85,14 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"): # For ObserveResult, we directly pass it to the server which will # execute the method against the selector - payload = options.model_dump(exclude_none=True) + payload = options.model_dump(exclude_none=True, by_alias=True) # Convert string to ActOptions if needed elif isinstance(options, str): options = ActOptions(action=options) - payload = options.model_dump(exclude_none=True) + payload = options.model_dump(exclude_none=True, by_alias=True) # Otherwise, it should be an ActOptions object else: - payload = options.model_dump(exclude_none=True) + payload = options.model_dump(exclude_none=True, by_alias=True) lock = self._stagehand._get_lock_for_session() async with lock: @@ -117,7 +117,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResu if isinstance(options, str): options = ObserveOptions(instruction=options) - payload = options.model_dump(exclude_none=True) + payload = options.model_dump(exclude_none=True, by_alias=True) lock = self._stagehand._get_lock_for_session() async with lock: result = await self._stagehand._execute("observe", payload) @@ -148,7 +148,7 @@ async def extract(self, options: Union[str, ExtractOptions]) -> ExtractResult: if isinstance(options, str): options = ExtractOptions(instruction=options) - payload = options.model_dump(exclude_none=True) + payload = options.model_dump(exclude_none=True, by_alias=True) lock = self._stagehand._get_lock_for_session() async with lock: result = await self._stagehand._execute("extract", payload) diff --git a/stagehand/schemas.py b/stagehand/schemas.py index d4962f46..fd40dfe5 100644 --- a/stagehand/schemas.py +++ b/stagehand/schemas.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Type, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_serializer # Default extraction schema that matches the TypeScript version DEFAULT_EXTRACT_SCHEMA = { @@ -18,7 +18,18 @@ class AvailableModel(str, Enum): CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest" -class ActOptions(BaseModel): +class StagehandBaseModel(BaseModel): + """Base model for all Stagehand models with camelCase conversion support""" + + class Config: + populate_by_name = True # Allow accessing fields by their Python name + alias_generator = lambda field_name: ''.join( + [field_name.split('_')[0]] + + [word.capitalize() for word in field_name.split('_')[1:]] + ) # snake_case to camelCase + + +class ActOptions(StagehandBaseModel): """ Options for the 'act' command. @@ -31,11 +42,11 @@ class ActOptions(BaseModel): action: str = Field(..., description="The action command to be executed by the AI.") variables: Optional[Dict[str, str]] = None - model_name: Optional[AvailableModel] = Field(None, alias="modelName") - slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct") + model_name: Optional[AvailableModel] = None + slow_dom_based_act: Optional[bool] = None -class ActResult(BaseModel): +class ActResult(StagehandBaseModel): """ Result of the 'act' command. @@ -50,7 +61,7 @@ class ActResult(BaseModel): action: str = Field(..., description="The action command that was executed.") -class ExtractOptions(BaseModel): +class ExtractOptions(StagehandBaseModel): """ Options for the 'extract' command. @@ -66,22 +77,28 @@ class ExtractOptions(BaseModel): instruction: str = Field( ..., description="Instruction specifying what data to extract using AI." ) - model_name: Optional[AvailableModel] = Field(None, alias="modelName") + model_name: Optional[AvailableModel] = None selector: Optional[str] = None # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method # to convert it to a JSON serializable dictionary before sending it with the extract command. schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field( default=DEFAULT_EXTRACT_SCHEMA, description="A JSON schema or Pydantic model that defines the structure of the expected data.", - alias="schemaDefinition", ) - use_text_extract: Optional[bool] = Field(True, alias="useTextExtract") + use_text_extract: Optional[bool] = True + + @field_serializer('schema_definition') + def serialize_schema_definition(self, schema_definition: Union[Dict[str, Any], Type[BaseModel]]) -> Dict[str, Any]: + """Serialize schema_definition to a JSON schema if it's a Pydantic model""" + if isinstance(schema_definition, type) and issubclass(schema_definition, BaseModel): + return schema_definition.model_json_schema() + return schema_definition class Config: arbitrary_types_allowed = True -class ExtractResult(BaseModel): +class ExtractResult(StagehandBaseModel): """ Result of the 'extract' command. @@ -103,7 +120,7 @@ def __getitem__(self, key): return getattr(self, key) -class ObserveOptions(BaseModel): +class ObserveOptions(StagehandBaseModel): """ Options for the 'observe' command. @@ -118,13 +135,13 @@ class ObserveOptions(BaseModel): instruction: str = Field( ..., description="Instruction detailing what the AI should observe." ) - only_visible: Optional[bool] = Field(False, alias="onlyVisible") - model_name: Optional[AvailableModel] = Field(None, alias="modelName") - return_action: Optional[bool] = Field(None, alias="returnAction") - draw_overlay: Optional[bool] = Field(None, alias="drawOverlay") + only_visible: Optional[bool] = False + model_name: Optional[AvailableModel] = None + return_action: Optional[bool] = None + draw_overlay: Optional[bool] = None -class ObserveResult(BaseModel): +class ObserveResult(StagehandBaseModel): """ Result of the 'observe' command. """ @@ -133,7 +150,7 @@ class ObserveResult(BaseModel): description: str = Field( ..., description="The description of the observed element." ) - backend_node_id: Optional[int] = Field(None, alias="backendNodeId") + backend_node_id: Optional[int] = None method: Optional[str] = None arguments: Optional[List[str]] = None diff --git a/stagehand/utils.py b/stagehand/utils.py index 6301247d..459f4c53 100644 --- a/stagehand/utils.py +++ b/stagehand/utils.py @@ -1,23 +1,57 @@ +import asyncio import logging +from typing import Any, Dict +# Setup logging logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +logger.addHandler(handler) +async def default_log_handler(log_data: Dict[str, Any]) -> None: + """Default handler for log messages from the Stagehand server.""" + level = log_data.get("level", "info").lower() + message = log_data.get("message", "") -async def default_log_handler(log_data: dict): + log_method = getattr(logger, level, logger.info) + log_method(message) + + +def snake_to_camel(snake_str: str) -> str: """ - Default async log handler that shows detailed server logs. - Can be overridden by passing a custom handler to Stagehand's constructor. + Convert a snake_case string to camelCase. + + Args: + snake_str: The snake_case string to convert + + Returns: + The converted camelCase string """ - if "type" in log_data: - log_type = log_data["type"] - data = log_data.get("data", {}) + components = snake_str.split('_') + return components[0] + ''.join(x.title() for x in components[1:]) + - if log_type == "system": - logger.info(f"🔧 SYSTEM: {data}") - elif log_type == "log": - logger.info(f"📝 LOG: {data}") - else: - logger.info(f"ℹ️ OTHER [{log_type}]: {data}") - else: - # Fallback for any other format - logger.info(f"🤖 RAW LOG: {log_data}") +def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert all keys in a dictionary from snake_case to camelCase. + Works recursively for nested dictionaries. + + Args: + data: Dictionary with snake_case keys + + Returns: + Dictionary with camelCase keys + """ + result = {} + + for key, value in data.items(): + if isinstance(value, dict): + value = convert_dict_keys_to_camel_case(value) + elif isinstance(value, list): + value = [convert_dict_keys_to_camel_case(item) if isinstance(item, dict) else item for item in value] + + # Convert snake_case key to camelCase + camel_key = snake_to_camel(key) + result[camel_key] = value + + return result