diff --git a/README.md b/README.md index 48504178..fe46f547 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,56 @@ export STAGEHAND_SERVER_URL="url-of-stagehand-server" ## Quickstart -Below is a minimal example to get started with Stagehand using the new schema-based options: +Stagehand supports both synchronous and asynchronous usage. Here are examples for both approaches: + +### Synchronous Usage + +```python +import os +from stagehand.sync.client import Stagehand +from stagehand.schemas import ActOptions, ExtractOptions +from pydantic import BaseModel +from dotenv import load_dotenv + +load_dotenv() + +class DescriptionSchema(BaseModel): + description: str + +def main(): + # Create a Stagehand client - it will automatically create a new session if needed + stagehand = Stagehand( + model_name="gpt-4", # Optional: defaults are available from the server + ) + + # Initialize Stagehand and create a new session + stagehand.init() + print(f"Created new session: {stagehand.session_id}") + + # Navigate to a webpage using local Playwright controls + stagehand.page.goto("https://www.example.com") + print("Navigation complete.") + + # Perform an action using the AI (e.g. simulate a button click) + stagehand.page.act("click on the 'Quickstart' button") + + # Extract data from the page with schema validation + data = stagehand.page.extract( + ExtractOptions( + instruction="extract the description of the page", + schemaDefinition=DescriptionSchema.model_json_schema() + ) + ) + description = data.get("description") if isinstance(data, dict) else data.description + print("Extracted description:", description) + + stagehand.close() + +if __name__ == "__main__": + main() +``` + +### Asynchronous Usage ```python import asyncio @@ -146,7 +195,7 @@ async def main(): print("Navigation complete.") # Perform an action using the AI (e.g. simulate a button click) - await stagehand.page.act(ActOptions(action="click on the 'Quickstart' button")) + await stagehand.page.act("click on the 'Quickstart' button") # Extract data from the page with schema validation data = await stagehand.page.extract( @@ -164,19 +213,14 @@ if __name__ == "__main__": asyncio.run(main()) ``` - -## Running Evaluations +## Evals To test all evaluations, run the following command in your terminal: - -```bash -python evals/run_all_evals.py -``` +`python evals/run_all_evals.py` This script will dynamically discover and execute every evaluation module within the `evals` directory and print the results for each. - ## More Examples For further examples, check out the scripts in the `examples/` directory: @@ -197,6 +241,8 @@ Stagehand can be configured via environment variables or through a `StagehandCon - `model_name`: Optional model name for the AI. - `dom_settle_timeout_ms`: Additional time (in ms) to have the DOM settle. - `debug_dom`: Enable debug mode for DOM operations. +- `stream_response`: Whether to stream responses from the server (default: True). +- `timeout_settings`: Custom timeout settings for HTTP requests. Example using a unified configuration: @@ -220,18 +266,34 @@ config = StagehandConfig( - **AI-powered Browser Control**: Execute natural language instructions over a running browser. - **Validated Data Extraction**: Use JSON schemas (or Pydantic models) to extract and validate information from pages. -- **Async/Await Support**: Built using Python's asyncio, making it easy to build scalable web automation workflows. +- **Async/Sync Support**: Choose between asynchronous and synchronous APIs based on your needs. +- **Context Manager Support**: Automatic resource cleanup with async and sync context managers. - **Extensible**: Seamlessly extend Playwright functionality with AI enrichments. +- **Streaming Support**: Sreaming responses for better performance with long-running operations. Default True. ## Requirements - Python 3.7+ -- httpx -- asyncio +- httpx (for async client) +- requests (for sync client) +- asyncio (for async client) - pydantic - python-dotenv (optional, for .env support) - playwright +## Contributing + +### Running Tests + +The project uses pytest for testing. To run the tests: + +```bash +# Install development dependencies +pip install -r requirements-dev.txt + +chmod +x run_tests.sh && ./run_tests.sh +``` + ## License MIT License (c) 2025 Browserbase, Inc. diff --git a/examples/example_sync.py b/examples/example_sync.py new file mode 100644 index 00000000..896ac846 --- /dev/null +++ b/examples/example_sync.py @@ -0,0 +1,115 @@ +import logging +import os + +from dotenv import load_dotenv +from rich.console import Console +from rich.panel import Panel +from rich.theme import Theme + +from stagehand.sync import Stagehand +from stagehand.config import StagehandConfig + +# Create a custom theme for consistent styling +custom_theme = Theme( + { + "info": "cyan", + "success": "green", + "warning": "yellow", + "error": "red bold", + "highlight": "magenta", + "url": "blue underline", + } +) + +# Create a Rich console instance with our theme +console = Console(theme=custom_theme) + +load_dotenv() + +# Configure logging with Rich handler +logging.basicConfig( + level=logging.WARNING, # Feel free to change this to INFO or DEBUG to see more logs + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def main(): + # Build a unified configuration object for Stagehand + config = StagehandConfig( + env="BROWSERBASE", + api_key=os.getenv("BROWSERBASE_API_KEY"), + project_id=os.getenv("BROWSERBASE_PROJECT_ID"), + headless=False, + dom_settle_timeout_ms=3000, + model_name="gpt-4o", + model_client_options={"apiKey": os.getenv("MODEL_API_KEY")}, + ) + + # Create a Stagehand client using the configuration object. + stagehand = Stagehand( + config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2 + ) + + # Initialize - this creates a new session automatically. + console.print("\n🚀 [info]Initializing Stagehand...[/]") + stagehand.init() + console.print(f"\n[yellow]Created new session:[/] {stagehand.session_id}") + console.print( + f"🌐 [white]View your live browser:[/] [url]https://www.browserbase.com/sessions/{stagehand.session_id}[/]" + ) + + import time + time.sleep(2) + + console.print("\n▶️ [highlight] Navigating[/] to Google") + stagehand.page.goto("https://google.com/") + console.print("✅ [success]Navigated to Google[/]") + + console.print("\n▶️ [highlight] Clicking[/] on About link") + # Click on the "About" link using Playwright + stagehand.page.get_by_role("link", name="About", exact=True).click() + console.print("✅ [success]Clicked on About link[/]") + + time.sleep(2) + console.print("\n▶️ [highlight] Navigating[/] back to Google") + stagehand.page.goto("https://google.com/") + console.print("✅ [success]Navigated back to Google[/]") + + console.print("\n▶️ [highlight] Performing action:[/] search for openai") + stagehand.page.act("search for openai") + stagehand.page.keyboard.press("Enter") + console.print("✅ [success]Performing Action:[/] Action completed successfully") + + console.print("\n▶️ [highlight] Observing page[/] for news button") + observed = stagehand.page.observe("find the news button on the page") + if len(observed) > 0: + element = observed[0] + console.print("✅ [success]Found element:[/] News button") + stagehand.page.act(element) + else: + console.print("❌ [error]No element found[/]") + + console.print("\n▶️ [highlight] Extracting[/] first search result") + data = stagehand.page.extract("extract the first result from the search") + console.print("📊 [info]Extracted data:[/]") + console.print_json(f"{data.model_dump_json()}") + + # Close the session + console.print("\n⏹️ [warning]Closing session...[/]") + stagehand.close() + console.print("✅ [success]Session closed successfully![/]") + console.rule("[bold]End of Example[/]") + + +if __name__ == "__main__": + # Add a fancy header + console.print( + "\n", + Panel.fit( + "[light_gray]Stagehand 🤘 Python Sync Example[/]", + border_style="green", + padding=(1, 10), + ), + ) + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 641e9eeb..561c8dd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ asyncio>=3.4.3 python-dotenv>=1.0.0 pydantic>=1.10.0 playwright>=1.42.1 +requests>=2.31.0 rich \ No newline at end of file diff --git a/stagehand/__init__.py b/stagehand/__init__.py index 63b03284..a553d9bc 100644 --- a/stagehand/__init__.py +++ b/stagehand/__init__.py @@ -1,4 +1,7 @@ from .client import Stagehand +from .config import StagehandConfig +from .page import StagehandPage -__version__ = "0.1.0" -__all__ = ["Stagehand"] +__version__ = "0.2.2" + +__all__ = ["Stagehand", "StagehandConfig", "StagehandPage"] diff --git a/stagehand/base.py b/stagehand/base.py new file mode 100644 index 00000000..61dff7c1 --- /dev/null +++ b/stagehand/base.py @@ -0,0 +1,109 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, Union +from playwright.async_api import Page + +from .config import StagehandConfig +from .page import StagehandPage +from .utils import default_log_handler +import os +import time +import logging + +logger = logging.getLogger(__name__) + + +class StagehandBase(ABC): + """ + Base class for Stagehand client implementations. + Defines the common interface and functionality for both sync and async versions. + """ + def __init__( + self, + config: Optional[StagehandConfig] = None, + server_url: Optional[str] = None, + session_id: Optional[str] = None, + browserbase_api_key: Optional[str] = None, + browserbase_project_id: Optional[str] = None, + model_api_key: Optional[str] = None, + on_log: Optional[Callable[[Dict[str, Any]], Any]] = default_log_handler, + verbose: int = 1, + model_name: Optional[str] = None, + dom_settle_timeout_ms: Optional[int] = None, + debug_dom: Optional[bool] = None, + timeout_settings: Optional[float] = None, + stream_response: Optional[bool] = None, + model_client_options: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Stagehand client with common configuration. + """ + self.server_url = server_url or os.getenv("STAGEHAND_SERVER_URL") + + if config: + self.browserbase_api_key = config.api_key or browserbase_api_key or os.getenv("BROWSERBASE_API_KEY") + self.browserbase_project_id = config.project_id or browserbase_project_id or os.getenv("BROWSERBASE_PROJECT_ID") + self.session_id = config.browserbase_session_id or session_id + self.model_name = config.model_name or model_name + self.dom_settle_timeout_ms = config.dom_settle_timeout_ms or dom_settle_timeout_ms + self.debug_dom = config.debug_dom if config.debug_dom is not None else debug_dom + else: + self.browserbase_api_key = browserbase_api_key or os.getenv("BROWSERBASE_API_KEY") + self.browserbase_project_id = browserbase_project_id or os.getenv("BROWSERBASE_PROJECT_ID") + self.session_id = session_id + self.model_name = model_name + self.dom_settle_timeout_ms = dom_settle_timeout_ms + self.debug_dom = debug_dom + + # Handle model-related settings directly + self.model_api_key = model_api_key or os.getenv("MODEL_API_KEY") + self.model_client_options = model_client_options or {} + if self.model_api_key and "apiKey" not in self.model_client_options: + self.model_client_options["apiKey"] = self.model_api_key + + # Handle streaming response setting directly + self.streamed_response = stream_response if stream_response is not None else True + + self.on_log = on_log + self.verbose = verbose + self.timeout_settings = timeout_settings or 180.0 + + self._initialized = False + self._closed = False + self.page: Optional[StagehandPage] = None + + # Validate essential fields if session_id was provided + if self.session_id: + if not self.browserbase_api_key: + raise ValueError("browserbase_api_key is required (or set BROWSERBASE_API_KEY in env).") + if not self.browserbase_project_id: + raise ValueError("browserbase_project_id is required (or set BROWSERBASE_PROJECT_ID in env).") + + @abstractmethod + def init(self): + """ + Initialize the Stagehand client. + Must be implemented by subclasses. + """ + pass + + @abstractmethod + def close(self): + """ + Clean up resources. + Must be implemented by subclasses. + """ + pass + + def _log(self, message: str, level: int = 1): + """ + Internal logging helper that maps verbosity to logging levels. + """ + if self.verbose >= level: + timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + formatted_msg = f"{timestamp}::[stagehand] {message}" + if level == 1: + logger.info(formatted_msg) + elif level == 2: + logger.warning(formatted_msg) + else: + logger.debug(formatted_msg) \ No newline at end of file diff --git a/stagehand/client.py b/stagehand/client.py index 8474b869..12441151 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import os import time from collections.abc import Awaitable from typing import Any, Callable, Dict, Optional @@ -13,13 +12,14 @@ from .config import StagehandConfig from .page import StagehandPage from .utils import default_log_handler, convert_dict_keys_to_camel_case +from .base import StagehandBase load_dotenv() logger = logging.getLogger(__name__) -class Stagehand: +class Stagehand(StagehandBase): """ Python client for interacting with a running Stagehand server and Browserbase remote headless browser. @@ -48,6 +48,7 @@ def __init__( httpx_client: Optional[httpx.AsyncClient] = None, timeout_settings: Optional[httpx.Timeout] = None, model_client_options: Optional[Dict[str, Any]] = None, + stream_response: Optional[bool] = None, ): """ Initialize the Stagehand client. @@ -67,50 +68,25 @@ def __init__( httpx_client (Optional[httpx.AsyncClient]): Optional custom httpx.AsyncClient instance. timeout_settings (Optional[httpx.Timeout]): Optional custom timeout settings for httpx. model_client_options (Optional[Dict[str, Any]]): Optional model client options. + stream_response (Optional[bool]): Whether to stream responses from the server. """ - self.server_url = server_url or os.getenv("STAGEHAND_SERVER_URL") - - if config: - self.browserbase_api_key = ( - config.api_key - or browserbase_api_key - or os.getenv("BROWSERBASE_API_KEY") - ) - self.browserbase_project_id = ( - config.project_id - or browserbase_project_id - or os.getenv("BROWSERBASE_PROJECT_ID") - ) - self.model_api_key = os.getenv("MODEL_API_KEY") - self.session_id = config.browserbase_session_id or session_id - self.model_name = config.model_name or model_name - self.dom_settle_timeout_ms = ( - config.dom_settle_timeout_ms or dom_settle_timeout_ms - ) - self.debug_dom = ( - config.debug_dom if config.debug_dom is not None else debug_dom - ) - self._custom_logger = config.logger # For future integration if needed - # Additional config parameters available for future use: - self.headless = config.headless - self.enable_caching = config.enable_caching - self.model_client_options = model_client_options - else: - self.browserbase_api_key = browserbase_api_key or os.getenv( - "BROWSERBASE_API_KEY" - ) - self.browserbase_project_id = browserbase_project_id or os.getenv( - "BROWSERBASE_PROJECT_ID" - ) - self.model_api_key = model_api_key or os.getenv("MODEL_API_KEY") - self.session_id = session_id - self.model_name = model_name - self.dom_settle_timeout_ms = dom_settle_timeout_ms - self.debug_dom = debug_dom - self.model_client_options = model_client_options - - self.on_log = on_log - self.verbose = verbose + super().__init__( + config=config, + server_url=server_url, + session_id=session_id, + browserbase_api_key=browserbase_api_key, + browserbase_project_id=browserbase_project_id, + model_api_key=model_api_key, + on_log=on_log, + verbose=verbose, + model_name=model_name, + dom_settle_timeout_ms=dom_settle_timeout_ms, + debug_dom=debug_dom, + timeout_settings=timeout_settings, + stream_response=stream_response, + model_client_options=model_client_options, + ) + self.httpx_client = httpx_client self.timeout_settings = timeout_settings or httpx.Timeout( connect=180.0, @@ -118,7 +94,6 @@ def __init__( write=180.0, pool=180.0, ) - self.streamed_response = True # Default to True for streamed responses self._client: Optional[httpx.AsyncClient] = None self._playwright = None @@ -375,6 +350,26 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any: async with client: try: + if not self.streamed_response: + # For non-streaming responses, just return the final result + response = await client.post( + f"{self.server_url}/sessions/{self.session_id}/{method}", + json=modified_payload, + headers=headers, + ) + if response.status_code != 200: + error_text = await response.aread() + error_message = error_text.decode("utf-8") + self._log(f"Error: {error_message}", level=3) + return None + + data = response.json() + if data.get("success"): + return data.get("data", {}).get("result") + else: + raise RuntimeError(f"Request failed: {data.get('error', 'Unknown error')}") + + # Handle streaming response async with client.stream( "POST", f"{self.server_url}/sessions/{self.session_id}/{method}", diff --git a/stagehand/sync/__init__.py b/stagehand/sync/__init__.py new file mode 100644 index 00000000..5a01f936 --- /dev/null +++ b/stagehand/sync/__init__.py @@ -0,0 +1,3 @@ +from .client import Stagehand + +__all__ = ["Stagehand"] \ No newline at end of file diff --git a/stagehand/sync/client.py b/stagehand/sync/client.py new file mode 100644 index 00000000..a44641be --- /dev/null +++ b/stagehand/sync/client.py @@ -0,0 +1,304 @@ +import os +import time +import logging +import json +from typing import Any, Dict, Optional, Callable + +import requests +from playwright.sync_api import sync_playwright + +from ..base import StagehandBase +from ..config import StagehandConfig +from .page import SyncStagehandPage +from ..utils import default_log_handler, convert_dict_keys_to_camel_case + +logger = logging.getLogger(__name__) + +class Stagehand(StagehandBase): + """ + Synchronous implementation of the Stagehand client. + """ + def __init__( + self, + config: Optional[StagehandConfig] = None, + server_url: Optional[str] = None, + session_id: Optional[str] = None, + browserbase_api_key: Optional[str] = None, + browserbase_project_id: Optional[str] = None, + model_api_key: Optional[str] = None, + on_log: Optional[Callable[[Dict[str, Any]], Any]] = default_log_handler, + verbose: int = 1, + model_name: Optional[str] = None, + dom_settle_timeout_ms: Optional[int] = None, + debug_dom: Optional[bool] = None, + timeout_settings: Optional[float] = None, + model_client_options: Optional[Dict[str, Any]] = None, + stream_response: Optional[bool] = None, + ): + super().__init__( + config=config, + server_url=server_url, + session_id=session_id, + browserbase_api_key=browserbase_api_key, + browserbase_project_id=browserbase_project_id, + model_api_key=model_api_key, + on_log=on_log, + verbose=verbose, + model_name=model_name, + dom_settle_timeout_ms=dom_settle_timeout_ms, + debug_dom=debug_dom, + timeout_settings=timeout_settings, + stream_response=stream_response, + ) + self._client: Optional[requests.Session] = None + self._playwright = None + self._browser = None + self._context = None + self._playwright_page = None + self.model_client_options = model_client_options + self.streamed_response = True # Default to True for streamed responses + + def init(self): + """ + Initialize the Stagehand client synchronously. + """ + if self._initialized: + self._log("Stagehand is already initialized; skipping init()", level=3) + return + + self._log("Initializing Stagehand...", level=3) + + if not self._client: + self._client = requests.Session() + + # Check server health + self._check_server_health() + + # Create session if we don't have one + if not self.session_id: + self._create_session() + self._log(f"Created new session: {self.session_id}", level=3) + + # Start Playwright and connect to remote + self._log("Starting Playwright...", level=3) + self._playwright = sync_playwright().start() + + connect_url = ( + f"wss://connect.browserbase.com?apiKey={self.browserbase_api_key}" + f"&sessionId={self.session_id}" + ) + self._log(f"Connecting to remote browser at: {connect_url}", level=3) + self._browser = self._playwright.chromium.connect_over_cdp(connect_url) + self._log(f"Connected to remote browser: {self._browser}", level=3) + + # Access or create a context + existing_contexts = self._browser.contexts + self._log(f"Existing contexts: {len(existing_contexts)}", level=3) + if existing_contexts: + self._context = existing_contexts[0] + else: + self._log("Creating a new context...", level=3) + self._context = self._browser.new_context() + + # Access or create a page + existing_pages = self._context.pages + self._log(f"Existing pages: {len(existing_pages)}", level=3) + if existing_pages: + self._log("Using existing page", level=3) + self._playwright_page = existing_pages[0] + else: + self._log("Creating a new page...", level=3) + self._playwright_page = self._context.new_page() + + # Wrap with SyncStagehandPage + self._log("Wrapping Playwright page in SyncStagehandPage", level=3) + self.page = SyncStagehandPage(self._playwright_page, self) + + self._initialized = True + + def close(self): + """ + Clean up resources synchronously. + """ + if self._closed: + return + + self._log("Closing resources...", level=3) + + # End the session on the server if we have a session ID + if self.session_id: + try: + self._log(f"Ending session {self.session_id} on the server...", level=3) + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "Content-Type": "application/json", + } + self._execute("end", {"sessionId": self.session_id}) + self._log(f"Session {self.session_id} ended successfully", level=3) + except Exception as e: + self._log(f"Error ending session: {str(e)}", level=3) + + if self._playwright: + self._log("Stopping Playwright...", level=3) + self._playwright.stop() + self._playwright = None + + if self._client: + self._log("Closing the HTTP client...", level=3) + self._client.close() + self._client = None + + self._closed = True + + def _check_server_health(self, timeout: int = 10): + """ + Check server health synchronously with exponential backoff. + """ + start = time.time() + attempt = 0 + while True: + try: + headers = { + "x-bb-api-key": self.browserbase_api_key, + } + resp = self._client.get(f"{self.server_url}/healthcheck", headers=headers) + if resp.status_code == 200: + data = resp.json() + if data.get("status") == "ok": + self._log("Healthcheck passed. Server is running.", level=3) + return + except Exception as e: + self._log(f"Healthcheck error: {str(e)}", level=3) + + if time.time() - start > timeout: + raise TimeoutError(f"Server not responding after {timeout} seconds.") + + wait_time = min(2 ** attempt * 0.5, 5.0) + time.sleep(wait_time) + attempt += 1 + + def _create_session(self): + """ + Create a new session synchronously. + """ + if not self.browserbase_api_key: + raise ValueError("browserbase_api_key is required to create a session.") + if not self.browserbase_project_id: + raise ValueError("browserbase_project_id is required to create a session.") + if not self.model_api_key: + raise ValueError("model_api_key is required to create a session.") + + payload = { + "modelName": self.model_name, + "domSettleTimeoutMs": self.dom_settle_timeout_ms, + "verbose": self.verbose, + "debugDom": self.debug_dom, + } + + if self.model_client_options: + payload["modelClientOptions"] = self.model_client_options + + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "x-model-api-key": self.model_api_key, + "Content-Type": "application/json", + } + + resp = self._client.post( + f"{self.server_url}/sessions/start", + json=payload, + headers=headers, + ) + if resp.status_code != 200: + raise RuntimeError(f"Failed to create session: {resp.text}") + data = resp.json() + self._log(f"Session created: {data}", level=3) + if not data.get("success") or "sessionId" not in data.get("data", {}): + raise RuntimeError(f"Invalid response format: {resp.text}") + self.session_id = data["data"]["sessionId"] + + def _execute(self, method: str, payload: Dict[str, Any]) -> Any: + """ + Execute a command synchronously. + """ + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "Content-Type": "application/json", + "Connection": "keep-alive", + "x-stream-response": str(self.streamed_response).lower(), + } + if self.model_api_key: + headers["x-model-api-key"] = self.model_api_key + + modified_payload = dict(payload) + if self.model_client_options and "modelClientOptions" not in modified_payload: + modified_payload["modelClientOptions"] = self.model_client_options + + # Convert snake_case keys to camelCase for the API + modified_payload = convert_dict_keys_to_camel_case(modified_payload) + + url = f"{self.server_url}/sessions/{self.session_id}/{method}" + self._log(f"\n==== EXECUTING {method.upper()} ====", level=3) + self._log(f"URL: {url}", level=3) + self._log(f"Payload: {modified_payload}", level=3) + self._log(f"Headers: {headers}", level=3) + + try: + if not self.streamed_response: + # For non-streaming responses, just return the final result + response = self._client.post(url, json=modified_payload, headers=headers) + if response.status_code != 200: + error_message = response.text + self._log(f"Error: {error_message}", level=3) + return None + + return response.json() # Return the raw response as the result + + # Handle streaming response + self._log("Starting to process streaming response...", level=3) + response = self._client.post(url, json=modified_payload, headers=headers, stream=True) + if response.status_code != 200: + error_message = response.text + self._log(f"Error: {error_message}", level=3) + return None + + for line in response.iter_lines(decode_unicode=True): + if not line.strip(): + continue + + try: + if line.startswith("data: "): + line = line[6:] + + message = json.loads(line) + msg_type = message.get("type") + + if msg_type == "system": + status = message.get("data", {}).get("status") + if status == "finished": + result = message.get("data", {}).get("result") + self._log(f"FINISHED WITH RESULT: {result}", level=3) + return result + elif msg_type == "log": + log_msg = message.get("data", {}).get("message", "") + self._log(log_msg, level=3) + if self.on_log: + # For sync implementation, we just log the message directly + self._log(f"Log message: {log_msg}", level=3) + else: + self._log(f"Unknown message type: {msg_type}", level=3) + if self.on_log: + self._log(f"Unknown message: {message}", level=3) + + except json.JSONDecodeError: + self._log(f"Could not parse line as JSON: {line}", level=3) + continue + except Exception as e: + self._log(f"EXCEPTION IN _EXECUTE: {str(e)}") + raise + + self._log("==== ERROR: No 'finished' message received ====", level=3) + raise RuntimeError("Server connection closed without sending 'finished' message") \ No newline at end of file diff --git a/stagehand/sync/page.py b/stagehand/sync/page.py new file mode 100644 index 00000000..d31261b9 --- /dev/null +++ b/stagehand/sync/page.py @@ -0,0 +1,153 @@ +from typing import List, Optional, Union + +from playwright.sync_api import Page + +from ..schemas import ( + ActOptions, + ActResult, + ExtractOptions, + ExtractResult, + ObserveOptions, + ObserveResult, +) + + +class SyncStagehandPage: + """Synchronous wrapper around Playwright Page that integrates with Stagehand server""" + + def __init__(self, page: Page, stagehand_client): + """ + Initialize a SyncStagehandPage instance. + + Args: + page (Page): The underlying Playwright page. + stagehand_client: The sync client used to interface with the Stagehand server. + """ + self.page = page + self._stagehand = stagehand_client + + def goto( + self, + url: str, + *, + referer: Optional[str] = None, + timeout: Optional[int] = None, + wait_until: Optional[str] = None + ): + """ + Navigate to URL using the Stagehand server synchronously. + + Args: + url (str): The URL to navigate to. + referer (Optional[str]): Optional referer URL. + timeout (Optional[int]): Optional navigation timeout in milliseconds. + wait_until (Optional[str]): Optional wait condition; one of ('load', 'domcontentloaded', 'networkidle', 'commit'). + + Returns: + The result from the Stagehand server's navigation execution. + """ + options = {} + if referer is not None: + options["referer"] = referer + if timeout is not None: + options["timeout"] = timeout + if wait_until is not None: + options["wait_until"] = wait_until + options["waitUntil"] = wait_until + + payload = {"url": url} + if options: + payload["options"] = options + + result = self._stagehand._execute("navigate", payload) + return result + + def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult: + """ + Execute an AI action via the Stagehand server synchronously. + + Args: + options (Union[str, ActOptions, ObserveResult]): + - A string with the action command to be executed by the AI + - An ActOptions object encapsulating the action command and optional parameters + - An ObserveResult with selector and method fields for direct execution without LLM + + Returns: + ActResult: The result from the Stagehand server's action execution. + """ + # Check if options is an ObserveResult with both selector and method + if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"): + # For ObserveResult, we directly pass it to the server which will + # execute the method against the selector + payload = options.model_dump(exclude_none=True, by_alias=True) + # Convert string to ActOptions if needed + elif isinstance(options, str): + options = ActOptions(action=options) + payload = options.model_dump(exclude_none=True, by_alias=True) + # Otherwise, it should be an ActOptions object + else: + payload = options.model_dump(exclude_none=True, by_alias=True) + + result = self._stagehand._execute("act", payload) + if isinstance(result, dict): + return ActResult(**result) + return result + + def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResult]: + """ + Make an AI observation via the Stagehand server synchronously. + + Args: + options (Union[str, ObserveOptions]): Either a string with the observation instruction + or a Pydantic model encapsulating the observation instruction. + + Returns: + List[ObserveResult]: A list of observation results from the Stagehand server. + """ + # Convert string to ObserveOptions if needed + if isinstance(options, str): + options = ObserveOptions(instruction=options) + + payload = options.model_dump(exclude_none=True, by_alias=True) + result = self._stagehand._execute("observe", payload) + + # Convert raw result to list of ObserveResult models + if isinstance(result, list): + return [ObserveResult(**item) for item in result] + elif isinstance(result, dict): + # If single dict, wrap in list + return [ObserveResult(**result)] + return [] + + def extract(self, options: Union[str, ExtractOptions]) -> ExtractResult: + """ + Extract data using AI via the Stagehand server synchronously. + + Args: + options (Union[str, ExtractOptions]): The extraction options describing what to extract and how. + + Returns: + ExtractResult: The result from the Stagehand server's extraction execution. + """ + # Convert string to ExtractOptions if needed + if isinstance(options, str): + options = ExtractOptions(instruction=options) + + payload = options.model_dump(exclude_none=True, by_alias=True) + result = self._stagehand._execute("extract", payload) + if isinstance(result, dict): + return ExtractResult(**result) + return result + + # Forward other Page methods to underlying Playwright page + def __getattr__(self, name): + """ + Forward attribute lookups to the underlying Playwright page. + + Args: + name (str): Name of the attribute to access. + + Returns: + The attribute from the underlying Playwright page. + """ + return getattr(self.page, name) \ No newline at end of file diff --git a/tests/functional/test_sync_client.py b/tests/functional/test_sync_client.py new file mode 100644 index 00000000..7bffec39 --- /dev/null +++ b/tests/functional/test_sync_client.py @@ -0,0 +1,79 @@ +import os +import pytest +from dotenv import load_dotenv +from stagehand.sync.client import Stagehand +from stagehand.config import StagehandConfig +from stagehand.schemas import ActOptions, ObserveOptions, ExtractOptions + +# Load environment variables +load_dotenv() + + +@pytest.fixture +def stagehand_client(): + """Fixture to create and manage a Stagehand client instance.""" + config = StagehandConfig( + env=( + "BROWSERBASE" + if os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID") + else "LOCAL" + ), + api_key=os.getenv("BROWSERBASE_API_KEY"), + project_id=os.getenv("BROWSERBASE_PROJECT_ID"), + debug_dom=True, + headless=True, # Run tests in headless mode + dom_settle_timeout_ms=3000, + model_name="gpt-4o-mini", + model_client_options={"apiKey": os.getenv("MODEL_API_KEY")}, + ) + + client = Stagehand( + config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2 + ) + + # Initialize the client + client.init() + + yield client + + # Cleanup + client.close() + + +def test_navigation(stagehand_client): + """Test basic navigation functionality.""" + stagehand_client.page.goto("https://www.google.com") + # Add assertions based on the page state if needed + + +def test_act_command(stagehand_client): + """Test the act command functionality.""" + stagehand_client.page.goto("https://www.google.com") + stagehand_client.page.act(ActOptions(action="search for openai")) + # Add assertions based on the action result if needed + + +def test_observe_command(stagehand_client): + """Test the observe command functionality.""" + stagehand_client.page.goto("https://www.google.com") + result = stagehand_client.page.observe(ObserveOptions(instruction="find the search input box")) + assert result is not None + assert len(result) > 0 + assert hasattr(result[0], 'selector') + assert hasattr(result[0], 'description') + + +def test_extract_command(stagehand_client): + """Test the extract command functionality.""" + stagehand_client.page.goto("https://www.google.com") + result = stagehand_client.page.extract("title") + assert result is not None + assert hasattr(result, 'extraction') + assert isinstance(result.extraction, str) + assert result.extraction is not None + + +def test_session_management(stagehand_client): + """Test session management functionality.""" + assert stagehand_client.session_id is not None + assert isinstance(stagehand_client.session_id, str)