diff --git a/docs/api/Agent.md b/docs/api/Agent.md index 854dc6aedfc..8e8321667e7 100644 --- a/docs/api/Agent.md +++ b/docs/api/Agent.md @@ -20,6 +20,7 @@ Extends: [`ClientOptions`](Pool.md#parameter-pooloptions) * **factory** `(origin: URL, opts: Object) => Dispatcher` - Default: `(origin, opts) => new Pool(origin, opts)` * **maxRedirections** `Integer` - Default: `0`. The number of HTTP redirection to follow unless otherwise specified in `DispatchOptions`. +* **interceptors** `{ Agent: DispatchInterceptor[] }` - Default: `[RedirectInterceptor]` - A list of interceptors that are applied to the dispatch method. Additional logic can be applied (such as, but not limited to: 302 status code handling, authentication, cookies, compression and caching). ## Instance Properties diff --git a/docs/api/Client.md b/docs/api/Client.md index 76a22253ffc..913ff54bcc9 100644 --- a/docs/api/Client.md +++ b/docs/api/Client.md @@ -26,6 +26,7 @@ Returns: `Client` * **pipelining** `number | null` (optional) - Default: `1` - The amount of concurrent requests to be sent over the single TCP/TLS connection according to [RFC7230](https://tools.ietf.org/html/rfc7230#section-6.3.2). Carefully consider your workload and environment before enabling concurrent requests as pipelining may reduce performance if used incorrectly. Pipelining is sensitive to network stack settings as well as head of line blocking caused by e.g. long running requests. Set to `0` to disable keep-alive connections. * **connect** `ConnectOptions | Function | null` (optional) - Default: `null`. * **strictContentLength** `Boolean` (optional) - Default: `true` - Whether to treat request content length mismatches as errors. If true, an error is thrown when the request content-length header doesn't match the length of the request body. +* **interceptors** `{ Client: DispatchInterceptor[] }` - Default: `[RedirectInterceptor]` - A list of interceptors that are applied to the dispatch method. Additional logic can be applied (such as, but not limited to: 302 status code handling, authentication, cookies, compression and caching). #### Parameter: `ConnectOptions` diff --git a/docs/api/DispatchInterceptor.md b/docs/api/DispatchInterceptor.md new file mode 100644 index 00000000000..652b2e86bf9 --- /dev/null +++ b/docs/api/DispatchInterceptor.md @@ -0,0 +1,60 @@ +#Interface: DispatchInterceptor + +Extends: `Function` + +A function that can be applied to the `Dispatcher.Dispatch` function before it is invoked with a dispatch request. + +This allows one to write logic to intercept both the outgoing request, and the incoming response. + +### Parameter: `Dispatcher.Dispatch` + +The base dispatch function you are decorating. + +### ReturnType: `Dispatcher.Dispatch` + +A dispatch function that has been altered to provide additional logic + +### Basic Example + +Here is an example of an interceptor being used to provide a JWT bearer token + +```js +'use strict' + +const insertHeaderInterceptor = dispatch => { + return function InterceptedDispatch(opts, handler){ + opts.headers.push('Authorization', 'Bearer [Some token]') + return dispatch(opts, handler) + } +} + +const client = new Client('https://localhost:3000', { + interceptors: { Client: [insertHeaderInterceptor] } +}) + +``` + +### Basic Example 2 + +Here is a contrived example of an interceptor stripping the headers from a response. + +```js +'use strict' + +const clearHeadersInterceptor = dispatch => { + const { DecoratorHandler } = require('undici') + class ResultInterceptor extends DecoratorHandler { + onHeaders (statusCode, headers, resume) { + return super.onHeaders(statusCode, [], resume) + } + } + return function InterceptedDispatch(opts, handler){ + return dispatch(opts, new ResultInterceptor(handler)) + } +} + +const client = new Client('https://localhost:3000', { + interceptors: { Client: [clearHeadersInterceptor] } +}) + +``` diff --git a/docs/api/Pool.md b/docs/api/Pool.md index 6b08294b61c..8fcabac3154 100644 --- a/docs/api/Pool.md +++ b/docs/api/Pool.md @@ -19,6 +19,7 @@ Extends: [`ClientOptions`](Client.md#parameter-clientoptions) * **factory** `(origin: URL, opts: Object) => Dispatcher` - Default: `(origin, opts) => new Client(origin, opts)` * **connections** `number | null` (optional) - Default: `null` - The number of `Client` instances to create. When set to `null`, the `Pool` instance will create an unlimited amount of `Client` instances. +* **interceptors** `{ Pool: DispatchInterceptor[] } }` - Default: `{ Pool: [] }` - A list of interceptors that are applied to the dispatch method. Additional logic can be applied (such as, but not limited to: 302 status code handling, authentication, cookies, compression and caching). ## Instance Properties diff --git a/index.d.ts b/index.d.ts index 87d44e3dcab..814d5c1a390 100644 --- a/index.d.ts +++ b/index.d.ts @@ -2,6 +2,8 @@ import Dispatcher = require('./types/dispatcher') import { setGlobalDispatcher, getGlobalDispatcher } from './types/global-dispatcher' import { setGlobalOrigin, getGlobalOrigin } from './types/global-origin' import Pool = require('./types/pool') +import { RedirectHandler, DecoratorHandler } from './types/handlers' + import BalancedPool = require('./types/balanced-pool') import Client = require('./types/client') import buildConnector = require('./types/connector') @@ -20,7 +22,7 @@ export * from './types/formdata' export * from './types/diagnostics-channel' export { Interceptable } from './types/mock-interceptor' -export { Dispatcher, BalancedPool, Pool, Client, buildConnector, errors, Agent, request, stream, pipeline, connect, upgrade, setGlobalDispatcher, getGlobalDispatcher, setGlobalOrigin, getGlobalOrigin, MockClient, MockPool, MockAgent, mockErrors, ProxyAgent } +export { Dispatcher, BalancedPool, Pool, Client, buildConnector, errors, Agent, request, stream, pipeline, connect, upgrade, setGlobalDispatcher, getGlobalDispatcher, setGlobalOrigin, getGlobalOrigin, MockClient, MockPool, MockAgent, mockErrors, ProxyAgent, RedirectHandler, DecoratorHandler } export default Undici declare function Undici(url: string, opts: Pool.Options): Pool @@ -28,6 +30,9 @@ declare function Undici(url: string, opts: Pool.Options): Pool declare namespace Undici { var Dispatcher: typeof import('./types/dispatcher') var Pool: typeof import('./types/pool'); + var RedirectHandler: typeof import ('./types/handlers').RedirectHandler + var DecoratorHandler: typeof import ('./types/handlers').DecoratorHandler + var createRedirectInterceptor: typeof import ('./types/interceptors').createRedirectInterceptor var BalancedPool: typeof import('./types/balanced-pool'); var Client: typeof import('./types/client'); var buildConnector: typeof import('./types/connector'); diff --git a/index.js b/index.js index 8020774681e..9cde34aed6f 100644 --- a/index.js +++ b/index.js @@ -16,6 +16,9 @@ const MockPool = require('./lib/mock/mock-pool') const mockErrors = require('./lib/mock/mock-errors') const ProxyAgent = require('./lib/proxy-agent') const { getGlobalDispatcher, setGlobalDispatcher } = require('./lib/global') +const DecoratorHandler = require('./lib/handler/DecoratorHandler') +const RedirectHandler = require('./lib/handler/RedirectHandler') +const createRedirectInterceptor = require('./lib/interceptor/redirectInterceptor') const nodeVersion = process.versions.node.split('.') const nodeMajor = Number(nodeVersion[0]) @@ -30,6 +33,10 @@ module.exports.BalancedPool = BalancedPool module.exports.Agent = Agent module.exports.ProxyAgent = ProxyAgent +module.exports.DecoratorHandler = DecoratorHandler +module.exports.RedirectHandler = RedirectHandler +module.exports.createRedirectInterceptor = createRedirectInterceptor + module.exports.buildConnector = buildConnector module.exports.errors = errors diff --git a/lib/agent.js b/lib/agent.js index 47aa2365e61..0b18f2a91bd 100644 --- a/lib/agent.js +++ b/lib/agent.js @@ -1,12 +1,12 @@ 'use strict' const { InvalidArgumentError } = require('./core/errors') -const { kClients, kRunning, kClose, kDestroy, kDispatch } = require('./core/symbols') +const { kClients, kRunning, kClose, kDestroy, kDispatch, kInterceptors } = require('./core/symbols') const DispatcherBase = require('./dispatcher-base') const Pool = require('./pool') const Client = require('./client') const util = require('./core/util') -const RedirectHandler = require('./handler/redirect') +const createRedirectInterceptor = require('./interceptor/redirectInterceptor') const { WeakRef, FinalizationRegistry } = require('./compat/dispatcher-weakref')() const kOnConnect = Symbol('onConnect') @@ -44,7 +44,14 @@ class Agent extends DispatcherBase { connect = { ...connect } } + this[kInterceptors] = options.interceptors && options.interceptors.Agent && Array.isArray(options.interceptors.Agent) + ? options.interceptors.Agent + : [createRedirectInterceptor({ maxRedirections })] + this[kOptions] = { ...util.deepClone(options), connect } + this[kOptions].interceptors = options.interceptors + ? { ...options.interceptors } + : undefined this[kMaxRedirections] = maxRedirections this[kFactory] = factory this[kClients] = new Map() @@ -108,12 +115,6 @@ class Agent extends DispatcherBase { this[kFinalizer].register(dispatcher, key) } - const { maxRedirections = this[kMaxRedirections] } = opts - if (maxRedirections != null && maxRedirections !== 0) { - opts = { ...opts, maxRedirections: 0 } // Stop sub dispatcher from also redirecting. - handler = new RedirectHandler(this, maxRedirections, opts, handler) - } - return dispatcher.dispatch(opts, handler) } diff --git a/lib/balanced-pool.js b/lib/balanced-pool.js index 47468ec0460..10bc6a47baf 100644 --- a/lib/balanced-pool.js +++ b/lib/balanced-pool.js @@ -13,7 +13,7 @@ const { kGetDispatcher } = require('./pool-base') const Pool = require('./pool') -const { kUrl } = require('./core/symbols') +const { kUrl, kInterceptors } = require('./core/symbols') const { parseOrigin } = require('./core/util') const kFactory = Symbol('factory') @@ -53,6 +53,9 @@ class BalancedPool extends PoolBase { throw new InvalidArgumentError('factory must be a function.') } + this[kInterceptors] = opts.interceptors && opts.interceptors.BalancedPool && Array.isArray(opts.interceptors.BalancedPool) + ? opts.interceptors.BalancedPool + : [] this[kFactory] = factory for (const upstream of upstreams) { diff --git a/lib/client.js b/lib/client.js index 14fcaee2e3c..46ec0b99330 100644 --- a/lib/client.js +++ b/lib/client.js @@ -7,7 +7,6 @@ const net = require('net') const util = require('./core/util') const Request = require('./core/request') const DispatcherBase = require('./dispatcher-base') -const RedirectHandler = require('./handler/redirect') const { RequestContentLengthMismatchError, ResponseContentLengthMismatchError, @@ -60,7 +59,8 @@ const { kCounter, kClose, kDestroy, - kDispatch + kDispatch, + kInterceptors } = require('./core/symbols') const kClosedResolve = Symbol('kClosedResolve') @@ -82,6 +82,7 @@ try { class Client extends DispatcherBase { constructor (url, { + interceptors, maxHeaderSize, headersTimeout, socketTimeout, @@ -179,6 +180,9 @@ class Client extends DispatcherBase { }) } + this[kInterceptors] = interceptors && interceptors.Client && Array.isArray(interceptors.Client) + ? interceptors.Client + : [createRedirectInterceptor({ maxRedirections })] this[kUrl] = util.parseOrigin(url) this[kConnector] = connect this[kSocket] = null @@ -254,11 +258,6 @@ class Client extends DispatcherBase { } [kDispatch] (opts, handler) { - const { maxRedirections = this[kMaxRedirections] } = opts - if (maxRedirections) { - handler = new RedirectHandler(this, maxRedirections, opts, handler) - } - const origin = opts.origin || this[kUrl].origin const request = new Request(origin, opts, handler) @@ -319,6 +318,7 @@ class Client extends DispatcherBase { } const constants = require('./llhttp/constants') +const createRedirectInterceptor = require('./interceptor/redirectInterceptor') const EMPTY_BUF = Buffer.alloc(0) async function lazyllhttp () { diff --git a/lib/core/symbols.js b/lib/core/symbols.js index 30108827a84..34e4b9fd2aa 100644 --- a/lib/core/symbols.js +++ b/lib/core/symbols.js @@ -48,5 +48,6 @@ module.exports = { kMaxRedirections: Symbol('maxRedirections'), kMaxRequests: Symbol('maxRequestsPerClient'), kProxy: Symbol('proxy agent options'), - kCounter: Symbol('socket request counter') + kCounter: Symbol('socket request counter'), + kInterceptors: Symbol('dispatch interceptors') } diff --git a/lib/dispatcher-base.js b/lib/dispatcher-base.js index 2c12ba80f35..14a5c0acd70 100644 --- a/lib/dispatcher-base.js +++ b/lib/dispatcher-base.js @@ -6,12 +6,13 @@ const { ClientClosedError, InvalidArgumentError } = require('./core/errors') -const { kDestroy, kClose, kDispatch } = require('./core/symbols') +const { kDestroy, kClose, kDispatch, kInterceptors } = require('./core/symbols') const kDestroyed = Symbol('destroyed') const kClosed = Symbol('closed') const kOnDestroyed = Symbol('onDestroyed') const kOnClosed = Symbol('onClosed') +const kInterceptedDispatch = Symbol('Intercepted Dispatch') class DispatcherBase extends Dispatcher { constructor () { @@ -31,6 +32,23 @@ class DispatcherBase extends Dispatcher { return this[kClosed] } + get interceptors () { + return this[kInterceptors] + } + + set interceptors (newInterceptors) { + if (newInterceptors) { + for (let i = newInterceptors.length - 1; i >= 0; i--) { + const interceptor = this[kInterceptors][i] + if (typeof interceptor !== 'function') { + throw new InvalidArgumentError('interceptor must be an function') + } + } + } + + this[kInterceptors] = newInterceptors + } + close (callback) { if (callback === undefined) { return new Promise((resolve, reject) => { @@ -125,6 +143,20 @@ class DispatcherBase extends Dispatcher { }) } + [kInterceptedDispatch] (opts, handler) { + if (!this[kInterceptors] || this[kInterceptors].length === 0) { + this[kInterceptedDispatch] = this[kDispatch] + return this[kDispatch](opts, handler) + } + + let dispatch = this[kDispatch].bind(this) + for (let i = this[kInterceptors].length - 1; i >= 0; i--) { + dispatch = this[kInterceptors][i](dispatch) + } + this[kInterceptedDispatch] = dispatch + return dispatch(opts, handler) + } + dispatch (opts, handler) { if (!handler || typeof handler !== 'object') { throw new InvalidArgumentError('handler must be an object') @@ -143,7 +175,7 @@ class DispatcherBase extends Dispatcher { throw new ClientClosedError() } - return this[kDispatch](opts, handler) + return this[kInterceptedDispatch](opts, handler) } catch (err) { if (typeof handler.onError !== 'function') { throw new InvalidArgumentError('invalid onError method') diff --git a/lib/handler/DecoratorHandler.js b/lib/handler/DecoratorHandler.js new file mode 100644 index 00000000000..9d70a767f1e --- /dev/null +++ b/lib/handler/DecoratorHandler.js @@ -0,0 +1,35 @@ +'use strict' + +module.exports = class DecoratorHandler { + constructor (handler) { + this.handler = handler + } + + onConnect (...args) { + return this.handler.onConnect(...args) + } + + onError (...args) { + return this.handler.onError(...args) + } + + onUpgrade (...args) { + return this.handler.onUpgrade(...args) + } + + onHeaders (...args) { + return this.handler.onHeaders(...args) + } + + onData (...args) { + return this.handler.onData(...args) + } + + onComplete (...args) { + return this.handler.onComplete(...args) + } + + onBodySent (...args) { + return this.handler.onBodySent(...args) + } +} diff --git a/lib/handler/redirect.js b/lib/handler/RedirectHandler.js similarity index 98% rename from lib/handler/redirect.js rename to lib/handler/RedirectHandler.js index a464e052dc7..2f726e79f2f 100644 --- a/lib/handler/redirect.js +++ b/lib/handler/RedirectHandler.js @@ -24,14 +24,14 @@ class BodyAsyncIterable { } class RedirectHandler { - constructor (dispatcher, maxRedirections, opts, handler) { + constructor (dispatch, maxRedirections, opts, handler) { if (maxRedirections != null && (!Number.isInteger(maxRedirections) || maxRedirections < 0)) { throw new InvalidArgumentError('maxRedirections must be a positive number') } util.validateHandler(handler, opts.method, opts.upgrade) - this.dispatcher = dispatcher + this.dispatch = dispatch this.location = null this.abort = null this.opts = { ...opts, maxRedirections: 0 } // opts must be a copy @@ -156,7 +156,7 @@ class RedirectHandler { this.location = null this.abort = null - this.dispatcher.dispatch(this.opts, this) + this.dispatch(this.opts, this) } else { this.handler.onComplete(trailers) } diff --git a/lib/interceptor/redirectInterceptor.js b/lib/interceptor/redirectInterceptor.js new file mode 100644 index 00000000000..7cc035e096c --- /dev/null +++ b/lib/interceptor/redirectInterceptor.js @@ -0,0 +1,21 @@ +'use strict' + +const RedirectHandler = require('../handler/RedirectHandler') + +function createRedirectInterceptor ({ maxRedirections: defaultMaxRedirections }) { + return (dispatch) => { + return function Intercept (opts, handler) { + const { maxRedirections = defaultMaxRedirections } = opts + + if (!maxRedirections) { + return dispatch(opts, handler) + } + + const redirectHandler = new RedirectHandler(dispatch, maxRedirections, opts, handler) + opts = { ...opts, maxRedirections: 0 } // Stop sub dispatcher from also redirecting. + return dispatch(opts, redirectHandler) + } + } +} + +module.exports = createRedirectInterceptor diff --git a/lib/pool.js b/lib/pool.js index 155dd3604b2..c1c20dd6b87 100644 --- a/lib/pool.js +++ b/lib/pool.js @@ -12,7 +12,7 @@ const { InvalidArgumentError } = require('./core/errors') const util = require('./core/util') -const { kUrl } = require('./core/symbols') +const { kUrl, kInterceptors } = require('./core/symbols') const buildConnector = require('./core/connect') const kOptions = Symbol('options') @@ -58,9 +58,15 @@ class Pool extends PoolBase { }) } + this[kInterceptors] = options.interceptors && options.interceptors.Pool && Array.isArray(options.interceptors.Pool) + ? options.interceptors.Pool + : [] this[kConnections] = connections || null this[kUrl] = util.parseOrigin(origin) this[kOptions] = { ...util.deepClone(options), connect } + this[kOptions].interceptors = options.interceptors + ? { ...options.interceptors } + : undefined this[kFactory] = factory } diff --git a/lib/proxy-agent.js b/lib/proxy-agent.js index bfc75d796ed..6d879d02216 100644 --- a/lib/proxy-agent.js +++ b/lib/proxy-agent.js @@ -1,8 +1,9 @@ 'use strict' -const { kClose, kDestroy } = require('./core/symbols') -const Client = require('./agent') +const { kProxy, kClose, kDestroy, kInterceptors } = require('./core/symbols') +const { URL } = require('url') const Agent = require('./agent') +const Client = require('./client') const DispatcherBase = require('./dispatcher-base') const { InvalidArgumentError, RequestAbortedError } = require('./core/errors') const buildConnector = require('./core/connect') @@ -18,9 +19,29 @@ function defaultProtocolPort (protocol) { return protocol === 'https:' ? 443 : 80 } +function buildProxyOptions (opts) { + if (typeof opts === 'string') { + opts = { uri: opts } + } + + if (!opts || !opts.uri) { + throw new InvalidArgumentError('Proxy opts.uri is mandatory') + } + + return { + uri: opts.uri, + protocol: opts.protocol || 'https' + } +} + class ProxyAgent extends DispatcherBase { constructor (opts) { super(opts) + this[kProxy] = buildProxyOptions(opts) + this[kAgent] = new Agent(opts) + this[kInterceptors] = opts.interceptors && opts.interceptors.ProxyAgent && Array.isArray(opts.interceptors.ProxyAgent) + ? opts.interceptors.ProxyAgent + : [] if (typeof opts === 'string') { opts = { uri: opts } @@ -38,11 +59,12 @@ class ProxyAgent extends DispatcherBase { this[kProxyHeaders]['proxy-authorization'] = `Basic ${opts.auth}` } - const { origin, port } = new URL(opts.uri) + const resolvedUrl = new URL(opts.uri) + const { origin, port } = resolvedUrl const connect = buildConnector({ ...opts.proxyTls }) this[kConnectEndpoint] = buildConnector({ ...opts.requestTls }) - this[kClient] = new Client({ origin: opts.origin, connect }) + this[kClient] = new Client(resolvedUrl, { connect }) this[kAgent] = new Agent({ ...opts, connect: async (opts, callback) => { diff --git a/package.json b/package.json index 0a3d9e5226c..54397ed89c6 100644 --- a/package.json +++ b/package.json @@ -65,7 +65,7 @@ }, "devDependencies": { "@sinonjs/fake-timers": "^9.1.2", - "@types/node": "^17.0.29", + "@types/node": "^17.0.45", "abort-controller": "^3.0.0", "atomic-sleep": "^1.0.0", "busboy": "^1.6.0", @@ -94,7 +94,7 @@ "standard": "^17.0.0", "table": "^6.8.0", "tap": "^16.1.0", - "tsd": "^0.22.0", + "tsd": "^0.23.0", "wait-on": "^6.0.0" }, "engines": { diff --git a/test/jest/interceptor.test.js b/test/jest/interceptor.test.js new file mode 100644 index 00000000000..c3a9bdeb682 --- /dev/null +++ b/test/jest/interceptor.test.js @@ -0,0 +1,196 @@ +'use strict' + +const { createServer } = require('http') +const { Agent, request } = require('../../index') +const DecoratorHandler = require('../../lib/handler/DecoratorHandler') +/* global expect */ + +describe('interceptors', () => { + let server + beforeEach(async () => { + server = createServer((req, res) => { + res.setHeader('Content-Type', 'text/plain') + res.end('hello') + }) + await new Promise((resolve) => { server.listen(0, resolve) }) + }) + afterEach(async () => { + await server.close() + }) + + test('interceptors are applied on client from an agent', async () => { + const interceptors = [] + const buildInterceptor = dispatch => { + const interceptorContext = { requestCount: 0 } + interceptors.push(interceptorContext) + return (opts, handler) => { + interceptorContext.requestCount++ + return dispatch(opts, handler) + } + } + + // await new Promise(resolve => server.listen(0, () => resolve())) + const opts = { interceptors: { Client: [buildInterceptor] } } + const agent = new Agent(opts) + const origin = new URL(`http://localhost:${server.address().port}`) + await Promise.all([ + request(origin, { dispatcher: agent }), + request(origin, { dispatcher: agent }) + ]) + + // Assert that the requests are run on different interceptors (different Clients) + const requestCounts = interceptors.map(x => x.requestCount) + expect(requestCounts).toEqual([1, 1]) + }) + + test('interceptors are applied in the correct order', async () => { + const setHeaderInterceptor = (dispatch) => { + return (opts, handler) => { + opts.headers.push('foo', 'bar') + return dispatch(opts, handler) + } + } + + const assertHeaderInterceptor = (dispatch) => { + return (opts, handler) => { + expect(opts.headers).toEqual(['foo', 'bar']) + return dispatch(opts, handler) + } + } + + const opts = { interceptors: { Pool: [setHeaderInterceptor, assertHeaderInterceptor] } } + const agent = new Agent(opts) + const origin = new URL(`http://localhost:${server.address().port}`) + await request(origin, { dispatcher: agent, headers: [] }) + }) + + test('interceptors handlers are called in reverse order', async () => { + const clearResponseHeadersInterceptor = (dispatch) => { + return (opts, handler) => { + class ResultInterceptor extends DecoratorHandler { + onHeaders (statusCode, headers, resume) { + return super.onHeaders(statusCode, [], resume) + } + } + + return dispatch(opts, new ResultInterceptor(handler)) + } + } + + const assertHeaderInterceptor = (dispatch) => { + return (opts, handler) => { + class ResultInterceptor extends DecoratorHandler { + onHeaders (statusCode, headers, resume) { + expect(headers).toEqual([]) + return super.onHeaders(statusCode, headers, resume) + } + } + + return dispatch(opts, new ResultInterceptor(handler)) + } + } + + const opts = { interceptors: { Agent: [assertHeaderInterceptor, clearResponseHeadersInterceptor] } } + const agent = new Agent(opts) + const origin = new URL(`http://localhost:${server.address().port}`) + await request(origin, { dispatcher: agent, headers: [] }) + }) +}) + +describe('interceptors with NtlmRequestHandler', () => { + class FakeNtlmRequestHandler { + constructor (dispatch, opts, handler) { + this.dispatch = dispatch + this.opts = opts + this.handler = handler + this.requestCount = 0 + } + + onConnect (...args) { + return this.handler.onConnect(...args) + } + + onError (...args) { + return this.handler.onError(...args) + } + + onUpgrade (...args) { + return this.handler.onUpgrade(...args) + } + + onHeaders (statusCode, headers, resume, statusText) { + this.requestCount++ + if (this.requestCount < 2) { + // Do nothing + } else { + return this.handler.onHeaders(statusCode, headers, resume, statusText) + } + } + + onData (...args) { + if (this.requestCount < 2) { + // Do nothing + } else { + return this.handler.onData(...args) + } + } + + onComplete (...args) { + if (this.requestCount < 2) { + this.dispatch(this.opts, this) + } else { + return this.handler.onComplete(...args) + } + } + + onBodySent (...args) { + if (this.requestCount < 2) { + // Do nothing + } else { + return this.handler.onBodySent(...args) + } + } + } + let server + + beforeEach(async () => { + // This Test is important because NTLM and Negotiate require several + // http requests in sequence to run on the same keepAlive socket + + const socketRequestCountSymbol = Symbol('Socket Request Count') + server = createServer((req, res) => { + if (req.socket[socketRequestCountSymbol] === undefined) { + req.socket[socketRequestCountSymbol] = 0 + } + req.socket[socketRequestCountSymbol]++ + res.setHeader('Content-Type', 'text/plain') + + // Simulate NTLM/Negotiate logic, by returning 200 + // on the second request of each socket + if (req.socket[socketRequestCountSymbol] >= 2) { + res.statusCode = 200 + res.end() + } else { + res.statusCode = 401 + res.end() + } + }) + await new Promise((resolve) => { server.listen(0, resolve) }) + }) + afterEach(async () => { + await server.close() + }) + + test('Retry interceptor on Client will use the same socket', async () => { + const interceptor = dispatch => { + return (opts, handler) => { + return dispatch(opts, new FakeNtlmRequestHandler(dispatch, opts, handler)) + } + } + const opts = { interceptors: { Client: [interceptor] } } + const agent = new Agent(opts) + const origin = new URL(`http://localhost:${server.address().port}`) + const { statusCode } = await request(origin, { dispatcher: agent, headers: [] }) + expect(statusCode).toEqual(200) + }) +}) diff --git a/test/jest/mock-agent.test.js b/test/jest/mock-agent.test.js index 7a94eae641e..6f6bac27bd9 100644 --- a/test/jest/mock-agent.test.js +++ b/test/jest/mock-agent.test.js @@ -1,6 +1,5 @@ 'use strict' -const { afterEach } = require('tap') const { request, setGlobalDispatcher, MockAgent } = require('../..') const { getResponse } = require('../../lib/mock/mock-utils') diff --git a/test/types/connector.test-d.ts b/test/types/connector.test-d.ts index ce0b310edfc..ba40f0e6c08 100644 --- a/test/types/connector.test-d.ts +++ b/test/types/connector.test-d.ts @@ -6,17 +6,17 @@ import {IpcNetConnectOpts, NetConnectOpts, TcpNetConnectOpts} from "net"; const connector = buildConnector({ rejectUnauthorized: false }) expectAssignable(new Client('', { connect (opts: buildConnector.Options, cb: buildConnector.Callback) { - connector(opts, (err, socket) => { - if (err) { - return cb(err, null) + connector(opts, (...args) => { + if (args[0]) { + return cb(args[0], null) } - if (socket instanceof TLSSocket) { - if (socket.getPeerCertificate().fingerprint256 !== 'FO:OB:AR') { - socket.destroy() + if (args[1] instanceof TLSSocket) { + if (args[1].getPeerCertificate().fingerprint256 !== 'FO:OB:AR') { + args[1].destroy() return cb(new Error('Fingerprint does not match'), null) } } - return cb(null, socket) + return cb(null, args[1]) }) } })) diff --git a/test/types/dispatcher.events.test-d.ts b/test/types/dispatcher.events.test-d.ts new file mode 100644 index 00000000000..1c58add84f2 --- /dev/null +++ b/test/types/dispatcher.events.test-d.ts @@ -0,0 +1,45 @@ +import { Dispatcher } from '../..' +import {expectAssignable} from "tsd"; +import {URL} from "url"; +import {UndiciError} from "../../types/errors"; + +interface EventHandler { + connect(origin: URL, targets: readonly Dispatcher[]): void + disconnect(origin: URL, targets: readonly Dispatcher[], error: UndiciError): void + connectionError(origin: URL, targets: readonly Dispatcher[], error: UndiciError): void + drain(origin: URL): void +} + +{ + const dispatcher = new Dispatcher() + const eventHandler: EventHandler = {} as EventHandler + + expectAssignable(dispatcher.rawListeners('connect')) + expectAssignable(dispatcher.rawListeners('disconnect')) + expectAssignable(dispatcher.rawListeners('connectionError')) + expectAssignable(dispatcher.rawListeners('drain')) + + expectAssignable(dispatcher.listeners('connect')) + expectAssignable(dispatcher.listeners('disconnect')) + expectAssignable(dispatcher.listeners('connectionError')) + expectAssignable(dispatcher.listeners('drain')) + + const eventHandlerMethods: ['on', 'once', 'off', 'addListener', "removeListener", "prependListener", "prependOnceListener"] + = ['on', 'once', 'off', 'addListener', "removeListener", "prependListener", "prependOnceListener"] + + for (const method of eventHandlerMethods) { + expectAssignable(dispatcher[method]('connect', eventHandler["connect"])) + expectAssignable(dispatcher[method]('disconnect', eventHandler["disconnect"])) + expectAssignable(dispatcher[method]('connectionError', eventHandler["connectionError"])) + expectAssignable(dispatcher[method]('drain', eventHandler["drain"])) + } + + const origin = new URL('') + const targets = new Array() + const error = new UndiciError() + expectAssignable(dispatcher.emit('connect', origin, targets)) + expectAssignable(dispatcher.emit('disconnect', origin, targets, error)) + expectAssignable(dispatcher.emit('connectionError', origin, targets, error)) + expectAssignable(dispatcher.emit('drain', origin)) +} + diff --git a/test/types/index.test-d.ts b/test/types/index.test-d.ts index 774750f4ebe..ed343618f4a 100644 --- a/test/types/index.test-d.ts +++ b/test/types/index.test-d.ts @@ -1,5 +1,6 @@ import { expectAssignable } from 'tsd' -import Undici, { Pool, Client, errors, fetch, Interceptable } from '../..' +import Undici, {Pool, Client, errors, fetch, Interceptable, RedirectHandler, DecoratorHandler} from '../..' +import Dispatcher from "../../types/dispatcher"; expectAssignable(Undici('', {})) expectAssignable(new Undici.Pool('', {})) @@ -7,3 +8,11 @@ expectAssignable(new Undici.Client('', {})) expectAssignable(new Undici.MockAgent().get('')) expectAssignable(Undici.errors) expectAssignable(Undici.fetch) + +const client = new Undici.Client('', {}) +const handler: Dispatcher.DispatchHandlers = {} + +expectAssignable(new Undici.RedirectHandler(client, 10, { + path: '/', method: 'GET' +}, handler)) +expectAssignable(new Undici.DecoratorHandler(handler)) diff --git a/test/types/interceptor.test-d.ts b/test/types/interceptor.test-d.ts new file mode 100644 index 00000000000..ea69405548d --- /dev/null +++ b/test/types/interceptor.test-d.ts @@ -0,0 +1,5 @@ +import {expectAssignable} from "tsd"; +import Undici from "../.."; +import Dispatcher, {DispatchInterceptor} from "../../types/dispatcher"; + +expectAssignable(Undici.createRedirectInterceptor({ maxRedirections: 3 })) diff --git a/types/agent.d.ts b/types/agent.d.ts index ebadc194daf..c09260b2913 100644 --- a/types/agent.d.ts +++ b/types/agent.d.ts @@ -1,6 +1,7 @@ import { URL } from 'url' import Dispatcher = require('./dispatcher') import Pool = require('./pool') +import {DispatchInterceptor} from "./dispatcher"; export = Agent @@ -20,6 +21,8 @@ declare namespace Agent { factory?(origin: URL, opts: Object): Dispatcher; /** Integer. Default: `0` */ maxRedirections?: number; + + interceptors?: { Agent?: readonly DispatchInterceptor[] } & Pool.Options["interceptors"] } export interface DispatchOptions extends Dispatcher.DispatchOptions { diff --git a/types/client.d.ts b/types/client.d.ts index 22fcb42cfe8..4932ece3ffe 100644 --- a/types/client.d.ts +++ b/types/client.d.ts @@ -1,8 +1,8 @@ import { URL } from 'url' import { TlsOptions } from 'tls' import Dispatcher = require('./dispatcher') -import { DispatchOptions, RequestOptions } from './dispatcher' -import buildConnector = require('./connector') +import {DispatchInterceptor} from './dispatcher' +import buildConnector, {connector} from "./connector"; export = Client @@ -28,7 +28,7 @@ declare namespace Client { /** The amount of concurrent requests to be sent over the single TCP/TLS connection according to [RFC7230](https://tools.ietf.org/html/rfc7230#section-6.3.2). Default: `1`. */ pipelining?: number | null; /** **/ - connect?: buildConnector.BuildOptions | Function | null; + connect?: buildConnector.BuildOptions | connector | null; /** The maximum length of request headers in bytes. Default: `16384` (16KiB). */ maxHeaderSize?: number | null; /** The timeout after which a request will time out, in milliseconds. Monitors time between receiving body data. Use `0` to disable it entirely. Default: `30e3` milliseconds (30s). */ @@ -41,6 +41,8 @@ declare namespace Client { tls?: TlsOptions | null; /** */ maxRequestsPerClient?: number; + + interceptors?: {Client: readonly DispatchInterceptor[] | undefined} } export interface SocketInfo { @@ -53,4 +55,6 @@ declare namespace Client { bytesWritten?: number bytesRead?: number } + + } diff --git a/types/connector.d.ts b/types/connector.d.ts index 38016b00008..9a47f87e599 100644 --- a/types/connector.d.ts +++ b/types/connector.d.ts @@ -2,7 +2,7 @@ import {TLSSocket, ConnectionOptions} from 'tls' import {IpcNetConnectOpts, Socket, TcpNetConnectOpts} from 'net' export = buildConnector -declare function buildConnector (options?: buildConnector.BuildOptions): typeof buildConnector.connector +declare function buildConnector (options?: buildConnector.BuildOptions): buildConnector.connector declare namespace buildConnector { export type BuildOptions = (ConnectionOptions | TcpNetConnectOpts | IpcNetConnectOpts) & { @@ -20,7 +20,16 @@ declare namespace buildConnector { servername?: string } - export type Callback = (err: Error | null, socket: Socket | TLSSocket | null) => void + export type Callback = (...args: CallbackArgs) => void + type CallbackArgs = [null, Socket | TLSSocket] | [Error, null] - export function connector (options: buildConnector.Options, callback: buildConnector.Callback): Socket | TLSSocket; + export type connector = connectorAsync | connectorSync + + interface connectorSync { + (options: buildConnector.Options): Socket | TLSSocket + } + + interface connectorAsync { + (options: buildConnector.Options, callback: buildConnector.Callback): void + } } diff --git a/types/diagnostics-channel.d.ts b/types/diagnostics-channel.d.ts index c6131482280..6c754491b89 100644 --- a/types/diagnostics-channel.d.ts +++ b/types/diagnostics-channel.d.ts @@ -25,7 +25,7 @@ declare namespace DiagnosticsChannel { port: URL["port"]; servername: string | null; } - type Connector = typeof connector; + type Connector = connector; export interface RequestCreateMessage { request: Request; } diff --git a/types/dispatcher.d.ts b/types/dispatcher.d.ts index 8744f04e2d7..cbd558a932f 100644 --- a/types/dispatcher.d.ts +++ b/types/dispatcher.d.ts @@ -3,8 +3,9 @@ import { Duplex, Readable, Writable } from 'stream' import { EventEmitter } from 'events' import { IncomingHttpHeaders } from 'http' import { Blob } from 'buffer' -import BodyReadable = require('./readable') +import type BodyReadable from './readable' import { FormData } from './formdata' +import { UndiciError } from './errors' type AbortSignal = unknown; @@ -36,6 +37,59 @@ declare class Dispatcher extends EventEmitter { destroy(err: Error | null): Promise; destroy(callback: () => void): void; destroy(err: Error | null, callback: () => void): void; + + on(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + on(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + on(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + on(eventName: 'drain', callback: (origin: URL) => void): this; + + + once(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + once(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + once(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + once(eventName: 'drain', callback: (origin: URL) => void): this; + + + off(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + off(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + off(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + off(eventName: 'drain', callback: (origin: URL) => void): this; + + + addListener(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + addListener(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + addListener(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + addListener(eventName: 'drain', callback: (origin: URL) => void): this; + + removeListener(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + removeListener(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + removeListener(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + removeListener(eventName: 'drain', callback: (origin: URL) => void): this; + + prependListener(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + prependListener(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + prependListener(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + prependListener(eventName: 'drain', callback: (origin: URL) => void): this; + + prependOnceListener(eventName: 'connect', callback: (origin: URL, targets: readonly Dispatcher[]) => void): this; + prependOnceListener(eventName: 'disconnect', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + prependOnceListener(eventName: 'connectionError', callback: (origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void): this; + prependOnceListener(eventName: 'drain', callback: (origin: URL) => void): this; + + listeners(eventName: 'connect'): ((origin: URL, targets: readonly Dispatcher[]) => void)[] + listeners(eventName: 'disconnect'): ((origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void)[]; + listeners(eventName: 'connectionError'): ((origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void)[]; + listeners(eventName: 'drain'): ((origin: URL) => void)[]; + + rawListeners(eventName: 'connect'): ((origin: URL, targets: readonly Dispatcher[]) => void)[] + rawListeners(eventName: 'disconnect'): ((origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void)[]; + rawListeners(eventName: 'connectionError'): ((origin: URL, targets: readonly Dispatcher[], error: UndiciError) => void)[]; + rawListeners(eventName: 'drain'): ((origin: URL) => void)[]; + + emit(eventName: 'connect', origin: URL, targets: readonly Dispatcher[]): boolean; + emit(eventName: 'disconnect', origin: URL, targets: readonly Dispatcher[], error: UndiciError): boolean; + emit(eventName: 'connectionError', origin: URL, targets: readonly Dispatcher[], error: UndiciError): boolean; + emit(eventName: 'drain', origin: URL): boolean; } declare namespace Dispatcher { @@ -147,9 +201,9 @@ declare namespace Dispatcher { /** Invoked when an error has occurred. */ onError?(err: Error): void; /** Invoked when request is upgraded either due to a `Upgrade` header or `CONNECT` method. */ - onUpgrade?(statusCode: number, headers: string[] | null, socket: Duplex): void; + onUpgrade?(statusCode: number, headers: Buffer[] | string[] | null, socket: Duplex): void; /** Invoked when statusCode and headers have been received. May be invoked multiple times due to 1xx informational headers. */ - onHeaders?(statusCode: number, headers: string[] | null, resume: () => void): boolean; + onHeaders?(statusCode: number, headers: Buffer[] | string[] | null, resume: () => void): boolean; /** Invoked when response payload data is received. */ onData?(chunk: Buffer): boolean; /** Invoked when response payload and trailers have been received and the request has completed. */ @@ -172,4 +226,8 @@ declare namespace Dispatcher { json(): Promise; text(): Promise; } + + export interface DispatchInterceptor { + (dispatch: Dispatcher['dispatch']): Dispatcher['dispatch'] + } } diff --git a/types/handlers.d.ts b/types/handlers.d.ts new file mode 100644 index 00000000000..eb4f5a9e8dd --- /dev/null +++ b/types/handlers.d.ts @@ -0,0 +1,9 @@ +import Dispatcher from "./dispatcher"; + +export declare class RedirectHandler implements Dispatcher.DispatchHandlers{ + constructor (dispatch: Dispatcher, maxRedirections: number, opts: Dispatcher.DispatchOptions, handler: Dispatcher.DispatchHandlers) +} + +export declare class DecoratorHandler implements Dispatcher.DispatchHandlers{ + constructor (handler: Dispatcher.DispatchHandlers) +} diff --git a/types/interceptors.d.ts b/types/interceptors.d.ts new file mode 100644 index 00000000000..a920ea982e8 --- /dev/null +++ b/types/interceptors.d.ts @@ -0,0 +1,5 @@ +import {DispatchInterceptor} from "./dispatcher"; + +type RedirectInterceptorOpts = { maxRedirections?: number } + +export declare function createRedirectInterceptor (opts: RedirectInterceptorOpts): DispatchInterceptor diff --git a/types/pool.d.ts b/types/pool.d.ts index af7fb94a9a6..0ef0bc39884 100644 --- a/types/pool.d.ts +++ b/types/pool.d.ts @@ -2,6 +2,7 @@ import Client = require('./client') import Dispatcher = require('./dispatcher') import TPoolStats = require('./pool-stats') import { URL } from 'url' +import {DispatchInterceptor} from "./dispatcher"; export = Pool @@ -22,5 +23,7 @@ declare namespace Pool { factory?(origin: URL, opts: object): Dispatcher; /** The max number of clients to create. `null` if no limit. Default `null`. */ connections?: number | null; + + interceptors?: { Pool?: readonly DispatchInterceptor[] } & Client.Options["interceptors"] } }