From 2aa6e58ba1c8fa84e82d7860f335cc17ae5711e5 Mon Sep 17 00:00:00 2001 From: miguel Date: Wed, 26 Feb 2025 18:38:52 -0800 Subject: [PATCH 1/2] CI updates for linting, testing & publishing package --- .bumpversion.cfg | 12 +++++ .github/README_PUBLISHING.md | 81 ++++++++++++++++++++++++++++++++++ .github/workflows/publish.yml | 82 +++++++++++++++++++++++++++++++++++ examples/example.py | 3 ++ stagehand/__init__.py | 2 +- 5 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 .bumpversion.cfg create mode 100644 .github/README_PUBLISHING.md create mode 100644 .github/workflows/publish.yml diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 00000000..bf8366a6 --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,12 @@ +[bumpversion] +current_version = 0.3.0 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version="{current_version}" +replace = version="{new_version}" + +[bumpversion:file:stagehand/__init__.py] +search = __version__ = "{current_version}" +replace = __version__ = "{new_version}" \ No newline at end of file diff --git a/.github/README_PUBLISHING.md b/.github/README_PUBLISHING.md new file mode 100644 index 00000000..2b4ec69f --- /dev/null +++ b/.github/README_PUBLISHING.md @@ -0,0 +1,81 @@ +# Publishing stagehand-python to PyPI + +This repository is configured with a GitHub Actions workflow to automate the process of publishing new versions to PyPI. + +## Prerequisites + +Before using the publishing workflow, ensure you have: + +1. Set up the following secrets in your GitHub repository settings: + - `PYPI_USERNAME`: Your PyPI username + - `PYPI_API_TOKEN`: Your PyPI API token (not your password) + +## How to Publish a New Version + +### Manual Trigger + +1. Go to the "Actions" tab in your GitHub repository +2. Select the "Publish to PyPI" workflow from the list +3. Click "Run workflow" on the right side +4. Configure the workflow: + - Choose the release type: + - `patch` (e.g., 0.3.0 → 0.3.1) for bug fixes + - `minor` (e.g., 0.3.0 → 0.4.0) for backward-compatible features + - `major` (e.g., 0.3.0 → 1.0.0) for breaking changes + - Toggle "Create GitHub Release" if you want to create a GitHub release +5. Click "Run workflow" to start the process + +### What Happens During Publishing + +The workflow will: + +1. Checkout the repository +2. Set up Python environment +3. Install dependencies +4. **Run Ruff linting checks**: + - Checks for code style and quality issues + - Verifies formatting according to project standards + - Fails the workflow if issues are found +5. Run tests to ensure everything works +6. Update the version number using bumpversion +7. Build the package +8. Upload to PyPI +9. Push the version bump commit and tag +10. Create a GitHub release (if selected) + +## Code Quality Standards + +This project uses Ruff for linting and formatting. The workflow enforces these standards before publishing: + +- Style checks following configured rules in `pyproject.toml` +- Format verification without making changes +- All linting issues must be fixed before a successful publish + +To run the same checks locally: +```bash +# Install Ruff +pip install ruff + +# Run linting +ruff check . + +# Check formatting +ruff format --check . + +# Auto-fix issues where possible +ruff check --fix . +ruff format . + +# Use Black to format the code +black . +``` + +## Troubleshooting + +If the workflow fails, check the following: + +1. **Linting errors**: Fix any issues reported by Ruff +2. Ensure all secrets are properly set +3. Verify that tests pass locally +4. Check if you have proper permissions on the repository +5. Make sure you have a PyPI account with publishing permissions \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..ae0a9c12 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,82 @@ +name: Publish to PyPI + +on: + workflow_dispatch: + inputs: + release_type: + description: 'Release type (patch, minor, major)' + required: true + default: 'patch' + type: choice + options: + - patch + - minor + - major + create_release: + description: 'Create GitHub Release' + required: true + default: true + type: boolean + +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine wheel setuptools bumpversion ruff + pip install -r requirements.txt + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + + - name: Run Ruff linting + run: | + # Run Ruff linter + ruff check . + + # Run Ruff formatter check (without modifying files) + ruff format --check . + + - name: Run tests + run: | + pytest + + - name: Update version + run: | + git config --local user.email "action@github.com" + git config --local user.name "GitHub Action" + bumpversion ${{ github.event.inputs.release_type }} + + - name: Build package + run: | + python -m build + + - name: Upload to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + twine upload dist/* + + - name: Push version bump + run: | + git push + git push --tags + + - name: Create GitHub Release + if: ${{ github.event.inputs.create_release == 'true' }} + uses: softprops/action-gh-release@v1 + with: + tag_name: v$(python setup.py --version) + name: Release v$(python setup.py --version) + generate_release_notes: true \ No newline at end of file diff --git a/examples/example.py b/examples/example.py index 0cb9f051..e8ce17f5 100644 --- a/examples/example.py +++ b/examples/example.py @@ -86,7 +86,10 @@ async def main(): if len(observed) > 0: element = observed[0] console.print("✅ [success]Found element:[/] News button") + console.print(f"\n▶️ [highlight] Performing action on observed element") await page.act(element) + console.print("✅ [success]Performing Action:[/] Action completed successfully") + else: console.print("❌ [error]No element found[/]") diff --git a/stagehand/__init__.py b/stagehand/__init__.py index 63b03284..59db66ce 100644 --- a/stagehand/__init__.py +++ b/stagehand/__init__.py @@ -1,4 +1,4 @@ from .client import Stagehand -__version__ = "0.1.0" +__version__ = "0.3.0" __all__ = ["Stagehand"] From 28a51bd433f9f4dc08c4850f93f9e2e90d5763d8 Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 10 Mar 2025 12:07:10 -0700 Subject: [PATCH 2/2] fixed most linting errors --- evals/act/google_jobs.py | 28 ++++---- evals/extract/extract_press_releases.py | 8 ++- evals/init_stagehand.py | 11 +-- examples/example.py | 2 +- examples/example_sync.py | 5 +- format.sh | 21 ++++++ pyproject.toml | 41 ++++++++++- requirements-dev.txt | 3 +- stagehand/__init__.py | 2 - stagehand/base.py | 56 ++++++++++----- stagehand/client.py | 42 +++++++----- stagehand/page.py | 18 +++-- stagehand/schemas.py | 30 +++++---- stagehand/sync/__init__.py | 2 +- stagehand/sync/client.py | 45 ++++++++----- stagehand/sync/page.py | 90 ++++++++++++++----------- stagehand/utils.py | 39 ++++++----- tests/functional/test_sync_client.py | 10 +-- 18 files changed, 294 insertions(+), 159 deletions(-) create mode 100755 format.sh diff --git a/evals/act/google_jobs.py b/evals/act/google_jobs.py index 5b78122d..f79b2bd1 100644 --- a/evals/act/google_jobs.py +++ b/evals/act/google_jobs.py @@ -1,6 +1,6 @@ import asyncio import traceback -from typing import Any, Dict, Optional +from typing import Any, Optional, dict from pydantic import BaseModel @@ -10,23 +10,24 @@ class Qualifications(BaseModel): degree: Optional[str] = None - yearsOfExperience: Optional[float] = None # Representing the number + years_of_experience: Optional[float] = None # Representing the number class JobDetails(BaseModel): - applicationDeadline: Optional[str] = None - minimumQualifications: Qualifications - preferredQualifications: Qualifications + application_deadline: Optional[str] = None + minimum_qualifications: Qualifications + preferred_qualifications: Qualifications -def is_job_details_valid(details: Dict[str, Any]) -> bool: +def is_job_details_valid(details: dict[str, Any]) -> bool: """ Validates that each top-level field in the extracted job details is not None. - For nested dictionary values, each sub-value must be non-null and a string or a number. + For nested dictionary values, each sub-value must be non-null and a string + or a number. """ if not details: return False - for key, value in details.items(): + for _key, value in details.items(): if value is None: return False if isinstance(value, dict): @@ -53,9 +54,9 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict: 4. Extracting job posting details using an AI-driven extraction schema. The extraction schema requires: - - applicationDeadline: The opening date until which applications are accepted. - - minimumQualifications: An object with degree and yearsOfExperience. - - preferredQualifications: An object with degree and yearsOfExperience. + - application_deadline: The opening date until which applications are accepted. + - minimum_qualifications: An object with degree and years_of_experience. + - preferred_qualifications: An object with degree and years_of_experience. Returns a dictionary containing: - _success (bool): Whether valid job details were extracted. @@ -90,8 +91,9 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict: job_details = await stagehand.page.extract( ExtractOptions( instruction=( - "Extract the following details from the job posting: application deadline, " - "minimum qualifications (degree and years of experience), and preferred qualifications " + "Extract the following details from the job posting: " + "application deadline, minimum qualifications " + "(degree and years of experience), and preferred qualifications " "(degree and years of experience)" ), schemaDefinition=JobDetails.model_json_schema(), diff --git a/evals/extract/extract_press_releases.py b/evals/extract/extract_press_releases.py index 800fed5f..89d98181 100644 --- a/evals/extract/extract_press_releases.py +++ b/evals/extract/extract_press_releases.py @@ -19,7 +19,8 @@ class PressReleases(BaseModel): async def extract_press_releases(model_name: str, logger, use_text_extract: bool): """ - Extract press releases from the dummy press releases page using the Stagehand client. + Extract press releases from the dummy press releases page using the Stagehand + client. Args: model_name (str): Name of the AI model to use. @@ -56,7 +57,10 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool # TODO - FAILING - extract is likely timing out raw_result = await stagehand.page.extract( ExtractOptions( - instruction="extract the title and corresponding publish date of EACH AND EVERY press releases on this page. DO NOT MISS ANY PRESS RELEASES.", + instruction=( + "extract the title and corresponding publish date of EACH AND EVERY " + "press releases on this page. DO NOT MISS ANY PRESS RELEASES." + ), schemaDefinition=PressReleases.model_json_schema(), useTextExtract=use_text_extract, ) diff --git a/evals/init_stagehand.py b/evals/init_stagehand.py index ed93966f..ac710983 100644 --- a/evals/init_stagehand.py +++ b/evals/init_stagehand.py @@ -6,10 +6,12 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3000): """ - Initialize a Stagehand client with the given model name, logger, and DOM settle timeout. + Initialize a Stagehand client with the given model name, logger, and DOM settle + timeout. - This function creates a configuration from environment variables, initializes the Stagehand client, - and returns a tuple of (stagehand, init_response). The init_response contains debug and session URLs. + This function creates a configuration from environment variables, initializes + the Stagehand client, and returns a tuple of (stagehand, init_response). + The init_response contains debug and session URLs. Args: model_name (str): The name of the AI model to use. @@ -37,7 +39,8 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3 model_client_options={"apiKey": os.getenv("MODEL_API_KEY")}, ) - # Create a Stagehand client with the configuration; server_url is taken from environment variables. + # Create a Stagehand client with the configuration; server_url is taken from + # environment variables. stagehand = Stagehand( config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2 ) diff --git a/examples/example.py b/examples/example.py index dd53f5e8..9eb61aad 100644 --- a/examples/example.py +++ b/examples/example.py @@ -87,7 +87,7 @@ async def main(): if len(observed) > 0: element = observed[0] console.print("✅ [success]Found element:[/] News button") - console.print(f"\n▶️ [highlight] Performing action on observed element") + console.print("\n▶️ [highlight] Performing action on observed element") await page.act(element) console.print("✅ [success]Performing Action:[/] Action completed successfully") diff --git a/examples/example_sync.py b/examples/example_sync.py index 896ac846..625aee30 100644 --- a/examples/example_sync.py +++ b/examples/example_sync.py @@ -6,8 +6,8 @@ from rich.panel import Panel from rich.theme import Theme -from stagehand.sync import Stagehand from stagehand.config import StagehandConfig +from stagehand.sync import Stagehand # Create a custom theme for consistent styling custom_theme = Theme( @@ -60,6 +60,7 @@ def main(): ) import time + time.sleep(2) console.print("\n▶️ [highlight] Navigating[/] to Google") @@ -112,4 +113,4 @@ def main(): padding=(1, 10), ), ) - main() \ No newline at end of file + main() diff --git a/format.sh b/format.sh new file mode 100755 index 00000000..fb408fd1 --- /dev/null +++ b/format.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Define source directories (adjust as needed) +SOURCE_DIRS="evals stagehand" + +# Apply Black formatting only to source directories +echo "Applying Black formatting..." +black $SOURCE_DIRS + +# Fix import sorting (addresses I001 errors) +echo "Sorting imports..." +isort $SOURCE_DIRS + +# Apply Ruff with autofix for remaining issues +echo "Applying Ruff autofixes..." +ruff check --fix $SOURCE_DIRS + +echo "Checking for remaining issues..." +ruff check $SOURCE_DIRS + +echo "Done! Code has been formatted and linted." \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4263a1f4..31baa275 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,8 @@ [tool.ruff] # Enable flake8-comprehensions, flake8-bugbear, naming, etc. select = ["E", "F", "B", "C4", "UP", "N", "I", "C"] -ignore = [] +# Ignore line length errors - let Black handle those +ignore = ["E501"] # Same as Black line-length = 88 @@ -46,4 +47,40 @@ classmethod-decorators = ["classmethod", "validator"] # Add more customizations if needed [tool.ruff.lint.pydocstyle] -convention = "google" \ No newline at end of file +convention = "google" + +# Black configuration +[tool.black] +line-length = 88 +target-version = ["py39"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | __pycache__ + | python-sdk +)/ +''' +# Ensure Black will wrap long strings and docstrings +skip-string-normalization = false +preview = true + +# isort configuration to work with Black +[tool.isort] +profile = "black" +line_length = 88 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +skip_gitignore = true +skip_glob = ["**/venv/**", "**/.venv/**", "**/__pycache__/**"] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 08d8882e..d6b18229 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,5 @@ pytest-mock>=3.10.0 pytest-cov>=4.1.0 black>=23.3.0 isort>=5.12.0 -mypy>=1.3.0 \ No newline at end of file +mypy>=1.3.0 +ruff \ No newline at end of file diff --git a/stagehand/__init__.py b/stagehand/__init__.py index 9536c45d..a553d9bc 100644 --- a/stagehand/__init__.py +++ b/stagehand/__init__.py @@ -2,8 +2,6 @@ from .config import StagehandConfig from .page import StagehandPage - __version__ = "0.2.2" __all__ = ["Stagehand", "StagehandConfig", "StagehandPage"] - diff --git a/stagehand/base.py b/stagehand/base.py index 61dff7c1..7d44f93c 100644 --- a/stagehand/base.py +++ b/stagehand/base.py @@ -1,13 +1,12 @@ +import logging +import os +import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Union -from playwright.async_api import Page +from typing import Any, Callable, Optional from .config import StagehandConfig from .page import StagehandPage from .utils import default_log_handler -import os -import time -import logging logger = logging.getLogger(__name__) @@ -17,6 +16,7 @@ class StagehandBase(ABC): Base class for Stagehand client implementations. Defines the common interface and functionality for both sync and async versions. """ + def __init__( self, config: Optional[StagehandConfig] = None, @@ -25,14 +25,14 @@ def __init__( browserbase_api_key: Optional[str] = None, browserbase_project_id: Optional[str] = None, model_api_key: Optional[str] = None, - on_log: Optional[Callable[[Dict[str, Any]], Any]] = default_log_handler, + on_log: Optional[Callable[[dict[str, Any]], Any]] = default_log_handler, verbose: int = 1, model_name: Optional[str] = None, dom_settle_timeout_ms: Optional[int] = None, debug_dom: Optional[bool] = None, timeout_settings: Optional[float] = None, stream_response: Optional[bool] = None, - model_client_options: Optional[Dict[str, Any]] = None, + model_client_options: Optional[dict[str, Any]] = None, ): """ Initialize the Stagehand client with common configuration. @@ -40,15 +40,31 @@ def __init__( self.server_url = server_url or os.getenv("STAGEHAND_SERVER_URL") if config: - self.browserbase_api_key = config.api_key or browserbase_api_key or os.getenv("BROWSERBASE_API_KEY") - self.browserbase_project_id = config.project_id or browserbase_project_id or os.getenv("BROWSERBASE_PROJECT_ID") + self.browserbase_api_key = ( + config.api_key + or browserbase_api_key + or os.getenv("BROWSERBASE_API_KEY") + ) + self.browserbase_project_id = ( + config.project_id + or browserbase_project_id + or os.getenv("BROWSERBASE_PROJECT_ID") + ) self.session_id = config.browserbase_session_id or session_id self.model_name = config.model_name or model_name - self.dom_settle_timeout_ms = config.dom_settle_timeout_ms or dom_settle_timeout_ms - self.debug_dom = config.debug_dom if config.debug_dom is not None else debug_dom + self.dom_settle_timeout_ms = ( + config.dom_settle_timeout_ms or dom_settle_timeout_ms + ) + self.debug_dom = ( + config.debug_dom if config.debug_dom is not None else debug_dom + ) else: - self.browserbase_api_key = browserbase_api_key or os.getenv("BROWSERBASE_API_KEY") - self.browserbase_project_id = browserbase_project_id or os.getenv("BROWSERBASE_PROJECT_ID") + self.browserbase_api_key = browserbase_api_key or os.getenv( + "BROWSERBASE_API_KEY" + ) + self.browserbase_project_id = browserbase_project_id or os.getenv( + "BROWSERBASE_PROJECT_ID" + ) self.session_id = session_id self.model_name = model_name self.dom_settle_timeout_ms = dom_settle_timeout_ms @@ -61,7 +77,9 @@ def __init__( self.model_client_options["apiKey"] = self.model_api_key # Handle streaming response setting directly - self.streamed_response = stream_response if stream_response is not None else True + self.streamed_response = ( + stream_response if stream_response is not None else True + ) self.on_log = on_log self.verbose = verbose @@ -74,9 +92,13 @@ def __init__( # Validate essential fields if session_id was provided if self.session_id: if not self.browserbase_api_key: - raise ValueError("browserbase_api_key is required (or set BROWSERBASE_API_KEY in env).") + raise ValueError( + "browserbase_api_key is required (or set BROWSERBASE_API_KEY in env)." + ) if not self.browserbase_project_id: - raise ValueError("browserbase_project_id is required (or set BROWSERBASE_PROJECT_ID in env).") + raise ValueError( + "browserbase_project_id is required (or set BROWSERBASE_PROJECT_ID in env)." + ) @abstractmethod def init(self): @@ -106,4 +128,4 @@ def _log(self, message: str, level: int = 1): elif level == 2: logger.warning(formatted_msg) else: - logger.debug(formatted_msg) \ No newline at end of file + logger.debug(formatted_msg) diff --git a/stagehand/client.py b/stagehand/client.py index 12441151..ab444f82 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -3,16 +3,16 @@ import logging import time from collections.abc import Awaitable -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import httpx from dotenv import load_dotenv from playwright.async_api import async_playwright +from .base import StagehandBase from .config import StagehandConfig from .page import StagehandPage -from .utils import default_log_handler, convert_dict_keys_to_camel_case -from .base import StagehandBase +from .utils import convert_dict_keys_to_camel_case, default_log_handler load_dotenv() @@ -39,7 +39,7 @@ def __init__( browserbase_project_id: Optional[str] = None, model_api_key: Optional[str] = None, on_log: Optional[ - Callable[[Dict[str, Any]], Awaitable[None]] + Callable[[dict[str, Any]], Awaitable[None]] ] = default_log_handler, verbose: int = 1, model_name: Optional[str] = None, @@ -47,7 +47,7 @@ def __init__( debug_dom: Optional[bool] = None, httpx_client: Optional[httpx.AsyncClient] = None, timeout_settings: Optional[httpx.Timeout] = None, - model_client_options: Optional[Dict[str, Any]] = None, + model_client_options: Optional[dict[str, Any]] = None, stream_response: Optional[bool] = None, ): """ @@ -60,14 +60,14 @@ def __init__( browserbase_api_key (Optional[str]): Your Browserbase API key. browserbase_project_id (Optional[str]): Your Browserbase project ID. model_api_key (Optional[str]): Your model API key (e.g. OpenAI, Anthropic, etc.). - on_log (Optional[Callable[[Dict[str, Any]], Awaitable[None]]]): Async callback for log messages from the server. + on_log (Optional[Callable[[dict[str, Any]], Awaitable[None]]]): Async callback for log messages from the server. verbose (int): Verbosity level for logs. model_name (Optional[str]): Model name to use when creating a new session. dom_settle_timeout_ms (Optional[int]): Additional time for the DOM to settle (in ms). debug_dom (Optional[bool]): Whether to enable DOM debugging mode. httpx_client (Optional[httpx.AsyncClient]): Optional custom httpx.AsyncClient instance. timeout_settings (Optional[httpx.Timeout]): Optional custom timeout settings for httpx. - model_client_options (Optional[Dict[str, Any]]): Optional model client options. + model_client_options (Optional[dict[str, Any]]): Optional model client options. stream_response (Optional[bool]): Whether to stream responses from the server. """ super().__init__( @@ -86,7 +86,7 @@ def __init__( stream_response=stream_response, model_client_options=model_client_options, ) - + self.httpx_client = httpx_client self.timeout_settings = timeout_settings or httpx.Timeout( connect=180.0, @@ -291,7 +291,7 @@ async def _create_session(self): "verbose": self.verbose, "debugDom": self.debug_dom, } - + if hasattr(self, "model_client_options") and self.model_client_options: payload["modelClientOptions"] = self.model_client_options @@ -318,7 +318,7 @@ async def _create_session(self): self.session_id = data["data"]["sessionId"] - async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: + async def _execute(self, method: str, payload: dict[str, Any]) -> Any: """ Internal helper to call /sessions/{session_id}/{method} with the given method and payload. Streams line-by-line, returning the 'result' from the final message (if any). @@ -332,14 +332,18 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: } if self.model_api_key: headers["x-model-api-key"] = self.model_api_key - + modified_payload = dict(payload) - if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload: + if ( + hasattr(self, "model_client_options") + and self.model_client_options + and "modelClientOptions" not in modified_payload + ): modified_payload["modelClientOptions"] = self.model_client_options - + # Convert snake_case keys to camelCase for the API modified_payload = convert_dict_keys_to_camel_case(modified_payload) - + client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) self._log(f"\n==== EXECUTING {method.upper()} ====", level=3) self._log( @@ -347,7 +351,7 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: ) self._log(f"Payload: {modified_payload}", level=3) self._log(f"Headers: {headers}", level=3) - + async with client: try: if not self.streamed_response: @@ -362,12 +366,14 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: error_message = error_text.decode("utf-8") self._log(f"Error: {error_message}", level=3) return None - + data = response.json() if data.get("success"): return data.get("data", {}).get("result") else: - raise RuntimeError(f"Request failed: {data.get('error', 'Unknown error')}") + raise RuntimeError( + f"Request failed: {data.get('error', 'Unknown error')}" + ) # Handle streaming response async with client.stream( @@ -430,7 +436,7 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: "Server connection closed without sending 'finished' message" ) - async def _handle_log(self, msg: Dict[str, Any]): + async def _handle_log(self, msg: dict[str, Any]): """ Handle a log line from the server. If on_log is set, call it asynchronously. """ diff --git a/stagehand/page.py b/stagehand/page.py index aeb878be..92d7ea89 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from playwright.async_api import Page @@ -32,7 +32,7 @@ async def goto( *, referer: Optional[str] = None, timeout: Optional[int] = None, - wait_until: Optional[str] = None + wait_until: Optional[str] = None, ): """ Navigate to URL using the Stagehand server. @@ -73,16 +73,20 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult - A string with the action command to be executed by the AI - An ActOptions object encapsulating the action command and optional parameters - An ObserveResult with selector and method fields for direct execution without LLM - + When an ObserveResult with both 'selector' and 'method' fields is provided, - the SDK will directly execute the action against the selector using the method + the SDK will directly execute the action against the selector using the method and arguments provided, bypassing the LLM processing. Returns: ActResult: The result from the Stagehand server's action execution. """ # Check if options is an ObserveResult with both selector and method - if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"): + if ( + isinstance(options, ObserveResult) + and hasattr(options, "selector") + and hasattr(options, "method") + ): # For ObserveResult, we directly pass it to the server which will # execute the method against the selector payload = options.model_dump(exclude_none=True, by_alias=True) @@ -101,7 +105,7 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult return ActResult(**result) return result - async def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResult]: + async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResult]: """ Make an AI observation via the Stagehand server. @@ -111,7 +115,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResu See `stagehand.schemas.ObserveOptions` for details on expected fields. Returns: - List[ObserveResult]: A list of observation results from the Stagehand server. + list[ObserveResult]: A list of observation results from the Stagehand server. """ # Convert string to ObserveOptions if needed if isinstance(options, str): diff --git a/stagehand/schemas.py b/stagehand/schemas.py index fd40dfe5..adf58db5 100644 --- a/stagehand/schemas.py +++ b/stagehand/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_serializer @@ -20,12 +20,12 @@ class AvailableModel(str, Enum): class StagehandBaseModel(BaseModel): """Base model for all Stagehand models with camelCase conversion support""" - + class Config: populate_by_name = True # Allow accessing fields by their Python name - alias_generator = lambda field_name: ''.join( - [field_name.split('_')[0]] + - [word.capitalize() for word in field_name.split('_')[1:]] + alias_generator = lambda field_name: "".join( + [field_name.split("_")[0]] + + [word.capitalize() for word in field_name.split("_")[1:]] ) # snake_case to camelCase @@ -35,13 +35,13 @@ class ActOptions(StagehandBaseModel): Attributes: action (str): The action command to be executed by the AI. - variables: Optional[Dict[str, str]] = None + variables: Optional[dict[str, str]] = None model_name: Optional[AvailableModel] = None slow_dom_based_act: Optional[bool] = None """ action: str = Field(..., description="The action command to be executed by the AI.") - variables: Optional[Dict[str, str]] = None + variables: Optional[dict[str, str]] = None model_name: Optional[AvailableModel] = None slow_dom_based_act: Optional[bool] = None @@ -69,7 +69,7 @@ class ExtractOptions(StagehandBaseModel): instruction (str): Instruction specifying what data to extract using AI. model_name: Optional[AvailableModel] = None selector: Optional[str] = None - schema_definition (Union[Dict[str, Any], Type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. + schema_definition (Union[dict[str, Any], type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable. use_text_extract: Optional[bool] = None """ @@ -81,16 +81,20 @@ class ExtractOptions(StagehandBaseModel): selector: Optional[str] = None # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method # to convert it to a JSON serializable dictionary before sending it with the extract command. - schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field( + schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( default=DEFAULT_EXTRACT_SCHEMA, description="A JSON schema or Pydantic model that defines the structure of the expected data.", ) use_text_extract: Optional[bool] = True - @field_serializer('schema_definition') - def serialize_schema_definition(self, schema_definition: Union[Dict[str, Any], Type[BaseModel]]) -> Dict[str, Any]: + @field_serializer("schema_definition") + def serialize_schema_definition( + self, schema_definition: Union[dict[str, Any], type[BaseModel]] + ) -> dict[str, Any]: """Serialize schema_definition to a JSON schema if it's a Pydantic model""" - if isinstance(schema_definition, type) and issubclass(schema_definition, BaseModel): + if isinstance(schema_definition, type) and issubclass( + schema_definition, BaseModel + ): return schema_definition.model_json_schema() return schema_definition @@ -152,7 +156,7 @@ class ObserveResult(StagehandBaseModel): ) backend_node_id: Optional[int] = None method: Optional[str] = None - arguments: Optional[List[str]] = None + arguments: Optional[list[str]] = None def __getitem__(self, key): """ diff --git a/stagehand/sync/__init__.py b/stagehand/sync/__init__.py index 5a01f936..d0e16fbe 100644 --- a/stagehand/sync/__init__.py +++ b/stagehand/sync/__init__.py @@ -1,3 +1,3 @@ from .client import Stagehand -__all__ = ["Stagehand"] \ No newline at end of file +__all__ = ["Stagehand"] diff --git a/stagehand/sync/client.py b/stagehand/sync/client.py index a44641be..9258ce1f 100644 --- a/stagehand/sync/client.py +++ b/stagehand/sync/client.py @@ -1,23 +1,24 @@ -import os -import time -import logging import json -from typing import Any, Dict, Optional, Callable +import logging +import time +from typing import Any, Callable, Optional import requests from playwright.sync_api import sync_playwright from ..base import StagehandBase from ..config import StagehandConfig +from ..utils import convert_dict_keys_to_camel_case, default_log_handler from .page import SyncStagehandPage -from ..utils import default_log_handler, convert_dict_keys_to_camel_case logger = logging.getLogger(__name__) + class Stagehand(StagehandBase): """ Synchronous implementation of the Stagehand client. """ + def __init__( self, config: Optional[StagehandConfig] = None, @@ -26,13 +27,13 @@ def __init__( browserbase_api_key: Optional[str] = None, browserbase_project_id: Optional[str] = None, model_api_key: Optional[str] = None, - on_log: Optional[Callable[[Dict[str, Any]], Any]] = default_log_handler, + on_log: Optional[Callable[[dict[str, Any]], Any]] = default_log_handler, verbose: int = 1, model_name: Optional[str] = None, dom_settle_timeout_ms: Optional[int] = None, debug_dom: Optional[bool] = None, timeout_settings: Optional[float] = None, - model_client_options: Optional[Dict[str, Any]] = None, + model_client_options: Optional[dict[str, Any]] = None, stream_response: Optional[bool] = None, ): super().__init__( @@ -162,7 +163,9 @@ def _check_server_health(self, timeout: int = 10): headers = { "x-bb-api-key": self.browserbase_api_key, } - resp = self._client.get(f"{self.server_url}/healthcheck", headers=headers) + resp = self._client.get( + f"{self.server_url}/healthcheck", headers=headers + ) if resp.status_code == 200: data = resp.json() if data.get("status") == "ok": @@ -174,7 +177,7 @@ def _check_server_health(self, timeout: int = 10): if time.time() - start > timeout: raise TimeoutError(f"Server not responding after {timeout} seconds.") - wait_time = min(2 ** attempt * 0.5, 5.0) + wait_time = min(2**attempt * 0.5, 5.0) time.sleep(wait_time) attempt += 1 @@ -195,7 +198,7 @@ def _create_session(self): "verbose": self.verbose, "debugDom": self.debug_dom, } - + if self.model_client_options: payload["modelClientOptions"] = self.model_client_options @@ -219,7 +222,7 @@ def _create_session(self): raise RuntimeError(f"Invalid response format: {resp.text}") self.session_id = data["data"]["sessionId"] - def _execute(self, method: str, payload: Dict[str, Any]) -> Any: + def _execute(self, method: str, payload: dict[str, Any]) -> Any: """ Execute a command synchronously. """ @@ -236,30 +239,34 @@ def _execute(self, method: str, payload: Dict[str, Any]) -> Any: modified_payload = dict(payload) if self.model_client_options and "modelClientOptions" not in modified_payload: modified_payload["modelClientOptions"] = self.model_client_options - + # Convert snake_case keys to camelCase for the API modified_payload = convert_dict_keys_to_camel_case(modified_payload) - + url = f"{self.server_url}/sessions/{self.session_id}/{method}" self._log(f"\n==== EXECUTING {method.upper()} ====", level=3) self._log(f"URL: {url}", level=3) self._log(f"Payload: {modified_payload}", level=3) self._log(f"Headers: {headers}", level=3) - + try: if not self.streamed_response: # For non-streaming responses, just return the final result - response = self._client.post(url, json=modified_payload, headers=headers) + response = self._client.post( + url, json=modified_payload, headers=headers + ) if response.status_code != 200: error_message = response.text self._log(f"Error: {error_message}", level=3) return None - + return response.json() # Return the raw response as the result # Handle streaming response self._log("Starting to process streaming response...", level=3) - response = self._client.post(url, json=modified_payload, headers=headers, stream=True) + response = self._client.post( + url, json=modified_payload, headers=headers, stream=True + ) if response.status_code != 200: error_message = response.text self._log(f"Error: {error_message}", level=3) @@ -301,4 +308,6 @@ def _execute(self, method: str, payload: Dict[str, Any]) -> Any: raise self._log("==== ERROR: No 'finished' message received ====", level=3) - raise RuntimeError("Server connection closed without sending 'finished' message") \ No newline at end of file + raise RuntimeError( + "Server connection closed without sending 'finished' message" + ) diff --git a/stagehand/sync/page.py b/stagehand/sync/page.py index d31261b9..efe0bb18 100644 --- a/stagehand/sync/page.py +++ b/stagehand/sync/page.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from playwright.sync_api import Page @@ -13,15 +13,17 @@ class SyncStagehandPage: - """Synchronous wrapper around Playwright Page that integrates with Stagehand server""" + """Synchronous wrapper around Playwright Page that integrates with Stagehand + server""" def __init__(self, page: Page, stagehand_client): """ - Initialize a SyncStagehandPage instance. + Initialize a SyncStagehandPage instance. - Args: - page (Page): The underlying Playwright page. - stagehand_client: The sync client used to interface with the Stagehand server. + Args: + page (Page): The underlying Playwright page. + stagehand_client: The sync client used to interface with the Stagehand + server. """ self.page = page self._stagehand = stagehand_client @@ -32,19 +34,20 @@ def goto( *, referer: Optional[str] = None, timeout: Optional[int] = None, - wait_until: Optional[str] = None + wait_until: Optional[str] = None, ): """ - Navigate to URL using the Stagehand server synchronously. + Navigate to URL using the Stagehand server synchronously. - Args: - url (str): The URL to navigate to. - referer (Optional[str]): Optional referer URL. - timeout (Optional[int]): Optional navigation timeout in milliseconds. - wait_until (Optional[str]): Optional wait condition; one of ('load', 'domcontentloaded', 'networkidle', 'commit'). + Args: + url (str): The URL to navigate to. + referer (Optional[str]): Optional referer URL. + timeout (Optional[int]): Optional navigation timeout in milliseconds. + wait_until (Optional[str]): Optional wait condition; one of ('load', + 'domcontentloaded', 'networkidle', 'commit'). - Returns: - The result from the Stagehand server's navigation execution. + Returns: + The result from the Stagehand server's navigation execution. """ options = {} if referer is not None: @@ -64,19 +67,25 @@ def goto( def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult: """ - Execute an AI action via the Stagehand server synchronously. - - Args: - options (Union[str, ActOptions, ObserveResult]): - - A string with the action command to be executed by the AI - - An ActOptions object encapsulating the action command and optional parameters - - An ObserveResult with selector and method fields for direct execution without LLM - - Returns: - ActResult: The result from the Stagehand server's action execution. + Execute an AI action via the Stagehand server synchronously. + + Args: + options (Union[str, ActOptions, ObserveResult]): + - A string with the action command to be executed by the AI + - An ActOptions object encapsulating the action command and optional + parameters + - An ObserveResult with selector and method fields for direct execution + without LLM + + Returns: + ActResult: The result from the Stagehand server's action execution. """ # Check if options is an ObserveResult with both selector and method - if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"): + if ( + isinstance(options, ObserveResult) + and hasattr(options, "selector") + and hasattr(options, "method") + ): # For ObserveResult, we directly pass it to the server which will # execute the method against the selector payload = options.model_dump(exclude_none=True, by_alias=True) @@ -93,16 +102,18 @@ def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult: return ActResult(**result) return result - def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResult]: + def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResult]: """ - Make an AI observation via the Stagehand server synchronously. + Make an AI observation via the Stagehand server synchronously. - Args: - options (Union[str, ObserveOptions]): Either a string with the observation instruction - or a Pydantic model encapsulating the observation instruction. + Args: + options (Union[str, ObserveOptions]): Either a string with the observation + instruction + or a Pydantic model encapsulating the observation instruction. - Returns: - List[ObserveResult]: A list of observation results from the Stagehand server. + Returns: + list[ObserveResult]: A list of observation results from the Stagehand + server. """ # Convert string to ObserveOptions if needed if isinstance(options, str): @@ -121,13 +132,14 @@ def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResult]: def extract(self, options: Union[str, ExtractOptions]) -> ExtractResult: """ - Extract data using AI via the Stagehand server synchronously. + Extract data using AI via the Stagehand server synchronously. - Args: - options (Union[str, ExtractOptions]): The extraction options describing what to extract and how. + Args: + options (Union[str, ExtractOptions]): The extraction options describing + what to extract and how. - Returns: - ExtractResult: The result from the Stagehand server's extraction execution. + Returns: + ExtractResult: The result from the Stagehand server's extraction execution. """ # Convert string to ExtractOptions if needed if isinstance(options, str): @@ -150,4 +162,4 @@ def __getattr__(self, name): Returns: The attribute from the underlying Playwright page. """ - return getattr(self.page, name) \ No newline at end of file + return getattr(self.page, name) diff --git a/stagehand/utils.py b/stagehand/utils.py index 459f4c53..c0d3f30c 100644 --- a/stagehand/utils.py +++ b/stagehand/utils.py @@ -1,14 +1,16 @@ -import asyncio import logging -from typing import Any, Dict +from typing import Any # Setup logging logger = logging.getLogger(__name__) handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) logger.addHandler(handler) -async def default_log_handler(log_data: Dict[str, Any]) -> None: + +async def default_log_handler(log_data: dict[str, Any]) -> None: """Default handler for log messages from the Stagehand server.""" level = log_data.get("level", "info").lower() message = log_data.get("message", "") @@ -20,38 +22,45 @@ async def default_log_handler(log_data: Dict[str, Any]) -> None: def snake_to_camel(snake_str: str) -> str: """ Convert a snake_case string to camelCase. - + Args: snake_str: The snake_case string to convert - + Returns: The converted camelCase string """ - components = snake_str.split('_') - return components[0] + ''.join(x.title() for x in components[1:]) + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) -def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]: +def convert_dict_keys_to_camel_case(data: dict[str, Any]) -> dict[str, Any]: """ Convert all keys in a dictionary from snake_case to camelCase. Works recursively for nested dictionaries. - + Args: data: Dictionary with snake_case keys - + Returns: Dictionary with camelCase keys """ result = {} - + for key, value in data.items(): if isinstance(value, dict): value = convert_dict_keys_to_camel_case(value) elif isinstance(value, list): - value = [convert_dict_keys_to_camel_case(item) if isinstance(item, dict) else item for item in value] - + value = [ + ( + convert_dict_keys_to_camel_case(item) + if isinstance(item, dict) + else item + ) + for item in value + ] + # Convert snake_case key to camelCase camel_key = snake_to_camel(key) result[camel_key] = value - + return result diff --git a/tests/functional/test_sync_client.py b/tests/functional/test_sync_client.py index 7bffec39..daeb7ab0 100644 --- a/tests/functional/test_sync_client.py +++ b/tests/functional/test_sync_client.py @@ -56,11 +56,13 @@ def test_act_command(stagehand_client): def test_observe_command(stagehand_client): """Test the observe command functionality.""" stagehand_client.page.goto("https://www.google.com") - result = stagehand_client.page.observe(ObserveOptions(instruction="find the search input box")) + result = stagehand_client.page.observe( + ObserveOptions(instruction="find the search input box") + ) assert result is not None assert len(result) > 0 - assert hasattr(result[0], 'selector') - assert hasattr(result[0], 'description') + assert hasattr(result[0], "selector") + assert hasattr(result[0], "description") def test_extract_command(stagehand_client): @@ -68,7 +70,7 @@ def test_extract_command(stagehand_client): stagehand_client.page.goto("https://www.google.com") result = stagehand_client.page.extract("title") assert result is not None - assert hasattr(result, 'extraction') + assert hasattr(result, "extraction") assert isinstance(result.extraction, str) assert result.extraction is not None