Skip to content

Commit

Permalink
(feat, typescript): accept abort signals as request options (#3694)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsinghvi committed May 24, 2024
1 parent c222939 commit b0bcea3
Show file tree
Hide file tree
Showing 311 changed files with 4,799 additions and 56,261 deletions.
12 changes: 12 additions & 0 deletions generators/typescript/sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.20.0-rc0] - 2024-05-20

- Feature: Add `abortSignal` to `RequestOptions`. SDK consumers can now specify an
an arbitrary abort signal that can interrupt the API call.

```ts
const controller = new AbortController();
client.endpoint.call(..., {
abortSignal: controller.signal,
})
```

## [0.19.0] - 2024-05-20

- Feature: Add `inlineFileProperties` configuration to support generating file upload properties
Expand Down
2 changes: 1 addition & 1 deletion generators/typescript/sdk/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.19.0
0.20.0-rc0
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export class GeneratedSdkClientClassImpl implements GeneratedSdkClientClass {
private static REQUEST_OPTIONS_INTERFACE_NAME = "RequestOptions";
private static IDEMPOTENT_REQUEST_OPTIONS_INTERFACE_NAME = "IdempotentRequestOptions";
private static TIMEOUT_IN_SECONDS_REQUEST_OPTION_PROPERTY_NAME = "timeoutInSeconds";
private static ABORT_SIGNAL_PROPERTY_NAME = "abortSignal";
private static MAX_RETRIES_REQUEST_OPTION_PROPERTY_NAME = "maxRetries";
private static OPTIONS_INTERFACE_NAME = "Options";
private static OPTIONS_PRIVATE_MEMBER = "_options";
Expand Down Expand Up @@ -766,6 +767,11 @@ export class GeneratedSdkClientClassImpl implements GeneratedSdkClientClass {
name: GeneratedSdkClientClassImpl.MAX_RETRIES_REQUEST_OPTION_PROPERTY_NAME,
type: getTextOfTsNode(ts.factory.createKeywordTypeNode(ts.SyntaxKind.NumberKeyword)),
hasQuestionToken: true
},
{
name: GeneratedSdkClientClassImpl.ABORT_SIGNAL_PROPERTY_NAME,
type: getTextOfTsNode(ts.factory.createIdentifier("AbortSignal")),
hasQuestionToken: true
}
]
};
Expand Down Expand Up @@ -1125,6 +1131,18 @@ export class GeneratedSdkClientClassImpl implements GeneratedSdkClientClass {
);
}

public getReferenceToAbortSignal({
referenceToRequestOptions
}: {
referenceToRequestOptions: ts.Expression;
}): ts.Expression {
return ts.factory.createPropertyAccessChain(
referenceToRequestOptions,
ts.factory.createToken(ts.SyntaxKind.QuestionDotToken),
ts.factory.createIdentifier(GeneratedSdkClientClassImpl.ABORT_SIGNAL_PROPERTY_NAME)
);
}

public getReferenceToOptions(): ts.Expression {
return ts.factory.createPropertyAccessExpression(
ts.factory.createThis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { GeneratedSdkClientClassImpl } from "../GeneratedSdkClientClassImpl";
import { GeneratedEndpointResponse } from "./default/endpoint-response/GeneratedEndpointResponse";
import { buildUrl } from "./utils/buildUrl";
import {
getAbortSignalExpression,
getMaxRetriesExpression,
getRequestOptionsParameter,
getTimeoutExpression
Expand Down Expand Up @@ -160,6 +161,11 @@ export class GeneratedFileDownloadEndpointImplementation implements GeneratedEnd
this.generatedSdkClientClass
)
}),
abortSignal: getAbortSignalExpression({
abortSignalReference: this.generatedSdkClientClass.getReferenceToAbortSignal.bind(
this.generatedSdkClientClass
)
}),
withCredentials: this.includeCredentialsOnCrossOriginRequests,
responseType: visitJavaScriptRuntime(context.targetRuntime, {
browser: () => "blob",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { GeneratedSdkClientClassImpl } from "../GeneratedSdkClientClassImpl";
import { GeneratedEndpointResponse } from "./default/endpoint-response/GeneratedEndpointResponse";
import { buildUrl } from "./utils/buildUrl";
import {
getAbortSignalExpression,
getMaxRetriesExpression,
getRequestOptionsParameter,
getTimeoutExpression
Expand Down Expand Up @@ -131,6 +132,11 @@ export class GeneratedStreamingEndpointImplementation implements GeneratedEndpoi
this.generatedSdkClientClass
)
}),
abortSignal: getAbortSignalExpression({
abortSignalReference: this.generatedSdkClientClass.getReferenceToAbortSignal.bind(
this.generatedSdkClientClass
)
}),
responseType: "streaming",
withCredentials: this.includeCredentialsOnCrossOriginRequests
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { GeneratedEndpointRequest } from "../../endpoint-request/GeneratedEndpoi
import { GeneratedSdkClientClassImpl } from "../../GeneratedSdkClientClassImpl";
import { buildUrl } from "../utils/buildUrl";
import {
getAbortSignalExpression,
getMaxRetriesExpression,
getRequestOptionsParameter,
getTimeoutExpression,
Expand Down Expand Up @@ -206,6 +207,12 @@ export class GeneratedDefaultEndpointImplementation implements GeneratedEndpoint
this.generatedSdkClientClass
)
}),
abortSignal: getAbortSignalExpression({
abortSignalReference: this.generatedSdkClientClass.getReferenceToAbortSignal.bind(
this.generatedSdkClientClass
)
}),

withCredentials: this.includeCredentialsOnCrossOriginRequests
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ export const getMaxRetriesExpression = ({
isNullable: true
});
};

export const getAbortSignalExpression = ({
abortSignalReference
}: {
abortSignalReference: (args: { referenceToRequestOptions: ts.Expression }) => ts.Expression;
}): ts.Expression => {
return abortSignalReference({
referenceToRequestOptions: ts.factory.createIdentifier(REQUEST_OPTIONS_PARAMETER_NAME)
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export declare namespace Fetcher {
contentType?: string | ts.Expression;
queryParameters: ts.Expression | undefined;
body: ts.Expression | undefined;
abortSignal: ts.Expression | undefined;
withCredentials: boolean;
timeoutInSeconds: ts.Expression;
maxRetries?: ts.Expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ export class FetcherImpl extends CoreUtility implements Fetcher {
body: "body",
timeoutMs: "timeoutMs",
withCredentials: "withCredentials",
responseType: "responseType"
responseType: "responseType",
abortSignal: "abortSignal"
},
_getReferenceToType: this.getReferenceToTypeInFetcherModule("Args")
},
Expand Down Expand Up @@ -132,6 +133,11 @@ export class FetcherImpl extends CoreUtility implements Fetcher {
)
);
}
if (args.abortSignal) {
properties.push(
ts.factory.createPropertyAssignment(this.Fetcher.Args.properties.abortSignal, args.abortSignal)
);
}

return ts.factory.createAwaitExpression(
ts.factory.createCallExpression(referenceToFetcher, cast != null ? [cast] : [], [
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{ "ignores": ["@types/jest", "@types/node"], "ignore-patterns": ["lib"] }
{ "ignores": ["@types/jest", "@types/node", "node-fetch", "qs", "@types/node-fetch", "@types/qs"], "ignore-patterns": ["lib"] }
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"rules": {
"duplicate-dependencies": {
"exclude": ["@fern-fern/ir-sdk"]
},
"depcheck": {
"ignores": ["@types/jest", "@types/node", "node-fetch", "qs", "@types/node-fetch", "@types/qs"]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export declare namespace Fetcher {
timeoutMs?: number;
maxRetries?: number;
withCredentials?: boolean;
abortSignal?: AbortSignal;
responseType?: "json" | "blob" | "streaming" | "text";
}

Expand Down Expand Up @@ -103,21 +104,33 @@ async function fetcherImpl<R = unknown>(args: Fetcher.Args): Promise<APIResponse
: ((await import("node-fetch")).default as any);

const makeRequest = async (): Promise<Response> => {
const controller = new AbortController();
let abortId = undefined;
const signals: AbortSignal[] = [];

// Add timeout signal
let timeoutAbortId: NodeJS.Timeout | undefined = undefined;
if (args.timeoutMs != null) {
abortId = setTimeout(() => controller.abort(), args.timeoutMs);
const { signal, abortId } = getTimeoutSignal(args.timeoutMs);
timeoutAbortId = abortId;
signals.push(signal);
}

// Add arbitrary signal
if (args.abortSignal != null) {
signals.push(args.abortSignal);
}

const response = await fetchFn(url, {
method: args.method,
headers,
body,
signal: controller.signal,
signal: anySignal(signals),
credentials: args.withCredentials ? "include" : undefined
});
if (abortId != null) {
clearTimeout(abortId);

if (timeoutAbortId != null) {
clearTimeout(timeoutAbortId);
}

return response;
};

Expand Down Expand Up @@ -181,7 +194,15 @@ async function fetcherImpl<R = unknown>(args: Fetcher.Args): Promise<APIResponse
};
}
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
if (args.abortSignal != null && args.abortSignal.aborted) {
return {
ok: false,
error: {
reason: "unknown",
errorMessage: "The user aborted a request"
}
};
} else if (error instanceof Error && error.name === "AbortError") {
return {
ok: false,
error: {
Expand All @@ -208,4 +229,43 @@ async function fetcherImpl<R = unknown>(args: Fetcher.Args): Promise<APIResponse
}
}

const TIMEOUT = "timeout";

function getTimeoutSignal(timeoutMs: number): { signal: AbortSignal; abortId: NodeJS.Timeout } {
const controller = new AbortController();
const abortId = setTimeout(() => controller.abort(TIMEOUT), timeoutMs);
return { signal: controller.signal, abortId };
}

/**
* Returns an abort signal that is getting aborted when
* at least one of the specified abort signals is aborted.
*
* Requires at least node.js 18.
*/
function anySignal(...args: AbortSignal[] | [AbortSignal[]]): AbortSignal {
// Allowing signals to be passed either as array
// of signals or as multiple arguments.
const signals = <AbortSignal[]>(args.length === 1 && Array.isArray(args[0]) ? args[0] : args);

const controller = new AbortController();

for (const signal of signals) {
if (signal.aborted) {
// Exiting early if one of the signals
// is already aborted.
controller.abort((signal as any)?.reason);
break;
}

// Listening for signals and removing the listeners
// when at least one symbol is aborted.
signal.addEventListener("abort", () => controller.abort((signal as any)?.reason), {
signal: controller.signal
});
}

return controller.signal;
}

export const fetcher: FetchFunction = fetcherImpl;
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ export declare namespace Stream {
* The event shape to use for parsing the stream data.
*/
eventShape: JsonEvent | SseEvent;
/**
* An abort signal to stop the stream.
*/
signal?: AbortSignal;
}

interface JsonEvent {
Expand All @@ -36,8 +40,9 @@ export class Stream<T> implements AsyncIterable<T> {
private prefix: string | undefined;
private messageTerminator: string;
private streamTerminator: string | undefined;
private controller: AbortController = new AbortController();

constructor({ stream, parse, eventShape }: Stream.Args & { parse: (val: unknown) => Promise<T> }) {
constructor({ stream, parse, eventShape, signal }: Stream.Args & { parse: (val: unknown) => Promise<T> }) {
this.stream = stream;
this.parse = parse;
if (eventShape.type === "sse") {
Expand All @@ -47,9 +52,11 @@ export class Stream<T> implements AsyncIterable<T> {
} else {
this.messageTerminator = eventShape.messageTerminator;
}
signal?.addEventListener("abort", () => this.controller.abort());
}

private async *iterMessages(): AsyncGenerator<T, void> {
this.controller.signal;
const stream = readableStreamAsyncIterable<any>(this.stream);
let buf = "";
let prefixSeen = false;
Expand Down
65 changes: 65 additions & 0 deletions packages/cli/ete-tests/src/tests/fetcher/fetcher.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,69 @@ describe("Fetcher Tests", () => {
process.stdout.write(JSON.stringify(message));
}
}, 90_000);

it.skip("abort while making request", async () => {
const controller = new AbortController();
const call = fetcher<stream.Readable>({
url: "https://api.cohere.ai/v1/chat",
method: "POST",
responseType: "streaming",
headers: {
Authorization: "Bearer <>",
"Content-Type": "application/json"
},
body: {
message: "Write a long essay about devtools",
stream: true
},
// timeoutMs: 10,
abortSignal: controller.signal
});
controller.abort();
const response = await call;
expect(response.ok).toEqual(false);
if (response.ok) {
throw new Error("Expected response to fail");
}
expect(response.error.reason === "unknown" && response.error.errorMessage.includes("aborted")).toBe(true);
}, 90_000);

it.skip("abort while streaming events", async () => {
const controller = new AbortController();
const response = await fetcher<stream.Readable>({
url: "https://api.cohere.ai/v1/chat",
method: "POST",
responseType: "streaming",
headers: {
Authorization: "Bearer ",
"Content-Type": "application/json"
},
body: {
message: "Write a long essay about devtools",
stream: true
},
// timeoutMs: 10,
abortSignal: controller.signal
});
expect(response.ok).toEqual(true);
if (!response.ok) {
throw new Error("Response failed");
}
const stream = new Stream<unknown>({
stream: response.body,
parse: async (data) => data,
eventShape: {
type: "json",
messageTerminator: "\n"
}
});
let i = 1;
for await (const event of stream) {
if (i === 10) {
controller.abort();
}
console.log(JSON.stringify(event));

Check failure on line 179 in packages/cli/ete-tests/src/tests/fetcher/fetcher.test.ts

View workflow job for this annotation

GitHub Actions / eslint

Unexpected console statement
i += 1;
}
}, 90_000);
});
1 change: 1 addition & 0 deletions seed/ts-sdk/api-wide-base-path/src/Client.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b0bcea3

Please sign in to comment.