Skip to content

Commit

Permalink
refc(validate): pass validations as part of config object
Browse files Browse the repository at this point in the history
  • Loading branch information
gamemaker1 committed Jul 16, 2023
1 parent ee5c18a commit 9b605a2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 57 deletions.
90 changes: 53 additions & 37 deletions source/lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,24 @@ type Configuration = {
skip: ValueDeterminingMiddleware<boolean>
requestWasSuccessful: ValueDeterminingMiddleware<boolean>
store: Store
validate: boolean
validations: Validations
}

/**
* Converts a `Configuration` object to a valid `Options` object, in case the
* configuration needs to be passed back to the user.
*
* @param config {Configuration} - The configuration object to convert.
*
* @returns {Partial<Options>} - The options derived from the configuration.
*/
const getOptionsFromConfig = (config: Configuration): Options => {
const { validations, ...directlyPassableEntries } = config

return {
...directlyPassableEntries,
validate: validations.enabled,
}
}

/**
Expand All @@ -123,11 +140,11 @@ type Configuration = {
*/
const omitUndefinedOptions = (
passedOptions: Partial<Options>,
): Partial<Configuration> => {
const omittedOptions: Partial<Configuration> = {}
): Partial<Options> => {
const omittedOptions: Partial<Options> = {}

for (const k of Object.keys(passedOptions)) {
const key = k as keyof Configuration
const key = k as keyof Options

if (passedOptions[key] !== undefined) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
Expand All @@ -145,15 +162,15 @@ const omitUndefinedOptions = (
*
* @returns {Configuration} - A complete configuration object.
*/
const parseOptions = (
passedOptions: Partial<Options>,
validations: Validations,
): Configuration => {
const parseOptions = (passedOptions: Partial<Options>): Configuration => {
// Passing undefined should be equivalent to not passing an option at all, so we'll
// omit all fields where their value is undefined.
const notUndefinedOptions: Partial<Options> =
omitUndefinedOptions(passedOptions)

// Create the validator before even parsing the rest of the options
const validations = new Validations(notUndefinedOptions?.validate ?? true)

// See ./types.ts#Options for a detailed description of the options and their
// defaults.
const config: Configuration = {
Expand Down Expand Up @@ -206,13 +223,13 @@ const parseOptions = (
_response: Response,
_optionsUsed: Options,
): void {},
// Print an error to the console if a few known misconfigurations are detected.
validate: true,
// Allow the options object to be overriden by the options passed to the middleware.
...notUndefinedOptions,
// Note that this field is declared after the user's options are spread in,
// so that this field doesn't get overriden with an un-promisified store!
store: promisifyStore(notUndefinedOptions.store ?? new MemoryStore()),
// Print an error to the console if a few known misconfigurations are detected.
validations,
}

// Ensure that the store passed implements the `Store` interface
Expand Down Expand Up @@ -265,19 +282,18 @@ const handleAsyncErrors =
const rateLimit = (
passedOptions?: Partial<Options>,
): RateLimitRequestHandler => {
// Create the validator before even parsing the rest of the options
const validations = new Validations(passedOptions?.validate ?? true)
// Parse the options and add the default values for unspecified options
const options = parseOptions(passedOptions ?? {}, validations)
const config = parseOptions(passedOptions ?? {})
const options = getOptionsFromConfig(config)

// Call the `init` method on the store, if it exists
if (typeof options.store.init === 'function') options.store.init(options)
if (typeof config.store.init === 'function') config.store.init(options)

// Then return the actual middleware
const middleware = handleAsyncErrors(
async (request: Request, response: Response, next: NextFunction) => {
// First check if we should skip the request
const skip = await options.skip(request, response)
const skip = await config.skip(request, response)
if (skip) {
next()
return
Expand All @@ -287,31 +303,31 @@ const rateLimit = (
const augmentedRequest = request as AugmentedRequest

// Get a unique key for the client
const key = await options.keyGenerator(request, response)
const key = await config.keyGenerator(request, response)
// Increment the client's hit counter by one
const { totalHits, resetTime } = await options.store.increment(key)
const { totalHits, resetTime } = await config.store.increment(key)

// Get the quota (max number of hits) for each client
const retrieveQuota =
typeof options.max === 'function'
? options.max(request, response)
: options.max
typeof config.max === 'function'
? config.max(request, response)
: config.max

const maxHits = await retrieveQuota
// Set the rate limit information on the augmented request object
augmentedRequest[options.requestPropertyName] = {
augmentedRequest[config.requestPropertyName] = {
limit: maxHits,
current: totalHits,
remaining: Math.max(maxHits - totalHits, 0),
resetTime,
}

// Set the X-RateLimit headers on the response object if enabled
if (options.legacyHeaders && !response.headersSent) {
if (config.legacyHeaders && !response.headersSent) {
response.setHeader('X-RateLimit-Limit', maxHits)
response.setHeader(
'X-RateLimit-Remaining',
augmentedRequest[options.requestPropertyName].remaining,
augmentedRequest[config.requestPropertyName].remaining,
)

// If we have a resetTime, also provide the current date to help avoid
Expand All @@ -327,11 +343,11 @@ const rateLimit = (

// Set the standardized RateLimit headers on the response object
// if enabled.
if (options.standardHeaders && !response.headersSent) {
if (config.standardHeaders && !response.headersSent) {
response.setHeader('RateLimit-Limit', maxHits)
response.setHeader(
'RateLimit-Remaining',
augmentedRequest[options.requestPropertyName].remaining,
augmentedRequest[config.requestPropertyName].remaining,
)

if (resetTime) {
Expand All @@ -344,18 +360,18 @@ const rateLimit = (

// If we are to skip failed/successfull requests, decrement the
// counter accordingly once we know the status code of the request
if (options.skipFailedRequests || options.skipSuccessfulRequests) {
if (config.skipFailedRequests || config.skipSuccessfulRequests) {
let decremented = false
const decrementKey = async () => {
if (!decremented) {
await options.store.decrement(key)
await config.store.decrement(key)
decremented = true
}
}

if (options.skipFailedRequests) {
if (config.skipFailedRequests) {
response.on('finish', async () => {
if (!options.requestWasSuccessful(request, response))
if (!config.requestWasSuccessful(request, response))
await decrementKey()
})
response.on('close', async () => {
Expand All @@ -366,9 +382,9 @@ const rateLimit = (
})
}

if (options.skipSuccessfulRequests) {
if (config.skipSuccessfulRequests) {
response.on('finish', async () => {
if (options.requestWasSuccessful(request, response))
if (config.requestWasSuccessful(request, response))
await decrementKey()
})
}
Expand All @@ -378,23 +394,23 @@ const rateLimit = (
// exceeds their rate limit
// NOTE: `onLimitReached` is deprecated, this should be removed in v7.x
if (maxHits && totalHits === maxHits + 1) {
options.onLimitReached(request, response, options)
config.onLimitReached(request, response, options)
}

// Disable the validations, since they should have run at least once by now.
validations.disable()
config.validations.disable()

// If the client has exceeded their rate limit, set the Retry-After header
// and call the `handler` function
if (maxHits && totalHits > maxHits) {
if (
(options.legacyHeaders || options.standardHeaders) &&
(config.legacyHeaders || config.standardHeaders) &&
!response.headersSent
) {
response.setHeader('Retry-After', Math.ceil(options.windowMs / 1000))
response.setHeader('Retry-After', Math.ceil(config.windowMs / 1000))
}

options.handler(request, response, next, options)
config.handler(request, response, next, options)
return
}

Expand All @@ -405,7 +421,7 @@ const rateLimit = (
// Export the store's function to reset the hit counter for a particular
// client based on their identifier
;(middleware as RateLimitRequestHandler).resetKey =
options.store.resetKey.bind(options.store)
config.store.resetKey.bind(config.store)

return middleware as RateLimitRequestHandler
}
Expand Down
38 changes: 19 additions & 19 deletions source/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ export type Options = {
*
* Defaults to `60000` ms (= 1 minute).
*/
readonly windowMs: number
windowMs: number

/**
* The maximum number of connections to allow during the `window` before
Expand All @@ -184,75 +184,75 @@ export type Options = {
*
* Defaults to `5`.
*/
readonly max: number | ValueDeterminingMiddleware<number>
max: number | ValueDeterminingMiddleware<number>

/**
* The response body to send back when a client is rate limited.
*
* Defaults to `'Too many requests, please try again later.'`
*/
// eslint-disable-next-line @typescript-eslint/no-redundant-type-constituents
readonly message: any | ValueDeterminingMiddleware<any>
message: any | ValueDeterminingMiddleware<any>

/**
* The HTTP status code to send back when a client is rate limited.
*
* Defaults to `HTTP 429 Too Many Requests` (RFC 6585).
*/
readonly statusCode: number
statusCode: number

/**
* Whether to send `X-RateLimit-*` headers with the rate limit and the number
* of requests.
*
* Defaults to `true` (for backward compatibility).
*/
readonly legacyHeaders: boolean
legacyHeaders: boolean

/**
* Whether to enable support for the standardized rate limit headers (`RateLimit-*`).
*
* Defaults to `false` (for backward compatibility, but its use is recommended).
*/
readonly standardHeaders: boolean
standardHeaders: boolean

/**
* The name of the property on the request object to store the rate limit info.
*
* Defaults to `rateLimit`.
*/
readonly requestPropertyName: string
requestPropertyName: string

/**
* If `true`, the library will (by default) skip all requests that have a 4XX
* or 5XX status.
*
* Defaults to `false`.
*/
readonly skipFailedRequests: boolean
skipFailedRequests: boolean

/**
* If `true`, the library will (by default) skip all requests that have a
* status code less than 400.
*
* Defaults to `false`.
*/
readonly skipSuccessfulRequests: boolean
skipSuccessfulRequests: boolean

/**
* Method to generate custom identifiers for clients.
*
* By default, the client's IP address is used.
*/
readonly keyGenerator: ValueDeterminingMiddleware<string>
keyGenerator: ValueDeterminingMiddleware<string>

/**
* Express request handler that sends back a response when a client is
* rate-limited.
*
* By default, sends back the `statusCode` and `message` set via the options.
*/
readonly handler: RateLimitExceededEventHandler
handler: RateLimitExceededEventHandler

/**
* Express request handler that sends back a response when a client has
Expand All @@ -261,15 +261,15 @@ export type Options = {
* @deprecated 6.x - Please use a custom `handler` that checks the number of
* hits instead.
*/
readonly onLimitReached: RateLimitReachedEventHandler
onLimitReached: RateLimitReachedEventHandler

/**
* Method (in the form of middleware) to determine whether or not this request
* counts towards a client's quota.
*
* By default, skips no requests.
*/
readonly skip: ValueDeterminingMiddleware<boolean>
skip: ValueDeterminingMiddleware<boolean>

/**
* Method to determine whether or not the request counts as 'succesful'. Used
Expand All @@ -278,7 +278,7 @@ export type Options = {
* By default, requests with a response status code less than 400 are considered
* successful.
*/
readonly requestWasSuccessful: ValueDeterminingMiddleware<boolean>
requestWasSuccessful: ValueDeterminingMiddleware<boolean>

/**
* The `Store` to use to store the hit count for each client.
Expand All @@ -290,7 +290,7 @@ export type Options = {
/**
* Whether or not the validation checks should run.
*/
readonly validate: boolean
validate: boolean

/**
* Whether to send `X-RateLimit-*` headers with the rate limit and the number
Expand Down Expand Up @@ -322,8 +322,8 @@ export type AugmentedRequest = Request & {
* Express request object.
*/
export type RateLimitInfo = {
readonly limit: number
readonly current: number
readonly remaining: number
readonly resetTime: Date | undefined
limit: number
current: number
remaining: number
resetTime: Date | undefined
}
14 changes: 13 additions & 1 deletion source/validations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,20 @@ class ValidationError extends Error {
}
}

/**
* The validations that can be run, as well as the methods to run them.
*/
export class Validations {
constructor(private enabled: boolean) {}
// eslint-disable-next-line @typescript-eslint/parameter-properties
enabled: boolean

constructor(enabled: boolean) {
this.enabled = enabled
}

enable() {
this.enabled = true
}

disable() {
this.enabled = false
Expand Down

0 comments on commit 9b605a2

Please sign in to comment.