Skip to content

Commit d95974a

Browse files
authored
Fix stagehand.metrics (#176)
* Fix stagehand.metrics * address comments
1 parent 36ba981 commit d95974a

File tree

4 files changed

+170
-22
lines changed

4 files changed

+170
-22
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Fix stagehand.metrics on env:BROWSERBASE

pyproject.toml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@ description = "Python SDK for Stagehand"
99
readme = "README.md"
1010
classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent",]
1111
requires-python = ">=3.9"
12-
dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.83.0", "anthropic>=0.51.0", "litellm>=1.72.0",]
12+
dependencies = [
13+
"httpx>=0.24.0",
14+
"python-dotenv>=1.0.0",
15+
"pydantic>=1.10.0",
16+
"playwright>=1.42.1",
17+
"requests>=2.31.0",
18+
"browserbase>=1.4.0",
19+
"rich>=13.7.0",
20+
"openai>=1.83.0",
21+
"anthropic>=0.51.0",
22+
"litellm>=1.72.0",
23+
"nest-asyncio>=1.6.0",
24+
]
1325
[[project.authors]]
1426
name = "Browserbase, Inc."
1527
email = "support@browserbase.com"

stagehand/api.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from importlib.metadata import PackageNotFoundError, version
44
from typing import Any
55

6+
from .metrics import StagehandMetrics
67
from .utils import convert_dict_keys_to_camel_case
78

8-
__all__ = ["_create_session", "_execute"]
9+
__all__ = ["_create_session", "_execute", "_get_replay_metrics"]
910

1011

1112
async def _create_session(self):
@@ -177,3 +178,89 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
177178
except Exception as e:
178179
self.logger.error(f"[EXCEPTION] {str(e)}")
179180
raise
181+
182+
183+
async def _get_replay_metrics(self):
184+
"""
185+
Fetch replay metrics from the API and parse them into StagehandMetrics.
186+
"""
187+
188+
if not self.session_id:
189+
raise ValueError("session_id is required to fetch metrics.")
190+
191+
headers = {
192+
"x-bb-api-key": self.browserbase_api_key,
193+
"x-bb-project-id": self.browserbase_project_id,
194+
"Content-Type": "application/json",
195+
}
196+
197+
try:
198+
response = await self._client.get(
199+
f"{self.api_url}/sessions/{self.session_id}/replay",
200+
headers=headers,
201+
)
202+
203+
if response.status_code != 200:
204+
error_text = (
205+
await response.aread() if hasattr(response, "aread") else response.text
206+
)
207+
self.logger.error(
208+
f"[HTTP ERROR] Failed to fetch metrics. Status {response.status_code}: {error_text}"
209+
)
210+
raise RuntimeError(
211+
f"Failed to fetch metrics with status {response.status_code}: {error_text}"
212+
)
213+
214+
data = response.json()
215+
216+
if not data.get("success"):
217+
raise RuntimeError(
218+
f"Failed to fetch metrics: {data.get('error', 'Unknown error')}"
219+
)
220+
221+
# Parse the API data into StagehandMetrics format
222+
api_data = data.get("data", {})
223+
metrics = StagehandMetrics()
224+
225+
# Parse pages and their actions
226+
pages = api_data.get("pages", [])
227+
for page in pages:
228+
actions = page.get("actions", [])
229+
for action in actions:
230+
# Get method name and token usage
231+
method = action.get("method", "").lower()
232+
token_usage = action.get("tokenUsage", {})
233+
234+
if token_usage:
235+
input_tokens = token_usage.get("inputTokens", 0)
236+
output_tokens = token_usage.get("outputTokens", 0)
237+
time_ms = token_usage.get("timeMs", 0)
238+
239+
# Map method to metrics fields
240+
if method == "act":
241+
metrics.act_prompt_tokens += input_tokens
242+
metrics.act_completion_tokens += output_tokens
243+
metrics.act_inference_time_ms += time_ms
244+
elif method == "extract":
245+
metrics.extract_prompt_tokens += input_tokens
246+
metrics.extract_completion_tokens += output_tokens
247+
metrics.extract_inference_time_ms += time_ms
248+
elif method == "observe":
249+
metrics.observe_prompt_tokens += input_tokens
250+
metrics.observe_completion_tokens += output_tokens
251+
metrics.observe_inference_time_ms += time_ms
252+
elif method == "agent":
253+
metrics.agent_prompt_tokens += input_tokens
254+
metrics.agent_completion_tokens += output_tokens
255+
metrics.agent_inference_time_ms += time_ms
256+
257+
# Always update totals for any method with token usage
258+
metrics.total_prompt_tokens += input_tokens
259+
metrics.total_completion_tokens += output_tokens
260+
metrics.total_inference_time_ms += time_ms
261+
262+
return metrics
263+
264+
except Exception as e:
265+
self.logger.error(f"[EXCEPTION] Error fetching replay metrics: {str(e)}")
266+
raise

stagehand/main.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Optional
88

99
import httpx
10+
import nest_asyncio
1011
from dotenv import load_dotenv
1112
from playwright.async_api import (
1213
BrowserContext,
@@ -16,7 +17,7 @@
1617
from playwright.async_api import Page as PlaywrightPage
1718

1819
from .agent import Agent
19-
from .api import _create_session, _execute
20+
from .api import _create_session, _execute, _get_replay_metrics
2021
from .browser import (
2122
cleanup_browser_resources,
2223
connect_browserbase_browser,
@@ -206,7 +207,7 @@ def __init__(
206207
)
207208

208209
# Initialize metrics tracking
209-
self.metrics = StagehandMetrics()
210+
self._local_metrics = StagehandMetrics() # Internal storage for local metrics
210211
self._inference_start_time = 0 # To track inference time
211212

212213
# Validate env
@@ -372,26 +373,26 @@ def update_metrics(
372373
inference_time_ms: Time taken for inference in milliseconds
373374
"""
374375
if function_name == StagehandFunctionName.ACT:
375-
self.metrics.act_prompt_tokens += prompt_tokens
376-
self.metrics.act_completion_tokens += completion_tokens
377-
self.metrics.act_inference_time_ms += inference_time_ms
376+
self._local_metrics.act_prompt_tokens += prompt_tokens
377+
self._local_metrics.act_completion_tokens += completion_tokens
378+
self._local_metrics.act_inference_time_ms += inference_time_ms
378379
elif function_name == StagehandFunctionName.EXTRACT:
379-
self.metrics.extract_prompt_tokens += prompt_tokens
380-
self.metrics.extract_completion_tokens += completion_tokens
381-
self.metrics.extract_inference_time_ms += inference_time_ms
380+
self._local_metrics.extract_prompt_tokens += prompt_tokens
381+
self._local_metrics.extract_completion_tokens += completion_tokens
382+
self._local_metrics.extract_inference_time_ms += inference_time_ms
382383
elif function_name == StagehandFunctionName.OBSERVE:
383-
self.metrics.observe_prompt_tokens += prompt_tokens
384-
self.metrics.observe_completion_tokens += completion_tokens
385-
self.metrics.observe_inference_time_ms += inference_time_ms
384+
self._local_metrics.observe_prompt_tokens += prompt_tokens
385+
self._local_metrics.observe_completion_tokens += completion_tokens
386+
self._local_metrics.observe_inference_time_ms += inference_time_ms
386387
elif function_name == StagehandFunctionName.AGENT:
387-
self.metrics.agent_prompt_tokens += prompt_tokens
388-
self.metrics.agent_completion_tokens += completion_tokens
389-
self.metrics.agent_inference_time_ms += inference_time_ms
388+
self._local_metrics.agent_prompt_tokens += prompt_tokens
389+
self._local_metrics.agent_completion_tokens += completion_tokens
390+
self._local_metrics.agent_inference_time_ms += inference_time_ms
390391

391392
# Always update totals
392-
self.metrics.total_prompt_tokens += prompt_tokens
393-
self.metrics.total_completion_tokens += completion_tokens
394-
self.metrics.total_inference_time_ms += inference_time_ms
393+
self._local_metrics.total_prompt_tokens += prompt_tokens
394+
self._local_metrics.total_completion_tokens += completion_tokens
395+
self._local_metrics.total_inference_time_ms += inference_time_ms
395396

396397
def update_metrics_from_response(
397398
self,
@@ -426,9 +427,9 @@ def update_metrics_from_response(
426427
f"{completion_tokens} completion tokens, {time_ms}ms"
427428
)
428429
self.logger.debug(
429-
f"Total metrics: {self.metrics.total_prompt_tokens} prompt tokens, "
430-
f"{self.metrics.total_completion_tokens} completion tokens, "
431-
f"{self.metrics.total_inference_time_ms}ms"
430+
f"Total metrics: {self._local_metrics.total_prompt_tokens} prompt tokens, "
431+
f"{self._local_metrics.total_completion_tokens} completion tokens, "
432+
f"{self._local_metrics.total_inference_time_ms}ms"
432433
)
433434
else:
434435
# Try to extract from _hidden_params or other locations
@@ -736,7 +737,50 @@ def page(self) -> Optional[StagehandPage]:
736737

737738
return self._live_page_proxy
738739

740+
def __getattribute__(self, name):
741+
"""
742+
Intercept access to 'metrics' to fetch from API when use_api=True.
743+
"""
744+
if name == "metrics":
745+
use_api = (
746+
object.__getattribute__(self, "use_api")
747+
if hasattr(self, "use_api")
748+
else False
749+
)
750+
751+
if use_api:
752+
# Need to fetch from API
753+
try:
754+
# Get the _get_replay_metrics method
755+
get_replay_metrics = object.__getattribute__(
756+
self, "_get_replay_metrics"
757+
)
758+
759+
# Try to get current event loop
760+
try:
761+
asyncio.get_running_loop()
762+
# We're in an async context, need to handle this carefully
763+
# Create a new task and wait for it
764+
nest_asyncio.apply()
765+
return asyncio.run(get_replay_metrics())
766+
except RuntimeError:
767+
# No event loop running, we can use asyncio.run directly
768+
return asyncio.run(get_replay_metrics())
769+
except Exception as e:
770+
# Log error and return empty metrics
771+
logger = object.__getattribute__(self, "logger")
772+
if logger:
773+
logger.error(f"Failed to fetch metrics from API: {str(e)}")
774+
return StagehandMetrics()
775+
else:
776+
# Return local metrics
777+
return object.__getattribute__(self, "_local_metrics")
778+
779+
# For all other attributes, use normal behavior
780+
return object.__getattribute__(self, name)
781+
739782

740783
# Bind the imported API methods to the Stagehand class
741784
Stagehand._create_session = _create_session
742785
Stagehand._execute = _execute
786+
Stagehand._get_replay_metrics = _get_replay_metrics

0 commit comments

Comments
 (0)