Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/nice-parts-poke.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"openapi-fetch": minor
---

Enable request level middlewares option
1 change: 1 addition & 0 deletions packages/openapi-fetch/src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export type RequestOptions<T> = ParamsOption<T> &
parseAs?: ParseAs;
fetch?: ClientOptions["fetch"];
headers?: HeadersOptions;
middleware?: Middleware[];
};

export type MergedOptions<T = unknown> = {
Expand Down
28 changes: 16 additions & 12 deletions packages/openapi-fetch/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export default function createClient(clientOptions) {
} = { ...clientOptions };
requestInitExt = supportsRequestInitExt() ? requestInitExt : undefined;
baseUrl = removeTrailingSlash(baseUrl);
const middlewares = [];
const globalMiddlewares = [];

/**
* Per-request fetch (keeps settings created in createClient()
Expand All @@ -52,6 +52,7 @@ export default function createClient(clientOptions) {
querySerializer: requestQuerySerializer,
bodySerializer = globalBodySerializer ?? defaultBodySerializer,
body,
middleware: requestMiddlewares = [],
...init
} = fetchOptions || {};
let finalBaseUrl = baseUrl;
Expand Down Expand Up @@ -99,6 +100,9 @@ export default function createClient(clientOptions) {
params.header,
);

// Client level middleware take priority over request-level middleware
const finalMiddlewares = [...globalMiddlewares, ...requestMiddlewares];

const requestInit = {
redirect: "follow",
...baseOptions,
Expand All @@ -122,7 +126,7 @@ export default function createClient(clientOptions) {
}
}

if (middlewares.length) {
if (finalMiddlewares.length) {
id = randomID();

// middleware (request)
Expand All @@ -133,7 +137,7 @@ export default function createClient(clientOptions) {
querySerializer,
bodySerializer,
});
for (const m of middlewares) {
for (const m of finalMiddlewares) {
if (m && typeof m === "object" && typeof m.onRequest === "function") {
const result = await m.onRequest({
request,
Expand Down Expand Up @@ -164,9 +168,9 @@ export default function createClient(clientOptions) {
let errorAfterMiddleware = error;
// middleware (error)
// execute in reverse-array order (first priority gets last transform)
if (middlewares.length) {
for (let i = middlewares.length - 1; i >= 0; i--) {
const m = middlewares[i];
if (finalMiddlewares.length) {
for (let i = finalMiddlewares.length - 1; i >= 0; i--) {
const m = finalMiddlewares[i];
if (m && typeof m === "object" && typeof m.onError === "function") {
const result = await m.onError({
request,
Expand Down Expand Up @@ -203,9 +207,9 @@ export default function createClient(clientOptions) {

// middleware (response)
// execute in reverse-array order (first priority gets last transform)
if (middlewares.length) {
for (let i = middlewares.length - 1; i >= 0; i--) {
const m = middlewares[i];
if (finalMiddlewares.length) {
for (let i = finalMiddlewares.length - 1; i >= 0; i--) {
const m = finalMiddlewares[i];
if (m && typeof m === "object" && typeof m.onResponse === "function") {
const result = await m.onResponse({
request,
Expand Down Expand Up @@ -295,15 +299,15 @@ export default function createClient(clientOptions) {
if (typeof m !== "object" || !("onRequest" in m || "onResponse" in m || "onError" in m)) {
throw new Error("Middleware must be an object with one of `onRequest()`, `onResponse() or `onError()`");
}
middlewares.push(m);
globalMiddlewares.push(m);
}
},
/** Unregister middleware */
eject(...middleware) {
for (const m of middleware) {
const i = middlewares.indexOf(m);
const i = globalMiddlewares.indexOf(m);
if (i !== -1) {
middlewares.splice(i, 1);
globalMiddlewares.splice(i, 1);
}
}
},
Expand Down
45 changes: 41 additions & 4 deletions packages/openapi-fetch/test/middleware/middleware.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,33 @@ test("executes in expected order", async () => {
return request;
},
onResponse({ response }) {
response.headers.set("step", "C");
return response;
const headers = new Headers(response.headers);
headers.set("step", "C");
if (response.headers.get("step") === "D") {
return new Response(response.body, { ...response, headers });
}
},
},
);

const { response } = await client.GET("/posts/{id}", { params: { path: { id: 123 } } });
const { response } = await client.GET("/posts/{id}", {
params: { path: { id: 123 } },
middleware: [
{
onRequest({ request }) {
request.headers.set("step", "D");
return request;
},
onResponse({ response }) {
response.headers.set("step", "D");
return response;
},
},
],
});

// assert requests ended up on step C (array order)
expect(actualRequest.headers.get("step")).toBe("C");
expect(actualRequest.headers.get("step")).toBe("D");

// assert responses ended up on step A (reverse order)
expect(response.headers.get("step")).toBe("A");
Expand Down Expand Up @@ -505,3 +522,23 @@ test("skips onResponse handlers when response is returned from onRequest", async

expect(onResponseCalled).toBe(false);
});

test("add middleware at the request level", async () => {
const customResponse = Response.json({});
const client = createObservedClient<paths>({}, async () => {
throw new Error("unexpected call to fetch");
});

const { response } = await client.GET("/posts/{id}", {
params: { path: { id: 123 } },
middleware: [
{
async onRequest() {
return customResponse;
},
},
],
});

expect(response).toBe(customResponse);
});