Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block duration option for throttler #1668

Merged
merged 13 commits into from
Jul 8, 2024
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ The following options are valid for the object passed to the array of the `Throt
<td><code>limit</code></td>
<td>the maximum number of requests within the TTL limit</td>
</tr>
<tr>
<td><code>blockDuration</code></td>
<td>the number of milliseconds that request will be blocked for that time</td>
</tr>
<tr>
<td><code>ignoreUserAgents</code></td>
<td>an array of regular expressions of user-agents to ignore when it comes to throttling requests</td>
Expand Down
5 changes: 5 additions & 0 deletions src/throttler-module-options.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ export interface ThrottlerOptions {
*/
ttl: Resolvable<number>;

/**
* The number of millisecond the request will be blocked.
asif-jalil marked this conversation as resolved.
Show resolved Hide resolved
*/
blockDuration?: Resolvable<number>;

/**
* The user agents that should be ignored (checked against the User-Agent header).
*/
Expand Down
14 changes: 12 additions & 2 deletions src/throttler-storage-options.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@ export interface ThrottlerStorageOptions {
/**
* Amount of requests done by a specific user (partially based on IP).
*/
totalHits: number;
totalHits: Record<string, number>;

/**
* Unix timestamp in milliseconds when the `totalHits` expire.
* Unix timestamp in milliseconds that indicates `ttl` lifetime.
*/
expiresAt: number;

/**
* Define whether the request is blocked or not.
*/
isBlocked: boolean;

/**
* Unix timestamp in milliseconds when the `totalHits` expire.
*/
blockExpiresAt: number;
}

export const ThrottlerStorageOptions = Symbol('ThrottlerStorageOptions');
12 changes: 11 additions & 1 deletion src/throttler-storage-record.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@ export interface ThrottlerStorageRecord {
totalHits: number;

/**
* Amount of seconds when the `totalHits` should expire.
* Amount of seconds when the `ttl` should expire.
*/
timeToExpire: number;

/**
* Define whether the request is blocked or not.
*/
isBlocked: boolean;

/**
* Amount of seconds when the `totalHits` should expire.
*/
timeToBlockExpire: number;
}

export const ThrottlerStorageRecord = Symbol('ThrottlerStorageRecord');
8 changes: 7 additions & 1 deletion src/throttler-storage.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ export interface ThrottlerStorage {
* Increment the amount of requests for a given record. The record will
* automatically be removed from the storage once its TTL has been reached.
*/
increment(key: string, ttl: number): Promise<ThrottlerStorageRecord>;
increment(
key: string,
ttl: number,
limit: number,
blockDuration: number,
throttlerName: string,
): Promise<ThrottlerStorageRecord>;
}

export const ThrottlerStorage = Symbol('ThrottlerStorage');
1 change: 1 addition & 0 deletions src/throttler.constants.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export const THROTTLER_LIMIT = 'THROTTLER:LIMIT';
export const THROTTLER_TTL = 'THROTTLER:TTL';
export const THROTTLER_TRACKER = 'THROTTLER:TRACKER';
export const THROTTLER_BLOCK_DURATION = 'THROTTLER:BLOCK_DURATION';
export const THROTTLER_KEY_GENERATOR = 'THROTTLER:KEY_GENERATOR';
export const THROTTLER_OPTIONS = 'THROTTLER:MODULE_OPTIONS';
export const THROTTLER_SKIP = 'THROTTLER:SKIP';
3 changes: 3 additions & 0 deletions src/throttler.decorator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ThrottlerGetTrackerFunction,
} from './throttler-module-options.interface';
import {
THROTTLER_BLOCK_DURATION,
THROTTLER_KEY_GENERATOR,
THROTTLER_LIMIT,
THROTTLER_SKIP,
Expand All @@ -16,6 +17,7 @@ import { getOptionsToken, getStorageToken } from './throttler.providers';
interface ThrottlerMethodOrControllerOptions {
limit?: Resolvable<number>;
ttl?: Resolvable<number>;
blockDuration?: Resolvable<number>;
getTracker?: ThrottlerGetTrackerFunction;
generateKey?: ThrottlerGenerateKeyFunction;
}
Expand All @@ -27,6 +29,7 @@ function setThrottlerMetadata(
for (const name in options) {
Reflect.defineMetadata(THROTTLER_TTL + name, options[name].ttl, target);
Reflect.defineMetadata(THROTTLER_LIMIT + name, options[name].limit, target);
Reflect.defineMetadata(THROTTLER_BLOCK_DURATION + name, options[name].blockDuration, target);
Reflect.defineMetadata(THROTTLER_TRACKER + name, options[name].getTracker, target);
Reflect.defineMetadata(THROTTLER_KEY_GENERATOR + name, options[name].generateKey, target);
}
Expand Down
43 changes: 43 additions & 0 deletions src/throttler.guard.interface.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { ExecutionContext } from '@nestjs/common';
import { ThrottlerStorageRecord } from './throttler-storage-record.interface';
import {
ThrottlerGenerateKeyFunction,
ThrottlerGetTrackerFunction,
ThrottlerOptions,
} from './throttler-module-options.interface';

/**
* Interface describing the details of a rate limit applied by the ThrottlerGuard.
Expand All @@ -24,3 +30,40 @@ export interface ThrottlerLimitDetail extends ThrottlerStorageRecord {
*/
tracker: string;
}

export interface ThrottlerRequest {
/**
* Interface describing details about the current request pipeline.
*/
context: ExecutionContext;

/**
* The amount of requests that are allowed within the ttl's time window.
*/
limit: number;

/**
* The number of milliseconds that each request will last in storage.
*/
ttl: number;

/**
* Incoming options of the throttler.
*/
throttler: ThrottlerOptions;

/**
* The number of millisecond the request will be blocked.
asif-jalil marked this conversation as resolved.
Show resolved Hide resolved
*/
blockDuration: number;

/**
* A method to override the default tracker string.
*/
getTracker: ThrottlerGetTrackerFunction;

/**
* A method to override the default key generator.
*/
generateKey: ThrottlerGenerateKeyFunction;
}
64 changes: 51 additions & 13 deletions src/throttler.guard.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,32 @@ class ThrottlerStorageServiceMock implements ThrottlerStorage {
return Math.floor((this.storage[key].expiresAt - Date.now()) / 1000);
}

async increment(key: string, ttl: number): Promise<ThrottlerStorageRecord> {
const ttlMilliseconds = ttl * 1000;
private getBlockExpirationTime(key: string): number {
return Math.floor((this.storage[key].blockExpiresAt - Date.now()) / 1000);
}

private fireHitCount(key: string, throttlerName: string) {
this.storage[key].totalHits[throttlerName]++;
}

async increment(
key: string,
ttl: number,
limit: number,
blockDuration: number,
throttlerName: string,
): Promise<ThrottlerStorageRecord> {
const ttlMilliseconds = ttl;
const blockDurationMilliseconds = blockDuration;
if (!this.storage[key]) {
this.storage[key] = { totalHits: 0, expiresAt: Date.now() + ttlMilliseconds };
this.storage[key] = {
totalHits: {
[throttlerName]: 0,
},
expiresAt: Date.now() + ttlMilliseconds,
blockExpiresAt: 0,
isBlocked: false,
};
}

let timeToExpire = this.getExpirationTime(key);
Expand All @@ -32,11 +54,27 @@ class ThrottlerStorageServiceMock implements ThrottlerStorage {
timeToExpire = this.getExpirationTime(key);
}

this.storage[key].totalHits++;
if (!this.storage[key].isBlocked) {
this.fireHitCount(key, throttlerName);
}

// Reset the blockExpiresAt once it gets blocked
if (this.storage[key].totalHits[throttlerName] > limit && !this.storage[key].isBlocked) {
this.storage[key].isBlocked = true;
this.storage[key].blockExpiresAt = Date.now() + blockDurationMilliseconds;
}

const timeToBlockExpire = this.getBlockExpirationTime(key);

if (timeToBlockExpire <= 0 && this.storage[key].isBlocked) {
this.fireHitCount(key, throttlerName);
}

return {
totalHits: this.storage[key].totalHits,
totalHits: this.storage[key].totalHits[throttlerName],
timeToExpire,
isBlocked: this.storage[key].isBlocked,
timeToBlockExpire: timeToBlockExpire,
};
}
}
Expand All @@ -50,28 +88,28 @@ function contextMockFactory(
getClass: () => ThrottlerStorageServiceMock as any,
getHandler: () => handler,
switchToRpc: () => ({
getContext: () => ({} as any),
getData: () => ({} as any),
getContext: () => ({}) as any,
getData: () => ({}) as any,
}),
getArgs: () => [] as any,
getArgByIndex: () => ({} as any),
getArgByIndex: () => ({}) as any,
getType: () => type as any,
};
switch (type) {
case 'ws':
executionPartial.switchToHttp = () => ({} as any);
executionPartial.switchToHttp = () => ({}) as any;
executionPartial.switchToWs = () => mockFunc as any;
break;
case 'http':
executionPartial.switchToWs = () => ({} as any);
executionPartial.switchToWs = () => ({}) as any;
executionPartial.switchToHttp = () => mockFunc as any;
break;
case 'graphql':
executionPartial.switchToWs = () => ({} as any);
executionPartial.switchToWs = () => ({}) as any;
executionPartial.switchToHttp = () =>
({
getNext: () => ({} as any),
} as any);
getNext: () => ({}) as any,
}) as any;
executionPartial.getArgByIndex = () => mockFunc as any;
break;
}
Expand Down
42 changes: 28 additions & 14 deletions src/throttler.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from './throttler-module-options.interface';
import { ThrottlerStorage } from './throttler-storage.interface';
import {
THROTTLER_BLOCK_DURATION,
THROTTLER_KEY_GENERATOR,
THROTTLER_LIMIT,
THROTTLER_SKIP,
Expand All @@ -18,7 +19,7 @@ import {
} from './throttler.constants';
import { InjectThrottlerOptions, InjectThrottlerStorage } from './throttler.decorator';
import { ThrottlerException, throttlerMessage } from './throttler.exception';
import { ThrottlerLimitDetail } from './throttler.guard.interface';
import { ThrottlerLimitDetail, ThrottlerRequest } from './throttler.guard.interface';

/**
* @publicApi
Expand Down Expand Up @@ -99,6 +100,10 @@ export class ThrottlerGuard implements CanActivate {
THROTTLER_TTL + namedThrottler.name,
[handler, classRef],
);
const routeOrClassBlockDuration = this.reflector.getAllAndOverride<Resolvable<number>>(
THROTTLER_BLOCK_DURATION + namedThrottler.name,
[handler, classRef],
);
const routeOrClassGetTracker = this.reflector.getAllAndOverride<ThrottlerGetTrackerFunction>(
THROTTLER_TRACKER + namedThrottler.name,
[handler, classRef],
Expand All @@ -112,13 +117,24 @@ export class ThrottlerGuard implements CanActivate {
// Check if specific limits are set at class or route level, otherwise use global options.
const limit = await this.resolveValue(context, routeOrClassLimit || namedThrottler.limit);
const ttl = await this.resolveValue(context, routeOrClassTtl || namedThrottler.ttl);
const blockDuration = await this.resolveValue(
context,
routeOrClassBlockDuration || namedThrottler.blockDuration || ttl,
);
const getTracker =
routeOrClassGetTracker || namedThrottler.getTracker || this.commonOptions.getTracker;
const generateKey =
routeOrClassGetKeyGenerator || namedThrottler.generateKey || this.commonOptions.generateKey;

continues.push(
await this.handleRequest(context, limit, ttl, namedThrottler, getTracker, generateKey),
await this.handleRequest({
context,
limit,
ttl,
throttler: namedThrottler,
blockDuration,
getTracker,
generateKey,
}),
);
}
return continues.every((cont) => cont);
Expand All @@ -134,14 +150,9 @@ export class ThrottlerGuard implements CanActivate {
* @see https://tools.ietf.org/id/draft-polli-ratelimit-headers-00.html#header-specifications
* @throws {ThrottlerException}
*/
protected async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number,
throttler: ThrottlerOptions,
getTracker: ThrottlerGetTrackerFunction,
generateKey: ThrottlerGenerateKeyFunction,
): Promise<boolean> {
protected async handleRequest(requestProps: ThrottlerRequest): Promise<boolean> {
const { context, limit, ttl, throttler, blockDuration, getTracker, generateKey } = requestProps;

// Here we start to check the amount of requests being done against the ttl.
const { req, res } = this.getRequestResponse(context);
const ignoreUserAgents = throttler.ignoreUserAgents ?? this.commonOptions.ignoreUserAgents;
Expand All @@ -155,20 +166,23 @@ export class ThrottlerGuard implements CanActivate {
}
const tracker = await getTracker(req);
const key = generateKey(context, tracker, throttler.name);
const { totalHits, timeToExpire } = await this.storageService.increment(key, ttl);
const { totalHits, timeToExpire, isBlocked, timeToBlockExpire } =
await this.storageService.increment(key, ttl, limit, blockDuration, throttler.name);

const getThrottlerSuffix = (name: string) => (name === 'default' ? '' : `-${name}`);

// Throw an error when the user reached their limit.
if (totalHits > limit) {
res.header(`Retry-After${getThrottlerSuffix(throttler.name)}`, timeToExpire);
if (isBlocked) {
res.header(`Retry-After${getThrottlerSuffix(throttler.name)}`, timeToBlockExpire);
await this.throwThrottlingException(context, {
limit,
ttl,
key,
tracker,
totalHits,
timeToExpire,
isBlocked,
timeToBlockExpire,
});
}

Expand Down
Loading