From 5f29b21022008c0b9baf1dce1405aef2c956f258 Mon Sep 17 00:00:00 2001 From: Sanveer Singh Osahan Date: Tue, 19 Aug 2025 11:52:41 +0530 Subject: [PATCH 1/7] Adding initial support for LLM Customisation (#132) * Adding support for LLM Customisation * Added back model_api_key config * Fixed README * Update README.md Co-authored-by: Miguel <36487034+miguelg719@users.noreply.github.com> * Update stagehand/llm/client.py --------- Co-authored-by: Miguel <36487034+miguelg719@users.noreply.github.com> --- stagehand/config.py | 4 +++ stagehand/llm/client.py | 2 +- stagehand/main.py | 11 ++++----- tests/unit/llm/test_llm_integration.py | 1 + tests/unit/test_client_api.py | 2 +- tests/unit/test_client_initialization.py | 31 +++++++++++++++++++++++- 6 files changed, 42 insertions(+), 9 deletions(-) diff --git a/stagehand/config.py b/stagehand/config.py index a577230d..9557bf28 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -20,6 +20,7 @@ class StagehandConfig(BaseModel): browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions. model_name (Optional[str]): Name of the model to use. model_api_key (Optional[str]): Model API key. + model_client_options (Optional[dict[str, Any]]): Options for the model client. logger (Optional[Callable[[Any], None]]): Custom logging function. verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed). use_rich_logging (bool): Whether to use Rich for colorized logging. @@ -50,6 +51,9 @@ class StagehandConfig(BaseModel): model_api_key: Optional[str] = Field( None, alias="modelApiKey", description="Model API key" ) + model_client_options: Optional[dict[str, Any]] = Field( + None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. apiKey, baseURL)", + ) verbose: Optional[int] = Field( 1, description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)", diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index 855b0fef..fdd88540 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -54,7 +54,7 @@ def __init__( setattr(litellm, key, value) self.logger.debug(f"Set global litellm.{key}", category="llm") # Handle common aliases or expected config names if necessary - elif key == "api_base": # Example: map api_base if needed + elif key == "api_base" or key == "baseURL": litellm.api_base = value self.logger.debug( f"Set global litellm.api_base to {value}", category="llm" diff --git a/stagehand/main.py b/stagehand/main.py index 0de682e0..0f98906a 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -68,7 +68,11 @@ def __init__( # Handle non-config parameters self.api_url = self.config.api_url - self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY") + + # Handle model-related settings + self.model_client_options = self.config.model_client_options or {} + self.model_api_key = self.config.model_api_key or self.model_client_options.get("apiKey") or os.getenv("MODEL_API_KEY") + self.model_name = self.config.model_name # Extract frequently used values from config for convenience @@ -89,11 +93,6 @@ def __init__( self.config.local_browser_launch_options or {} ) - # Handle model-related settings - self.model_client_options = {} - if self.model_api_key and "apiKey" not in self.model_client_options: - self.model_client_options["apiKey"] = self.model_api_key - # Handle browserbase session create params self.browserbase_session_create_params = make_serializable( self.config.browserbase_session_create_params diff --git a/tests/unit/llm/test_llm_integration.py b/tests/unit/llm/test_llm_integration.py index a01e7a73..00d09c40 100644 --- a/tests/unit/llm/test_llm_integration.py +++ b/tests/unit/llm/test_llm_integration.py @@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self): api_key="test-key", default_model="gpt-4o-mini", stagehand_logger=StagehandLogger(), + api_base="https://test-api-base.com", ) assert client.default_model == "gpt-4o-mini" diff --git a/tests/unit/test_client_api.py b/tests/unit/test_client_api.py index f6cb20b4..e76e30da 100644 --- a/tests/unit/test_client_api.py +++ b/tests/unit/test_client_api.py @@ -19,7 +19,7 @@ async def mock_client(self): browserbase_session_id="test-session-123", api_key="test-api-key", project_id="test-project-id", - model_api_key="test-model-api-key", + model_client_options={"apiKey": "test-model-api-key"} ) return client diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py index cd748ac4..ff220396 100644 --- a/tests/unit/test_client_initialization.py +++ b/tests/unit/test_client_initialization.py @@ -23,7 +23,7 @@ def test_init_with_direct_params(self): browserbase_session_id="test-session", api_key="test-api-key", project_id="test-project-id", - model_api_key="test-model-api-key", + model_client_options={"apiKey": "test-model-api-key"}, verbose=2, ) @@ -203,3 +203,32 @@ async def mock_create_session(): # Call _create_session and expect error with pytest.raises(RuntimeError, match="Invalid response format"): await client._create_session() + + @mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True) + def test_init_with_model_api_key_in_env(self): + config = StagehandConfig(env="LOCAL") + client = Stagehand(config=config) + assert client.model_api_key == "test-model-api-key" + + def test_init_with_custom_llm(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand(config=config) + assert client.model_api_key == "custom-llm-key" + assert client.model_client_options["apiKey"] == "custom-llm-key" + assert client.model_client_options["baseURL"] == "https://custom-llm.com" + + def test_init_with_custom_llm_override(self): + config = StagehandConfig( + env="LOCAL", + model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"} + ) + client = Stagehand( + config=config, + model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"} + ) + assert client.model_api_key == "override-llm-key" + assert client.model_client_options["apiKey"] == "override-llm-key" + assert client.model_client_options["baseURL"] == "https://override-llm.com" \ No newline at end of file From 18afbde16cb6b70406663dc57a059c745051fd38 Mon Sep 17 00:00:00 2001 From: Miguel <36487034+miguelg719@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:26:10 -0700 Subject: [PATCH 2/7] Update stagehand/config.py --- stagehand/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stagehand/config.py b/stagehand/config.py index 9557bf28..320ded82 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -52,7 +52,7 @@ class StagehandConfig(BaseModel): None, alias="modelApiKey", description="Model API key" ) model_client_options: Optional[dict[str, Any]] = Field( - None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. apiKey, baseURL)", + None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. api_base)", ) verbose: Optional[int] = Field( 1, From 4ee3a33bab4c110cfae063ba90eba1eafacd1bb0 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 19 Aug 2025 08:55:44 -0700 Subject: [PATCH 3/7] minor updates --- stagehand/config.py | 4 +++- stagehand/main.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/stagehand/config.py b/stagehand/config.py index 320ded82..8805e40b 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -52,7 +52,9 @@ class StagehandConfig(BaseModel): None, alias="modelApiKey", description="Model API key" ) model_client_options: Optional[dict[str, Any]] = Field( - None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. api_base)", + None, + alias="modelClientOptions", + description="Configuration options for the language model client (i.e. api_base)", ) verbose: Optional[int] = Field( 1, diff --git a/stagehand/main.py b/stagehand/main.py index 0f98906a..67095bf3 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -71,7 +71,7 @@ def __init__( # Handle model-related settings self.model_client_options = self.config.model_client_options or {} - self.model_api_key = self.config.model_api_key or self.model_client_options.get("apiKey") or os.getenv("MODEL_API_KEY") + self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY") self.model_name = self.config.model_name @@ -92,6 +92,7 @@ def __init__( self.local_browser_launch_options = ( self.config.local_browser_launch_options or {} ) + self.model_client_options["apiKey"] = self.model_api_key # Handle browserbase session create params self.browserbase_session_create_params = make_serializable( From e99c1ff2858dbd931cdd5ae630dd70bf9cb947ae Mon Sep 17 00:00:00 2001 From: Miguel <36487034+miguelg719@users.noreply.github.com> Date: Tue, 19 Aug 2025 08:59:03 -0700 Subject: [PATCH 4/7] Apply suggestions from code review --- tests/unit/test_client_api.py | 2 +- tests/unit/test_client_initialization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_client_api.py b/tests/unit/test_client_api.py index e76e30da..f6cb20b4 100644 --- a/tests/unit/test_client_api.py +++ b/tests/unit/test_client_api.py @@ -19,7 +19,7 @@ async def mock_client(self): browserbase_session_id="test-session-123", api_key="test-api-key", project_id="test-project-id", - model_client_options={"apiKey": "test-model-api-key"} + model_api_key="test-model-api-key", ) return client diff --git a/tests/unit/test_client_initialization.py b/tests/unit/test_client_initialization.py index 8e8fb518..afec5c6b 100644 --- a/tests/unit/test_client_initialization.py +++ b/tests/unit/test_client_initialization.py @@ -23,7 +23,7 @@ def test_init_with_direct_params(self): browserbase_session_id="test-session", api_key="test-api-key", project_id="test-project-id", - model_client_options={"apiKey": "test-model-api-key"}, + model_api_key="test-model-api-key", verbose=2, ) From fe4dcb8d652b8442d1aa0b225d07ec93c9553512 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 19 Aug 2025 09:00:46 -0700 Subject: [PATCH 5/7] changeset --- .changeset/nostalgic-tireless-kestrel.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/nostalgic-tireless-kestrel.md diff --git a/.changeset/nostalgic-tireless-kestrel.md b/.changeset/nostalgic-tireless-kestrel.md new file mode 100644 index 00000000..e7bd4fb9 --- /dev/null +++ b/.changeset/nostalgic-tireless-kestrel.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Add LLM customization support (eg. api_base) From dac64dcc63a017d4efe9e769875d94f0188276d7 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 19 Aug 2025 10:53:49 -0700 Subject: [PATCH 6/7] accept apiKey in model_client_options --- stagehand/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stagehand/main.py b/stagehand/main.py index 2681489a..e7ff8cf5 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -185,7 +185,8 @@ def __init__( self.local_browser_launch_options = ( self.config.local_browser_launch_options or {} ) - self.model_client_options["apiKey"] = self.model_api_key + if self.model_api_key: + self.model_client_options["apiKey"] = self.model_api_key # Handle browserbase session create params self.browserbase_session_create_params = make_serializable( From 75cde34e8bafaeb0cab91a47ebe2d90747981b04 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 19 Aug 2025 10:57:06 -0700 Subject: [PATCH 7/7] passthrough apiKey if none in model_api_key --- stagehand/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stagehand/main.py b/stagehand/main.py index e7ff8cf5..03d6d45f 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -187,6 +187,9 @@ def __init__( ) if self.model_api_key: self.model_client_options["apiKey"] = self.model_api_key + else: + if "apiKey" in self.model_client_options: + self.model_api_key = self.model_client_options["apiKey"] # Handle browserbase session create params self.browserbase_session_create_params = make_serializable(