Skip to content

Commit edc57ac

Browse files
Adding initial support for LLM Customisation (#132) (#184)
* Adding initial support for LLM Customisation (#132) * Adding support for LLM Customisation * Added back model_api_key config * Fixed README * Update README.md Co-authored-by: Miguel <36487034+miguelg719@users.noreply.github.com> * Update stagehand/llm/client.py --------- Co-authored-by: Miguel <36487034+miguelg719@users.noreply.github.com> * Update stagehand/config.py * minor updates * Apply suggestions from code review * changeset * accept apiKey in model_client_options * passthrough apiKey if none in model_api_key --------- Co-authored-by: Sanveer Singh Osahan <sanveer.singh@atlan.com>
1 parent 61ed2e0 commit edc57ac

File tree

6 files changed

+50
-5
lines changed

6 files changed

+50
-5
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Add LLM customization support (eg. api_base)

stagehand/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class StagehandConfig(BaseModel):
2020
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
2121
model_name (Optional[str]): Name of the model to use.
2222
model_api_key (Optional[str]): Model API key.
23+
model_client_options (Optional[dict[str, Any]]): Options for the model client.
2324
logger (Optional[Callable[[Any], None]]): Custom logging function.
2425
verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed).
2526
use_rich_logging (bool): Whether to use Rich for colorized logging.
@@ -50,6 +51,11 @@ class StagehandConfig(BaseModel):
5051
model_api_key: Optional[str] = Field(
5152
None, alias="modelApiKey", description="Model API key"
5253
)
54+
model_client_options: Optional[dict[str, Any]] = Field(
55+
None,
56+
alias="modelClientOptions",
57+
description="Configuration options for the language model client (i.e. api_base)",
58+
)
5359
verbose: Optional[int] = Field(
5460
1,
5561
description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)",

stagehand/llm/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
setattr(litellm, key, value)
5555
self.logger.debug(f"Set global litellm.{key}", category="llm")
5656
# Handle common aliases or expected config names if necessary
57-
elif key == "api_base": # Example: map api_base if needed
57+
elif key == "api_base" or key == "baseURL":
5858
litellm.api_base = value
5959
self.logger.debug(
6060
f"Set global litellm.api_base to {value}", category="llm"

stagehand/main.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,11 @@ def __init__(
161161

162162
# Handle non-config parameters
163163
self.api_url = self.config.api_url
164+
165+
# Handle model-related settings
166+
self.model_client_options = self.config.model_client_options or {}
164167
self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY")
168+
165169
self.model_name = self.config.model_name
166170

167171
# Extract frequently used values from config for convenience
@@ -181,11 +185,11 @@ def __init__(
181185
self.local_browser_launch_options = (
182186
self.config.local_browser_launch_options or {}
183187
)
184-
185-
# Handle model-related settings
186-
self.model_client_options = {}
187-
if self.model_api_key and "apiKey" not in self.model_client_options:
188+
if self.model_api_key:
188189
self.model_client_options["apiKey"] = self.model_api_key
190+
else:
191+
if "apiKey" in self.model_client_options:
192+
self.model_api_key = self.model_client_options["apiKey"]
189193

190194
# Handle browserbase session create params
191195
self.browserbase_session_create_params = make_serializable(

tests/unit/llm/test_llm_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self):
4040
api_key="test-key",
4141
default_model="gpt-4o-mini",
4242
stagehand_logger=StagehandLogger(),
43+
api_base="https://test-api-base.com",
4344
)
4445

4546
assert client.default_model == "gpt-4o-mini"

tests/unit/test_client_initialization.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,32 @@ async def mock_create_session():
228228
# Call _create_session and expect error
229229
with pytest.raises(RuntimeError, match="Invalid response format"):
230230
await client._create_session()
231+
232+
@mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True)
233+
def test_init_with_model_api_key_in_env(self):
234+
config = StagehandConfig(env="LOCAL")
235+
client = Stagehand(config=config)
236+
assert client.model_api_key == "test-model-api-key"
237+
238+
def test_init_with_custom_llm(self):
239+
config = StagehandConfig(
240+
env="LOCAL",
241+
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
242+
)
243+
client = Stagehand(config=config)
244+
assert client.model_api_key == "custom-llm-key"
245+
assert client.model_client_options["apiKey"] == "custom-llm-key"
246+
assert client.model_client_options["baseURL"] == "https://custom-llm.com"
247+
248+
def test_init_with_custom_llm_override(self):
249+
config = StagehandConfig(
250+
env="LOCAL",
251+
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
252+
)
253+
client = Stagehand(
254+
config=config,
255+
model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"}
256+
)
257+
assert client.model_api_key == "override-llm-key"
258+
assert client.model_client_options["apiKey"] == "override-llm-key"
259+
assert client.model_client_options["baseURL"] == "https://override-llm.com"

0 commit comments

Comments
 (0)