Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/vertex-model-forwarding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@browserbasehq/stagehand": minor
"@browserbasehq/stagehand-server-v3": minor
---

Forward constructor and request model configuration when initializing API-backed sessions.
119 changes: 88 additions & 31 deletions packages/core/lib/v3/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ interface ClientSessionStartParams extends Api.SessionStartRequest {
* Optional: when omitted, requests are sent without the x-model-api-key header
* and the server is expected to handle model authentication on its own. */
modelApiKey?: string;
/** Default model config for later action requests. Not sent to /sessions/start. */
defaultModelConfig?: ModelConfiguration;
}

/**
Expand All @@ -115,6 +117,11 @@ type ApiResponse<T> =
| { success: true; data: T }
| { success: false; message: string };

type PreparedModelConfig = { modelName: string; apiKey?: string } & Record<
string,
unknown
>;

/**
* Union of all API request body types for type-safe execute() calls
*/
Expand Down Expand Up @@ -180,6 +187,7 @@ export class StagehandAPIClient {
private sessionId?: string;
private modelApiKey?: string;
private modelProvider?: string;
private defaultModelConfig?: PreparedModelConfig;
private region?: BrowserbaseRegion;
private logger: (message: LogLine) => void;
private fetchWithCookies;
Expand All @@ -205,6 +213,7 @@ export class StagehandAPIClient {
async init({
modelName,
modelApiKey,
defaultModelConfig,
domSettleTimeoutMs,
verbose,
systemPrompt,
Expand All @@ -218,6 +227,9 @@ export class StagehandAPIClient {
this.modelProvider = modelName?.includes("/")
? modelName.split("/")[0]
: undefined;
this.defaultModelConfig = defaultModelConfig
? this.prepareModelConfig(defaultModelConfig)
: undefined;

// Store the region for multi-region API URL resolution
this.region = browserbaseSessionCreateParams?.region;
Expand Down Expand Up @@ -288,12 +300,18 @@ export class StagehandAPIClient {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { page: _, serverCache: enableCache, ...restOptions } = options;
serverCache = enableCache;
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
} else if (this.defaultModelConfig) {
restOptions.model = this.getDefaultModelConfig();
}
if (Object.keys(restOptions).length > 0) {
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
}
wireOptions = restOptions as unknown as Api.ActRequest["options"];
}
} else if (this.defaultModelConfig) {
wireOptions = {
model: this.getDefaultModelConfig(),
} as unknown as Api.ActRequest["options"];
}

// Build wire-format request body
Expand Down Expand Up @@ -326,12 +344,18 @@ export class StagehandAPIClient {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { page: _, serverCache: enableCache, ...restOptions } = options;
serverCache = enableCache;
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
} else if (this.defaultModelConfig) {
restOptions.model = this.getDefaultModelConfig();
}
if (Object.keys(restOptions).length > 0) {
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
}
wireOptions = restOptions as unknown as Api.ExtractRequest["options"];
}
} else if (this.defaultModelConfig) {
wireOptions = {
model: this.getDefaultModelConfig(),
} as unknown as Api.ExtractRequest["options"];
}

// Build wire-format request body
Expand Down Expand Up @@ -361,12 +385,18 @@ export class StagehandAPIClient {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { page: _, serverCache: enableCache, ...restOptions } = options;
serverCache = enableCache;
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
} else if (this.defaultModelConfig) {
restOptions.model = this.getDefaultModelConfig();
}
if (Object.keys(restOptions).length > 0) {
if (restOptions.model) {
restOptions.model = this.prepareModelConfig(restOptions.model);
}
wireOptions = restOptions as unknown as Api.ObserveRequest["options"];
}
} else if (this.defaultModelConfig) {
wireOptions = {
model: this.getDefaultModelConfig(),
} as unknown as Api.ObserveRequest["options"];
}

// Build wire-format request body
Expand All @@ -388,7 +418,22 @@ export class StagehandAPIClient {
options?: Api.NavigateRequest["options"],
frameId?: string,
): Promise<SerializableResponse | null> {
const requestBody: Api.NavigateRequest = { url, options, frameId };
const publicOptions = { ...(options ?? {}) } as NonNullable<
Api.NavigateRequest["options"]
> & { model?: unknown };
delete publicOptions.model;

const wireOptions = {
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
...publicOptions,
...(this.defaultModelConfig
? { model: this.getDefaultModelConfig() }
: {}),
};
const requestBody: Api.NavigateRequest = {
url,
options: Object.keys(wireOptions).length > 0 ? wireOptions : undefined,
frameId,
};

return this.execute<SerializableResponse | null>({
method: "navigate",
Expand Down Expand Up @@ -425,7 +470,7 @@ export class StagehandAPIClient {
cua: agentConfig.mode === undefined ? agentConfig.cua : undefined,
model: agentConfig.model
? this.prepareModelConfig(agentConfig.model)
: undefined,
: this.getDefaultModelConfig(),
executionModel: agentConfig.executionModel
? this.prepareModelConfig(agentConfig.executionModel)
: undefined,
Expand Down Expand Up @@ -605,40 +650,52 @@ export class StagehandAPIClient {
* In API mode, we only attempt to load an API key from env vars when the
* model provider differs from the one used to init the session.
*/
private prepareModelConfig(
model: ModelConfiguration,
): { modelName: string; apiKey?: string } & Record<string, unknown> {
private prepareModelConfig(model: ModelConfiguration): PreparedModelConfig {
if (typeof model === "string") {
// Extract provider from model string (e.g., "openai/gpt-5-nano" -> "openai")
const provider = model.includes("/") ? model.split("/")[0] : undefined;
const provider = this.getModelProvider(model);
const inheritedDefault =
provider && provider === this.modelProvider
? this.getDefaultModelConfig()
: undefined;
const apiKey =
provider && provider !== this.modelProvider
? (loadApiKeyFromEnv(provider, this.logger) ?? this.modelApiKey)
: this.modelApiKey;
return {
...inheritedDefault,
modelName: model,
...(apiKey ? { apiKey } : {}),
};
}

if (!model.apiKey) {
const provider = model.modelName?.includes("/")
? model.modelName.split("/")[0]
const provider = this.getModelProvider(model.modelName);
const inheritedDefault =
provider && provider === this.modelProvider
? this.getDefaultModelConfig()
: undefined;
const apiKey =
provider && provider !== this.modelProvider
? (loadApiKeyFromEnv(provider, this.logger) ?? this.modelApiKey)
: this.modelApiKey;
return {
...model,
...(apiKey ? { apiKey } : {}),
};
}
const apiKey =
!model.apiKey && provider && provider !== this.modelProvider
? (loadApiKeyFromEnv(provider, this.logger) ?? this.modelApiKey)
: !model.apiKey
? this.modelApiKey
: undefined;

return {
...inheritedDefault,
...model,
...(apiKey ? { apiKey } : {}),
} as PreparedModelConfig;
}

private getDefaultModelConfig(): PreparedModelConfig | undefined {
return this.defaultModelConfig
? ({ ...this.defaultModelConfig } as PreparedModelConfig)
: undefined;
}

return model as { modelName: string; apiKey: string } & Record<
string,
unknown
>;
private getModelProvider(modelName: string | undefined): string | undefined {
return modelName?.includes("/") ? modelName.split("/")[0] : undefined;
}

private consumeFinishedEventData<T>(): T | null {
Expand Down
48 changes: 37 additions & 11 deletions packages/core/lib/v3/llm/LLMProvider.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { LanguageModelV2Middleware } from "@ai-sdk/provider";
import {
ExperimentalNotConfiguredError,
UnsupportedAISDKModelProviderError,
UnsupportedModelError,
UnsupportedModelProviderError,
Expand Down Expand Up @@ -70,6 +69,37 @@ const AISDKProvidersWithAPIKey: Record<string, AISDKCustomProvider> = {
gateway: createGateway,
};

type AISDKProviderClientOptions = ClientOptions & Record<string, unknown>;

export function toAISDKClientOptions(
subProvider: string,
clientOptions?: ClientOptions,
): AISDKProviderClientOptions | undefined {
if (!clientOptions || subProvider !== "vertex") {
return clientOptions as AISDKProviderClientOptions | undefined;
}

const { auth, providerOptions, ...rest } = clientOptions;
const vertexOptions = providerOptions?.vertex;

return {
...rest,
...(vertexOptions ?? {}),
...(auth?.type === "googleServiceAccount"
? {
googleAuthOptions: {
credentials: auth.credentials,
...(auth.scopes ? { scopes: auth.scopes } : {}),
...(auth.projectId ? { projectId: auth.projectId } : {}),
...(auth.universeDomain
? { universeDomain: auth.universeDomain }
: {}),
},
}
: {}),
};
}

const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = {
"gpt-4.1": "openai",
"gpt-4.1-mini": "openai",
Expand Down Expand Up @@ -106,9 +136,12 @@ export function getAISDKLanguageModel(
clientOptions?: ClientOptions,
middleware?: LanguageModelV2Middleware,
) {
const aiSdkClientOptions = toAISDKClientOptions(subProvider, clientOptions);
const hasValidOptions =
clientOptions &&
Object.values(clientOptions).some((v) => v !== undefined && v !== null);
aiSdkClientOptions &&
Object.values(aiSdkClientOptions).some(
(v) => v !== undefined && v !== null,
);

let model;
if (hasValidOptions) {
Expand All @@ -119,7 +152,7 @@ export function getAISDKLanguageModel(
Object.keys(AISDKProvidersWithAPIKey),
);
}
const provider = creator(clientOptions);
const provider = creator(aiSdkClientOptions as ClientOptions);
model = provider(subModelName);
} else {
const provider = AISDKProviders[subProvider];
Expand Down Expand Up @@ -163,13 +196,6 @@ export class LLMProvider {
const firstSlashIndex = modelName.indexOf("/");
const subProvider = modelName.substring(0, firstSlashIndex);
const subModelName = modelName.substring(firstSlashIndex + 1);
if (
subProvider === "vertex" &&
!options?.disableAPI &&
!options?.experimental
) {
throw new ExperimentalNotConfiguredError("Vertex provider");
}

const effectiveMiddleware = options?.middleware ?? this.middleware;
const languageModel = getAISDKLanguageModel(
Expand Down
Loading
Loading