diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 48045179d2a..bd2d715a632 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -6,11 +6,18 @@ interface Fetcher { fetch: typeof fetch } -export type SessionOptions = { // Deprecated, do not use this +interface AiError { + internalCode: number + message: string + name: string + description: string +} + +export interface SessionOptions { // Deprecated, do not use this extraHeaders?: object; -}; +} -export type AiOptions = { +export interface AiOptions { debug?: boolean; prefix?: string; extraHeaders?: object; @@ -18,12 +25,12 @@ export type AiOptions = { * @deprecated this option is deprecated, do not use this */ sessionOptions?: SessionOptions; -}; +} export class InferenceUpstreamError extends Error { - public constructor(message: string) { + public constructor(message: string, name = "InferenceUpstreamError") { super(message); - this.name = "InferenceUpstreamError"; + this.name = name; } } @@ -52,7 +59,7 @@ export class Ai { const body = JSON.stringify({ inputs, options: { - debug: this.options?.debug, + debug: this.options.debug, }, }); @@ -60,29 +67,28 @@ export class Ai { method: "POST", body: body, headers: { - ...(this.options?.sessionOptions?.extraHeaders || {}), - ...(this.options?.extraHeaders || {}), + ...(this.options.sessionOptions?.extraHeaders || {}), + ...(this.options.extraHeaders || {}), "content-type": "application/json", - // 'content-encoding': 'gzip', "cf-consn-sdk-version": "2.0.0", "cf-consn-model-id": `${this.options.prefix ? `${this.options.prefix}:` : ""}${model}`, }, }; - const res = await this.fetcher.fetch("http://workers-binding.ai/run?version=2", fetchOptions); + const res = await this.fetcher.fetch("http://workers-binding.ai/run?version=3", fetchOptions); this.lastRequestId = res.headers.get("cf-ai-req-id"); if (inputs['stream']) { if (!res.ok) { - throw new InferenceUpstreamError(await res.text()); + throw await this._parseError(res) } return res.body; } else { if (!res.ok || !res.body) { - throw new InferenceUpstreamError(await res.text()); + throw await this._parseError(res) } const contentType = res.headers.get("content-type"); @@ -98,9 +104,18 @@ export class Ai { /* * @deprecated this method is deprecated, do not use this */ - public getLogs(): Array { + public getLogs(): string[] { return [] } + + private async _parseError(res: Response): Promise { + try { + const content = (await res.json()) as AiError + return new InferenceUpstreamError(`${content.internalCode}: ${content.description}`, content.name); + } catch { + return new InferenceUpstreamError(await res.text()); + } + } } export default function makeBinding(env: { fetcher: Fetcher }): Ai { diff --git a/src/cloudflare/internal/d1-api.ts b/src/cloudflare/internal/d1-api.ts index 4d3386c8793..afc0998675a 100644 --- a/src/cloudflare/internal/d1-api.ts +++ b/src/cloudflare/internal/d1-api.ts @@ -6,7 +6,7 @@ interface Fetcher { fetch: typeof fetch } -type D1Response = { +interface D1Response { success: true meta: Record error?: never @@ -16,11 +16,11 @@ type D1Result = D1Response & { results: T[] } -type D1RawOptions = { +interface D1RawOptions { columnNames?: boolean } -type D1UpstreamFailure = { +interface D1UpstreamFailure { results?: never error: string success: false @@ -41,12 +41,12 @@ type D1UpstreamSuccess = type D1UpstreamResponse = D1UpstreamSuccess | D1UpstreamFailure -type D1ExecResult = { +interface D1ExecResult { count: number duration: number } -type SQLError = { +interface SQLError { error: string } diff --git a/src/cloudflare/internal/sockets.d.ts b/src/cloudflare/internal/sockets.d.ts index 76dd06732f6..f074541e396 100644 --- a/src/cloudflare/internal/sockets.d.ts +++ b/src/cloudflare/internal/sockets.d.ts @@ -12,16 +12,16 @@ export class Socket { public startTls(options: TlsOptions): Socket } -export type TlsOptions = { +export interface TlsOptions { expectedServerHostname?: string } -export type SocketAddress = { +export interface SocketAddress { hostname: string port: number } -export type SocketOptions = { +export interface SocketOptions { secureTransport?: 'off' | 'on' | 'starttls' allowHalfOpen?: boolean } diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index 97f5a06c5b0..aebb98084e6 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -4,6 +4,16 @@ import * as assert from 'node:assert' +async function assertThrowsAsynchronously(test, error) { + try { + await test(); + } catch(e) { + if () + return "everything is fine"; + } + throw new assert.AssertionError("Missing rejection" + (error ? " with "+error.name : "")); +} + export const tests = { async test(_, env) { { @@ -33,5 +43,31 @@ export const tests = { }) assert.deepStrictEqual(await resp.json(), { response: 'model response' }); } + + { + // Test error response + try { + await env.ai.run('inputErrorModel', {prompt: 'test'}) + } catch(e) { + assert.deepEqual({ + name: e.name, message: e.message + }, { + name: 'InvalidInput', + message: '1001: prompt and messages are mutually exclusive', + }) + } + } + + { + // Test error properties + const err = await env.ai._parseError(Response.json({ + internalCode: 1001, + message: "InvalidInput: prompt and messages are mutually exclusive", + name: "InvalidInput", + description: "prompt and messages are mutually exclusive" + })) + assert.equal(err.name, 'InvalidInput') + assert.equal(err.message, '1001: prompt and messages are mutually exclusive') + } }, } diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index ff4cb18e6d9..7b7e1cee197 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -21,6 +21,21 @@ export default { }) } + if (modelName === 'inputErrorModel') { + return Response.json({ + internalCode: 1001, + message: "InvalidInput: prompt and messages are mutually exclusive", + name: "InvalidInput", + description: "prompt and messages are mutually exclusive" + }, { + status: 400, + headers: { + 'content-type': 'application/json', + ...respHeaders + } + }) + } + return Response.json({response: 'model response'}, { headers: { 'content-type': 'application/json', diff --git a/src/cloudflare/internal/vectorize.d.ts b/src/cloudflare/internal/vectorize.d.ts index 87ea095cdfa..f98f30831c0 100644 --- a/src/cloudflare/internal/vectorize.d.ts +++ b/src/cloudflare/internal/vectorize.d.ts @@ -38,17 +38,14 @@ type VectorizeVectorMetadataFilterOp = "$eq" | "$ne"; /** * Filter criteria for vector metadata used to limit the retrieved query result set. */ -type VectorizeVectorMetadataFilter = { - [field: string]: - | Exclude +type VectorizeVectorMetadataFilter = Record | null | { [Op in VectorizeVectorMetadataFilterOp]?: Exclude< VectorizeVectorMetadataValue, string[] > | null; - }; -}; + }>; /** * Supported distance metrics for an index.