-
Notifications
You must be signed in to change notification settings - Fork 69
fixed snake_case migration #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
include README.md | ||
include requirements.txt | ||
global-exclude *.pyc | ||
global-exclude __pycache__ | ||
global-exclude .DS_Store | ||
global-exclude */node_modules/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from enum import Enum | ||
from typing import Any, Dict, List, Optional, Type, Union | ||
|
||
from pydantic import BaseModel, Field | ||
from pydantic import BaseModel, Field, field_serializer | ||
|
||
# Default extraction schema that matches the TypeScript version | ||
DEFAULT_EXTRACT_SCHEMA = { | ||
|
@@ -18,7 +18,18 @@ class AvailableModel(str, Enum): | |
CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest" | ||
|
||
|
||
class ActOptions(BaseModel): | ||
class StagehandBaseModel(BaseModel): | ||
"""Base model for all Stagehand models with camelCase conversion support""" | ||
|
||
class Config: | ||
populate_by_name = True # Allow accessing fields by their Python name | ||
alias_generator = lambda field_name: ''.join( | ||
[field_name.split('_')[0]] + | ||
[word.capitalize() for word in field_name.split('_')[1:]] | ||
) # snake_case to camelCase | ||
|
||
|
||
class ActOptions(StagehandBaseModel): | ||
""" | ||
Options for the 'act' command. | ||
|
||
|
@@ -31,11 +42,11 @@ class ActOptions(BaseModel): | |
|
||
action: str = Field(..., description="The action command to be executed by the AI.") | ||
variables: Optional[Dict[str, str]] = None | ||
model_name: Optional[AvailableModel] = Field(None, alias="modelName") | ||
slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct") | ||
model_name: Optional[AvailableModel] = None | ||
slow_dom_based_act: Optional[bool] = None | ||
|
||
|
||
class ActResult(BaseModel): | ||
class ActResult(StagehandBaseModel): | ||
""" | ||
Result of the 'act' command. | ||
|
||
|
@@ -50,7 +61,7 @@ class ActResult(BaseModel): | |
action: str = Field(..., description="The action command that was executed.") | ||
|
||
|
||
class ExtractOptions(BaseModel): | ||
class ExtractOptions(StagehandBaseModel): | ||
""" | ||
Options for the 'extract' command. | ||
|
||
|
@@ -66,22 +77,28 @@ class ExtractOptions(BaseModel): | |
instruction: str = Field( | ||
..., description="Instruction specifying what data to extract using AI." | ||
) | ||
model_name: Optional[AvailableModel] = Field(None, alias="modelName") | ||
model_name: Optional[AvailableModel] = None | ||
selector: Optional[str] = None | ||
# IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method | ||
# to convert it to a JSON serializable dictionary before sending it with the extract command. | ||
schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field( | ||
default=DEFAULT_EXTRACT_SCHEMA, | ||
description="A JSON schema or Pydantic model that defines the structure of the expected data.", | ||
alias="schemaDefinition", | ||
) | ||
use_text_extract: Optional[bool] = Field(True, alias="useTextExtract") | ||
use_text_extract: Optional[bool] = True | ||
|
||
@field_serializer('schema_definition') | ||
def serialize_schema_definition(self, schema_definition: Union[Dict[str, Any], Type[BaseModel]]) -> Dict[str, Any]: | ||
"""Serialize schema_definition to a JSON schema if it's a Pydantic model""" | ||
if isinstance(schema_definition, type) and issubclass(schema_definition, BaseModel): | ||
return schema_definition.model_json_schema() | ||
return schema_definition | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you have a test for it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. smoke tested locally w/ example.py: class News(BaseModel):
description: str
url: str
...
data = await page.extract(ExtractOptions(
instruction="extract the first result from the search",
schema_definition=News,
)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. awesome, confirming it works for me as well! |
||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class ExtractResult(BaseModel): | ||
class ExtractResult(StagehandBaseModel): | ||
""" | ||
Result of the 'extract' command. | ||
|
||
|
@@ -103,7 +120,7 @@ def __getitem__(self, key): | |
return getattr(self, key) | ||
|
||
|
||
class ObserveOptions(BaseModel): | ||
class ObserveOptions(StagehandBaseModel): | ||
""" | ||
Options for the 'observe' command. | ||
|
||
|
@@ -118,13 +135,13 @@ class ObserveOptions(BaseModel): | |
instruction: str = Field( | ||
..., description="Instruction detailing what the AI should observe." | ||
) | ||
only_visible: Optional[bool] = Field(False, alias="onlyVisible") | ||
model_name: Optional[AvailableModel] = Field(None, alias="modelName") | ||
return_action: Optional[bool] = Field(None, alias="returnAction") | ||
draw_overlay: Optional[bool] = Field(None, alias="drawOverlay") | ||
only_visible: Optional[bool] = False | ||
model_name: Optional[AvailableModel] = None | ||
return_action: Optional[bool] = None | ||
draw_overlay: Optional[bool] = None | ||
|
||
|
||
class ObserveResult(BaseModel): | ||
class ObserveResult(StagehandBaseModel): | ||
""" | ||
Result of the 'observe' command. | ||
""" | ||
|
@@ -133,7 +150,7 @@ class ObserveResult(BaseModel): | |
description: str = Field( | ||
..., description="The description of the observed element." | ||
) | ||
backend_node_id: Optional[int] = Field(None, alias="backendNodeId") | ||
backend_node_id: Optional[int] = None | ||
method: Optional[str] = None | ||
arguments: Optional[List[str]] = None | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,57 @@ | ||
import asyncio | ||
import logging | ||
from typing import Any, Dict | ||
|
||
# Setup logging | ||
logger = logging.getLogger(__name__) | ||
handler = logging.StreamHandler() | ||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | ||
logger.addHandler(handler) | ||
|
||
async def default_log_handler(log_data: Dict[str, Any]) -> None: | ||
"""Default handler for log messages from the Stagehand server.""" | ||
level = log_data.get("level", "info").lower() | ||
message = log_data.get("message", "") | ||
|
||
async def default_log_handler(log_data: dict): | ||
log_method = getattr(logger, level, logger.info) | ||
log_method(message) | ||
|
||
|
||
def snake_to_camel(snake_str: str) -> str: | ||
""" | ||
Default async log handler that shows detailed server logs. | ||
Can be overridden by passing a custom handler to Stagehand's constructor. | ||
Convert a snake_case string to camelCase. | ||
|
||
Args: | ||
snake_str: The snake_case string to convert | ||
|
||
Returns: | ||
The converted camelCase string | ||
""" | ||
if "type" in log_data: | ||
log_type = log_data["type"] | ||
data = log_data.get("data", {}) | ||
components = snake_str.split('_') | ||
return components[0] + ''.join(x.title() for x in components[1:]) | ||
|
||
|
||
if log_type == "system": | ||
logger.info(f"🔧 SYSTEM: {data}") | ||
elif log_type == "log": | ||
logger.info(f"📝 LOG: {data}") | ||
else: | ||
logger.info(f"ℹ️ OTHER [{log_type}]: {data}") | ||
else: | ||
# Fallback for any other format | ||
logger.info(f"🤖 RAW LOG: {log_data}") | ||
def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Convert all keys in a dictionary from snake_case to camelCase. | ||
Works recursively for nested dictionaries. | ||
|
||
Args: | ||
data: Dictionary with snake_case keys | ||
|
||
Returns: | ||
Dictionary with camelCase keys | ||
""" | ||
result = {} | ||
|
||
for key, value in data.items(): | ||
if isinstance(value, dict): | ||
value = convert_dict_keys_to_camel_case(value) | ||
elif isinstance(value, list): | ||
value = [convert_dict_keys_to_camel_case(item) if isinstance(item, dict) else item for item in value] | ||
|
||
# Convert snake_case key to camelCase | ||
camel_key = snake_to_camel(key) | ||
result[camel_key] = value | ||
|
||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the model name really optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes there's a default in config