diff --git a/source/lib.ts b/source/lib.ts index d7c5392f..1aea05bb 100644 --- a/source/lib.ts +++ b/source/lib.ts @@ -107,7 +107,24 @@ type Configuration = { skip: ValueDeterminingMiddleware requestWasSuccessful: ValueDeterminingMiddleware store: Store - validate: boolean + validations: Validations +} + +/** + * Converts a `Configuration` object to a valid `Options` object, in case the + * configuration needs to be passed back to the user. + * + * @param config {Configuration} - The configuration object to convert. + * + * @returns {Partial} - The options derived from the configuration. + */ +const getOptionsFromConfig = (config: Configuration): Options => { + const { validations, ...directlyPassableEntries } = config + + return { + ...directlyPassableEntries, + validate: validations.enabled, + } } /** @@ -123,11 +140,11 @@ type Configuration = { */ const omitUndefinedOptions = ( passedOptions: Partial, -): Partial => { - const omittedOptions: Partial = {} +): Partial => { + const omittedOptions: Partial = {} for (const k of Object.keys(passedOptions)) { - const key = k as keyof Configuration + const key = k as keyof Options if (passedOptions[key] !== undefined) { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment @@ -145,15 +162,15 @@ const omitUndefinedOptions = ( * * @returns {Configuration} - A complete configuration object. */ -const parseOptions = ( - passedOptions: Partial, - validations: Validations, -): Configuration => { +const parseOptions = (passedOptions: Partial): Configuration => { // Passing undefined should be equivalent to not passing an option at all, so we'll // omit all fields where their value is undefined. const notUndefinedOptions: Partial = omitUndefinedOptions(passedOptions) + // Create the validator before even parsing the rest of the options + const validations = new Validations(notUndefinedOptions?.validate ?? true) + // See ./types.ts#Options for a detailed description of the options and their // defaults. const config: Configuration = { @@ -206,13 +223,13 @@ const parseOptions = ( _response: Response, _optionsUsed: Options, ): void {}, - // Print an error to the console if a few known misconfigurations are detected. - validate: true, // Allow the options object to be overriden by the options passed to the middleware. ...notUndefinedOptions, // Note that this field is declared after the user's options are spread in, // so that this field doesn't get overriden with an un-promisified store! store: promisifyStore(notUndefinedOptions.store ?? new MemoryStore()), + // Print an error to the console if a few known misconfigurations are detected. + validations, } // Ensure that the store passed implements the `Store` interface @@ -265,19 +282,18 @@ const handleAsyncErrors = const rateLimit = ( passedOptions?: Partial, ): RateLimitRequestHandler => { - // Create the validator before even parsing the rest of the options - const validations = new Validations(passedOptions?.validate ?? true) // Parse the options and add the default values for unspecified options - const options = parseOptions(passedOptions ?? {}, validations) + const config = parseOptions(passedOptions ?? {}) + const options = getOptionsFromConfig(config) // Call the `init` method on the store, if it exists - if (typeof options.store.init === 'function') options.store.init(options) + if (typeof config.store.init === 'function') config.store.init(options) // Then return the actual middleware const middleware = handleAsyncErrors( async (request: Request, response: Response, next: NextFunction) => { // First check if we should skip the request - const skip = await options.skip(request, response) + const skip = await config.skip(request, response) if (skip) { next() return @@ -287,19 +303,19 @@ const rateLimit = ( const augmentedRequest = request as AugmentedRequest // Get a unique key for the client - const key = await options.keyGenerator(request, response) + const key = await config.keyGenerator(request, response) // Increment the client's hit counter by one - const { totalHits, resetTime } = await options.store.increment(key) + const { totalHits, resetTime } = await config.store.increment(key) // Get the quota (max number of hits) for each client const retrieveQuota = - typeof options.max === 'function' - ? options.max(request, response) - : options.max + typeof config.max === 'function' + ? config.max(request, response) + : config.max const maxHits = await retrieveQuota // Set the rate limit information on the augmented request object - augmentedRequest[options.requestPropertyName] = { + augmentedRequest[config.requestPropertyName] = { limit: maxHits, current: totalHits, remaining: Math.max(maxHits - totalHits, 0), @@ -307,11 +323,11 @@ const rateLimit = ( } // Set the X-RateLimit headers on the response object if enabled - if (options.legacyHeaders && !response.headersSent) { + if (config.legacyHeaders && !response.headersSent) { response.setHeader('X-RateLimit-Limit', maxHits) response.setHeader( 'X-RateLimit-Remaining', - augmentedRequest[options.requestPropertyName].remaining, + augmentedRequest[config.requestPropertyName].remaining, ) // If we have a resetTime, also provide the current date to help avoid @@ -327,11 +343,11 @@ const rateLimit = ( // Set the standardized RateLimit headers on the response object // if enabled. - if (options.standardHeaders && !response.headersSent) { + if (config.standardHeaders && !response.headersSent) { response.setHeader('RateLimit-Limit', maxHits) response.setHeader( 'RateLimit-Remaining', - augmentedRequest[options.requestPropertyName].remaining, + augmentedRequest[config.requestPropertyName].remaining, ) if (resetTime) { @@ -344,18 +360,18 @@ const rateLimit = ( // If we are to skip failed/successfull requests, decrement the // counter accordingly once we know the status code of the request - if (options.skipFailedRequests || options.skipSuccessfulRequests) { + if (config.skipFailedRequests || config.skipSuccessfulRequests) { let decremented = false const decrementKey = async () => { if (!decremented) { - await options.store.decrement(key) + await config.store.decrement(key) decremented = true } } - if (options.skipFailedRequests) { + if (config.skipFailedRequests) { response.on('finish', async () => { - if (!options.requestWasSuccessful(request, response)) + if (!config.requestWasSuccessful(request, response)) await decrementKey() }) response.on('close', async () => { @@ -366,9 +382,9 @@ const rateLimit = ( }) } - if (options.skipSuccessfulRequests) { + if (config.skipSuccessfulRequests) { response.on('finish', async () => { - if (options.requestWasSuccessful(request, response)) + if (config.requestWasSuccessful(request, response)) await decrementKey() }) } @@ -378,23 +394,23 @@ const rateLimit = ( // exceeds their rate limit // NOTE: `onLimitReached` is deprecated, this should be removed in v7.x if (maxHits && totalHits === maxHits + 1) { - options.onLimitReached(request, response, options) + config.onLimitReached(request, response, options) } // Disable the validations, since they should have run at least once by now. - validations.disable() + config.validations.disable() // If the client has exceeded their rate limit, set the Retry-After header // and call the `handler` function if (maxHits && totalHits > maxHits) { if ( - (options.legacyHeaders || options.standardHeaders) && + (config.legacyHeaders || config.standardHeaders) && !response.headersSent ) { - response.setHeader('Retry-After', Math.ceil(options.windowMs / 1000)) + response.setHeader('Retry-After', Math.ceil(config.windowMs / 1000)) } - options.handler(request, response, next, options) + config.handler(request, response, next, options) return } @@ -405,7 +421,7 @@ const rateLimit = ( // Export the store's function to reset the hit counter for a particular // client based on their identifier ;(middleware as RateLimitRequestHandler).resetKey = - options.store.resetKey.bind(options.store) + config.store.resetKey.bind(config.store) return middleware as RateLimitRequestHandler } diff --git a/source/types.ts b/source/types.ts index f6e3f714..755be846 100644 --- a/source/types.ts +++ b/source/types.ts @@ -173,7 +173,7 @@ export type Options = { * * Defaults to `60000` ms (= 1 minute). */ - readonly windowMs: number + windowMs: number /** * The maximum number of connections to allow during the `window` before @@ -184,7 +184,7 @@ export type Options = { * * Defaults to `5`. */ - readonly max: number | ValueDeterminingMiddleware + max: number | ValueDeterminingMiddleware /** * The response body to send back when a client is rate limited. @@ -192,14 +192,14 @@ export type Options = { * Defaults to `'Too many requests, please try again later.'` */ // eslint-disable-next-line @typescript-eslint/no-redundant-type-constituents - readonly message: any | ValueDeterminingMiddleware + message: any | ValueDeterminingMiddleware /** * The HTTP status code to send back when a client is rate limited. * * Defaults to `HTTP 429 Too Many Requests` (RFC 6585). */ - readonly statusCode: number + statusCode: number /** * Whether to send `X-RateLimit-*` headers with the rate limit and the number @@ -207,21 +207,21 @@ export type Options = { * * Defaults to `true` (for backward compatibility). */ - readonly legacyHeaders: boolean + legacyHeaders: boolean /** * Whether to enable support for the standardized rate limit headers (`RateLimit-*`). * * Defaults to `false` (for backward compatibility, but its use is recommended). */ - readonly standardHeaders: boolean + standardHeaders: boolean /** * The name of the property on the request object to store the rate limit info. * * Defaults to `rateLimit`. */ - readonly requestPropertyName: string + requestPropertyName: string /** * If `true`, the library will (by default) skip all requests that have a 4XX @@ -229,7 +229,7 @@ export type Options = { * * Defaults to `false`. */ - readonly skipFailedRequests: boolean + skipFailedRequests: boolean /** * If `true`, the library will (by default) skip all requests that have a @@ -237,14 +237,14 @@ export type Options = { * * Defaults to `false`. */ - readonly skipSuccessfulRequests: boolean + skipSuccessfulRequests: boolean /** * Method to generate custom identifiers for clients. * * By default, the client's IP address is used. */ - readonly keyGenerator: ValueDeterminingMiddleware + keyGenerator: ValueDeterminingMiddleware /** * Express request handler that sends back a response when a client is @@ -252,7 +252,7 @@ export type Options = { * * By default, sends back the `statusCode` and `message` set via the options. */ - readonly handler: RateLimitExceededEventHandler + handler: RateLimitExceededEventHandler /** * Express request handler that sends back a response when a client has @@ -261,7 +261,7 @@ export type Options = { * @deprecated 6.x - Please use a custom `handler` that checks the number of * hits instead. */ - readonly onLimitReached: RateLimitReachedEventHandler + onLimitReached: RateLimitReachedEventHandler /** * Method (in the form of middleware) to determine whether or not this request @@ -269,7 +269,7 @@ export type Options = { * * By default, skips no requests. */ - readonly skip: ValueDeterminingMiddleware + skip: ValueDeterminingMiddleware /** * Method to determine whether or not the request counts as 'succesful'. Used @@ -278,7 +278,7 @@ export type Options = { * By default, requests with a response status code less than 400 are considered * successful. */ - readonly requestWasSuccessful: ValueDeterminingMiddleware + requestWasSuccessful: ValueDeterminingMiddleware /** * The `Store` to use to store the hit count for each client. @@ -290,7 +290,7 @@ export type Options = { /** * Whether or not the validation checks should run. */ - readonly validate: boolean + validate: boolean /** * Whether to send `X-RateLimit-*` headers with the rate limit and the number @@ -322,8 +322,8 @@ export type AugmentedRequest = Request & { * Express request object. */ export type RateLimitInfo = { - readonly limit: number - readonly current: number - readonly remaining: number - readonly resetTime: Date | undefined + limit: number + current: number + remaining: number + resetTime: Date | undefined } diff --git a/source/validations.ts b/source/validations.ts index ec525891..8a1c1699 100644 --- a/source/validations.ts +++ b/source/validations.ts @@ -30,8 +30,20 @@ class ValidationError extends Error { } } +/** + * The validations that can be run, as well as the methods to run them. + */ export class Validations { - constructor(private enabled: boolean) {} + // eslint-disable-next-line @typescript-eslint/parameter-properties + enabled: boolean + + constructor(enabled: boolean) { + this.enabled = enabled + } + + enable() { + this.enabled = true + } disable() { this.enabled = false