Skip to content

Commit

Permalink
Add better error handling to ai binding
Browse files Browse the repository at this point in the history
  • Loading branch information
G4brym committed May 10, 2024
1 parent 9bb3e21 commit 75baabe
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 27 deletions.
43 changes: 29 additions & 14 deletions src/cloudflare/internal/ai-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@ interface Fetcher {
fetch: typeof fetch
}

export type SessionOptions = { // Deprecated, do not use this
interface AiError {
internalCode: number
message: string
name: string
description: string
}

export interface SessionOptions { // Deprecated, do not use this
extraHeaders?: object;
};
}

export type AiOptions = {
export interface AiOptions {
debug?: boolean;
prefix?: string;
extraHeaders?: object;
/*
* @deprecated this option is deprecated, do not use this
*/
sessionOptions?: SessionOptions;
};
}

export class InferenceUpstreamError extends Error {
public constructor(message: string) {
public constructor(message: string, name = "InferenceUpstreamError") {
super(message);
this.name = "InferenceUpstreamError";
this.name = name;
}
}

Expand Down Expand Up @@ -52,37 +59,36 @@ export class Ai {
const body = JSON.stringify({
inputs,
options: {
debug: this.options?.debug,
debug: this.options.debug,
},
});

const fetchOptions = {
method: "POST",
body: body,
headers: {
...(this.options?.sessionOptions?.extraHeaders || {}),
...(this.options?.extraHeaders || {}),
...(this.options.sessionOptions?.extraHeaders || {}),
...(this.options.extraHeaders || {}),
"content-type": "application/json",
// 'content-encoding': 'gzip',
"cf-consn-sdk-version": "2.0.0",
"cf-consn-model-id": `${this.options.prefix ? `${this.options.prefix}:` : ""}${model}`,
},
};

const res = await this.fetcher.fetch("http://workers-binding.ai/run?version=2", fetchOptions);
const res = await this.fetcher.fetch("http://workers-binding.ai/run?version=3", fetchOptions);

this.lastRequestId = res.headers.get("cf-ai-req-id");

if (inputs['stream']) {
if (!res.ok) {
throw new InferenceUpstreamError(await res.text());
throw await this._parseError(res)
}

return res.body;

} else {
if (!res.ok || !res.body) {
throw new InferenceUpstreamError(await res.text());
throw await this._parseError(res)
}

const contentType = res.headers.get("content-type");
Expand All @@ -98,9 +104,18 @@ export class Ai {
/*
* @deprecated this method is deprecated, do not use this
*/
public getLogs(): Array<string> {
public getLogs(): string[] {
return []
}

private async _parseError(res: Response): Promise<InferenceUpstreamError> {
try {
const content = (await res.json()) as AiError
return new InferenceUpstreamError(`${content.internalCode}: ${content.description}`, content.name);
} catch {
return new InferenceUpstreamError(await res.text());
}
}
}

export default function makeBinding(env: { fetcher: Fetcher }): Ai {
Expand Down
10 changes: 5 additions & 5 deletions src/cloudflare/internal/d1-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ interface Fetcher {
fetch: typeof fetch
}

type D1Response = {
interface D1Response {
success: true
meta: Record<string, unknown>
error?: never
Expand All @@ -16,11 +16,11 @@ type D1Result<T = unknown> = D1Response & {
results: T[]
}

type D1RawOptions = {
interface D1RawOptions {
columnNames?: boolean
}

type D1UpstreamFailure = {
interface D1UpstreamFailure {
results?: never
error: string
success: false
Expand All @@ -41,12 +41,12 @@ type D1UpstreamSuccess<T = unknown> =

type D1UpstreamResponse<T = unknown> = D1UpstreamSuccess<T> | D1UpstreamFailure

type D1ExecResult = {
interface D1ExecResult {
count: number
duration: number
}

type SQLError = {
interface SQLError {
error: string
}

Expand Down
6 changes: 3 additions & 3 deletions src/cloudflare/internal/sockets.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ export class Socket {
public startTls(options: TlsOptions): Socket
}

export type TlsOptions = {
export interface TlsOptions {
expectedServerHostname?: string
}

export type SocketAddress = {
export interface SocketAddress {
hostname: string
port: number
}

export type SocketOptions = {
export interface SocketOptions {
secureTransport?: 'off' | 'on' | 'starttls'
allowHalfOpen?: boolean
}
Expand Down
36 changes: 36 additions & 0 deletions src/cloudflare/internal/test/ai/ai-api-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

import * as assert from 'node:assert'

async function assertThrowsAsynchronously(test, error) {
try {
await test();
} catch(e) {
if ()
return "everything is fine";
}
throw new assert.AssertionError("Missing rejection" + (error ? " with "+error.name : ""));
}

export const tests = {
async test(_, env) {
{
Expand Down Expand Up @@ -33,5 +43,31 @@ export const tests = {
})
assert.deepStrictEqual(await resp.json(), { response: 'model response' });
}

{
// Test error response
try {
await env.ai.run('inputErrorModel', {prompt: 'test'})
} catch(e) {
assert.deepEqual({
name: e.name, message: e.message
}, {
name: 'InvalidInput',
message: '1001: prompt and messages are mutually exclusive',
})
}
}

{
// Test error properties
const err = await env.ai._parseError(Response.json({
internalCode: 1001,
message: "InvalidInput: prompt and messages are mutually exclusive",
name: "InvalidInput",
description: "prompt and messages are mutually exclusive"
}))
assert.equal(err.name, 'InvalidInput')
assert.equal(err.message, '1001: prompt and messages are mutually exclusive')
}
},
}
15 changes: 15 additions & 0 deletions src/cloudflare/internal/test/ai/ai-mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ export default {
})
}

if (modelName === 'inputErrorModel') {
return Response.json({
internalCode: 1001,
message: "InvalidInput: prompt and messages are mutually exclusive",
name: "InvalidInput",
description: "prompt and messages are mutually exclusive"
}, {
status: 400,
headers: {
'content-type': 'application/json',
...respHeaders
}
})
}

return Response.json({response: 'model response'}, {
headers: {
'content-type': 'application/json',
Expand Down
7 changes: 2 additions & 5 deletions src/cloudflare/internal/vectorize.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ type VectorizeVectorMetadataFilterOp = "$eq" | "$ne";
/**
* Filter criteria for vector metadata used to limit the retrieved query result set.
*/
type VectorizeVectorMetadataFilter = {
[field: string]:
| Exclude<VectorizeVectorMetadataValue, string[]>
type VectorizeVectorMetadataFilter = Record<string, | Exclude<VectorizeVectorMetadataValue, string[]>
| null
| {
[Op in VectorizeVectorMetadataFilterOp]?: Exclude<
VectorizeVectorMetadataValue,
string[]
> | null;
};
};
}>;

/**
* Supported distance metrics for an index.
Expand Down

0 comments on commit 75baabe

Please sign in to comment.