Skip to content

Commit 52bc809

Browse files
committed
refactor: extract provider stream wrappers
1 parent 6094035 commit 52bc809

File tree

4 files changed

+597
-676
lines changed

4 files changed

+597
-676
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
import type { StreamFn } from "@mariozechner/pi-agent-core";
2+
import { streamSimple } from "@mariozechner/pi-ai";
3+
import {
4+
requiresOpenAiCompatibleAnthropicToolPayload,
5+
usesOpenAiFunctionAnthropicToolSchema,
6+
usesOpenAiStringModeAnthropicToolChoice,
7+
} from "../provider-capabilities.js";
8+
import { log } from "./logger.js";
9+
10+
const ANTHROPIC_CONTEXT_1M_BETA = "context-1m-2025-08-07";
11+
const ANTHROPIC_1M_MODEL_PREFIXES = ["claude-opus-4", "claude-sonnet-4"] as const;
12+
const PI_AI_DEFAULT_ANTHROPIC_BETAS = [
13+
"fine-grained-tool-streaming-2025-05-14",
14+
"interleaved-thinking-2025-05-14",
15+
] as const;
16+
const PI_AI_OAUTH_ANTHROPIC_BETAS = [
17+
"claude-code-20250219",
18+
"oauth-2025-04-20",
19+
...PI_AI_DEFAULT_ANTHROPIC_BETAS,
20+
] as const;
21+
22+
type CacheRetention = "none" | "short" | "long";
23+
24+
function isAnthropic1MModel(modelId: string): boolean {
25+
const normalized = modelId.trim().toLowerCase();
26+
return ANTHROPIC_1M_MODEL_PREFIXES.some((prefix) => normalized.startsWith(prefix));
27+
}
28+
29+
function parseHeaderList(value: unknown): string[] {
30+
if (typeof value !== "string") {
31+
return [];
32+
}
33+
return value
34+
.split(",")
35+
.map((item) => item.trim())
36+
.filter(Boolean);
37+
}
38+
39+
function mergeAnthropicBetaHeader(
40+
headers: Record<string, string> | undefined,
41+
betas: string[],
42+
): Record<string, string> {
43+
const merged = { ...headers };
44+
const existingKey = Object.keys(merged).find((key) => key.toLowerCase() === "anthropic-beta");
45+
const existing = existingKey ? parseHeaderList(merged[existingKey]) : [];
46+
const values = Array.from(new Set([...existing, ...betas]));
47+
const key = existingKey ?? "anthropic-beta";
48+
merged[key] = values.join(",");
49+
return merged;
50+
}
51+
52+
function isAnthropicOAuthApiKey(apiKey: unknown): boolean {
53+
return typeof apiKey === "string" && apiKey.includes("sk-ant-oat");
54+
}
55+
56+
function requiresAnthropicToolPayloadCompatibilityForModel(model: {
57+
api?: unknown;
58+
provider?: unknown;
59+
compat?: unknown;
60+
}): boolean {
61+
if (model.api !== "anthropic-messages") {
62+
return false;
63+
}
64+
65+
if (
66+
typeof model.provider === "string" &&
67+
requiresOpenAiCompatibleAnthropicToolPayload(model.provider)
68+
) {
69+
return true;
70+
}
71+
72+
if (!model.compat || typeof model.compat !== "object" || Array.isArray(model.compat)) {
73+
return false;
74+
}
75+
76+
return (
77+
(model.compat as { requiresOpenAiAnthropicToolPayload?: unknown })
78+
.requiresOpenAiAnthropicToolPayload === true
79+
);
80+
}
81+
82+
function usesOpenAiFunctionAnthropicToolSchemaForModel(model: {
83+
provider?: unknown;
84+
compat?: unknown;
85+
}): boolean {
86+
if (typeof model.provider === "string" && usesOpenAiFunctionAnthropicToolSchema(model.provider)) {
87+
return true;
88+
}
89+
if (!model.compat || typeof model.compat !== "object" || Array.isArray(model.compat)) {
90+
return false;
91+
}
92+
return (
93+
(model.compat as { requiresOpenAiAnthropicToolPayload?: unknown })
94+
.requiresOpenAiAnthropicToolPayload === true
95+
);
96+
}
97+
98+
function usesOpenAiStringModeAnthropicToolChoiceForModel(model: {
99+
provider?: unknown;
100+
compat?: unknown;
101+
}): boolean {
102+
if (
103+
typeof model.provider === "string" &&
104+
usesOpenAiStringModeAnthropicToolChoice(model.provider)
105+
) {
106+
return true;
107+
}
108+
if (!model.compat || typeof model.compat !== "object" || Array.isArray(model.compat)) {
109+
return false;
110+
}
111+
return (
112+
(model.compat as { requiresOpenAiAnthropicToolPayload?: unknown })
113+
.requiresOpenAiAnthropicToolPayload === true
114+
);
115+
}
116+
117+
function normalizeOpenAiFunctionAnthropicToolDefinition(
118+
tool: unknown,
119+
): Record<string, unknown> | undefined {
120+
if (!tool || typeof tool !== "object" || Array.isArray(tool)) {
121+
return undefined;
122+
}
123+
124+
const toolObj = tool as Record<string, unknown>;
125+
if (toolObj.function && typeof toolObj.function === "object") {
126+
return toolObj;
127+
}
128+
129+
const rawName = typeof toolObj.name === "string" ? toolObj.name.trim() : "";
130+
if (!rawName) {
131+
return toolObj;
132+
}
133+
134+
const functionSpec: Record<string, unknown> = {
135+
name: rawName,
136+
parameters:
137+
toolObj.input_schema && typeof toolObj.input_schema === "object"
138+
? toolObj.input_schema
139+
: toolObj.parameters && typeof toolObj.parameters === "object"
140+
? toolObj.parameters
141+
: { type: "object", properties: {} },
142+
};
143+
144+
if (typeof toolObj.description === "string" && toolObj.description.trim()) {
145+
functionSpec.description = toolObj.description;
146+
}
147+
if (typeof toolObj.strict === "boolean") {
148+
functionSpec.strict = toolObj.strict;
149+
}
150+
151+
return {
152+
type: "function",
153+
function: functionSpec,
154+
};
155+
}
156+
157+
function normalizeOpenAiStringModeAnthropicToolChoice(toolChoice: unknown): unknown {
158+
if (!toolChoice || typeof toolChoice !== "object" || Array.isArray(toolChoice)) {
159+
return toolChoice;
160+
}
161+
162+
const choice = toolChoice as Record<string, unknown>;
163+
if (choice.type === "auto") {
164+
return "auto";
165+
}
166+
if (choice.type === "none") {
167+
return "none";
168+
}
169+
if (choice.type === "required" || choice.type === "any") {
170+
return "required";
171+
}
172+
if (choice.type === "tool" && typeof choice.name === "string" && choice.name.trim()) {
173+
return {
174+
type: "function",
175+
function: { name: choice.name.trim() },
176+
};
177+
}
178+
179+
return toolChoice;
180+
}
181+
182+
export function resolveCacheRetention(
183+
extraParams: Record<string, unknown> | undefined,
184+
provider: string,
185+
): CacheRetention | undefined {
186+
const isAnthropicDirect = provider === "anthropic";
187+
const hasBedrockOverride =
188+
extraParams?.cacheRetention !== undefined || extraParams?.cacheControlTtl !== undefined;
189+
const isAnthropicBedrock = provider === "amazon-bedrock" && hasBedrockOverride;
190+
191+
if (!isAnthropicDirect && !isAnthropicBedrock) {
192+
return undefined;
193+
}
194+
195+
const newVal = extraParams?.cacheRetention;
196+
if (newVal === "none" || newVal === "short" || newVal === "long") {
197+
return newVal;
198+
}
199+
200+
const legacy = extraParams?.cacheControlTtl;
201+
if (legacy === "5m") {
202+
return "short";
203+
}
204+
if (legacy === "1h") {
205+
return "long";
206+
}
207+
208+
return isAnthropicDirect ? "short" : undefined;
209+
}
210+
211+
export function resolveAnthropicBetas(
212+
extraParams: Record<string, unknown> | undefined,
213+
provider: string,
214+
modelId: string,
215+
): string[] | undefined {
216+
if (provider !== "anthropic") {
217+
return undefined;
218+
}
219+
220+
const betas = new Set<string>();
221+
const configured = extraParams?.anthropicBeta;
222+
if (typeof configured === "string" && configured.trim()) {
223+
betas.add(configured.trim());
224+
} else if (Array.isArray(configured)) {
225+
for (const beta of configured) {
226+
if (typeof beta === "string" && beta.trim()) {
227+
betas.add(beta.trim());
228+
}
229+
}
230+
}
231+
232+
if (extraParams?.context1m === true) {
233+
if (isAnthropic1MModel(modelId)) {
234+
betas.add(ANTHROPIC_CONTEXT_1M_BETA);
235+
} else {
236+
log.warn(`ignoring context1m for non-opus/sonnet model: ${provider}/${modelId}`);
237+
}
238+
}
239+
240+
return betas.size > 0 ? [...betas] : undefined;
241+
}
242+
243+
export function createAnthropicBetaHeadersWrapper(
244+
baseStreamFn: StreamFn | undefined,
245+
betas: string[],
246+
): StreamFn {
247+
const underlying = baseStreamFn ?? streamSimple;
248+
return (model, context, options) => {
249+
const isOauth = isAnthropicOAuthApiKey(options?.apiKey);
250+
const requestedContext1m = betas.includes(ANTHROPIC_CONTEXT_1M_BETA);
251+
const effectiveBetas =
252+
isOauth && requestedContext1m
253+
? betas.filter((beta) => beta !== ANTHROPIC_CONTEXT_1M_BETA)
254+
: betas;
255+
if (isOauth && requestedContext1m) {
256+
log.warn(
257+
`ignoring context1m for OAuth token auth on ${model.provider}/${model.id}; Anthropic rejects context-1m beta with OAuth auth`,
258+
);
259+
}
260+
261+
const piAiBetas = isOauth
262+
? (PI_AI_OAUTH_ANTHROPIC_BETAS as readonly string[])
263+
: (PI_AI_DEFAULT_ANTHROPIC_BETAS as readonly string[]);
264+
const allBetas = [...new Set([...piAiBetas, ...effectiveBetas])];
265+
return underlying(model, context, {
266+
...options,
267+
headers: mergeAnthropicBetaHeader(options?.headers, allBetas),
268+
});
269+
};
270+
}
271+
272+
export function createAnthropicToolPayloadCompatibilityWrapper(
273+
baseStreamFn: StreamFn | undefined,
274+
): StreamFn {
275+
const underlying = baseStreamFn ?? streamSimple;
276+
return (model, context, options) => {
277+
const originalOnPayload = options?.onPayload;
278+
return underlying(model, context, {
279+
...options,
280+
onPayload: (payload) => {
281+
if (
282+
payload &&
283+
typeof payload === "object" &&
284+
requiresAnthropicToolPayloadCompatibilityForModel(model)
285+
) {
286+
const payloadObj = payload as Record<string, unknown>;
287+
if (
288+
Array.isArray(payloadObj.tools) &&
289+
usesOpenAiFunctionAnthropicToolSchemaForModel(model)
290+
) {
291+
payloadObj.tools = payloadObj.tools
292+
.map((tool) => normalizeOpenAiFunctionAnthropicToolDefinition(tool))
293+
.filter((tool): tool is Record<string, unknown> => !!tool);
294+
}
295+
if (usesOpenAiStringModeAnthropicToolChoiceForModel(model)) {
296+
payloadObj.tool_choice = normalizeOpenAiStringModeAnthropicToolChoice(
297+
payloadObj.tool_choice,
298+
);
299+
}
300+
}
301+
originalOnPayload?.(payload);
302+
},
303+
});
304+
};
305+
}
306+
307+
export function createBedrockNoCacheWrapper(baseStreamFn: StreamFn | undefined): StreamFn {
308+
const underlying = baseStreamFn ?? streamSimple;
309+
return (model, context, options) =>
310+
underlying(model, context, {
311+
...options,
312+
cacheRetention: "none",
313+
});
314+
}
315+
316+
export function isAnthropicBedrockModel(modelId: string): boolean {
317+
const normalized = modelId.toLowerCase();
318+
return normalized.includes("anthropic.claude") || normalized.includes("anthropic/claude");
319+
}

0 commit comments

Comments
 (0)