/
index.ts
165 lines (146 loc) · 4.09 KB
/
index.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/**
* Module dependencies.
*/
import { Redis } from "ioredis"
/**
* Expose `tokenBucket()`.
*
* Initialize ratelimit middleware with the given `opts`:
*
* - `db` database connection Redis, Map instance if memory
* - `id` id to compare requests [ip]
* - `headers` custom header names
* - `tokens` tokens number of requests ['X-RateLimit-Tokens']
* - `rate` rate timestamp ['X-RateLimit-Rate']
* - `capacity` capacity number of requests ['X-RateLimit-Capacity']
* - `whitelist` whitelist function [false]
* - `blacklist` blacklist function [false]
* - `throw` call ctx.throw if true
*
* @param {Object} opts
* @return {Function}
* @api public
*/
export type RateLimit = {
driver?: "memory" | "redis";
redis?: Redis;
headers?: any;
id?: (ctx) => string | boolean;
whitelist?: (ctx) => boolean;
blacklist?: (ctx) => boolean;
disableHeader?: boolean;
status?: number;
errorMessage?: string;
throw?: boolean;
rate?: number;
capacity?: number;
namespace?: string;
};
type GetTokenOptions = {
tokens?:number,
lastRefillTime?:number,
}
module.exports = function ratelimit(options: RateLimit = {}) {
const defaultOpts = {
driver: "memory",
id: (ctx) => ctx.ip,
headers: {
rate: "X-RateLimit-Rate", //1秒生成多少个令牌
tokens: "X-RateLimit-Tokens", //当前令牌数
capacity: "X-RateLimit-Capacity", //总令牌桶数
},
rate: 10, //1秒生成多少个令牌
capacity: 100, //总令牌桶数
namespace: "limit",
};
let opts = { ...defaultOpts, ...options };
const {
rate = "X-RateLimit-Rate",
tokens = "X-RateLimit-Tokens",
capacity = "X-RateLimit-Capacity",
} = opts.headers;
if(opts.driver === 'redis'&&!(opts.redis instanceof Redis)){
throw new Error("Invalid options reids should be ioreids instance")
}
const db = createStore(opts.driver, opts.redis, new Map());
function getTokens(options:GetTokenOptions = {}) {
const { tokens = opts.capacity, lastRefillTime = Date.now() } = options;
const currentTime = Date.now();
const timeElapsed = currentTime - lastRefillTime;
const tokensToAdd = (timeElapsed * opts.rate) / 1000; // 生成令牌数量
const currentTokens = Math.min(tokens + tokensToAdd, opts.capacity); //当前令牌数量
return {
tokens: currentTokens,
lastRefillTime: currentTime,
};
}
return async function ratelimit(ctx, next) {
const id = opts.id(ctx);
const key = `${opts.namespace}:${id}`;
const whitelisted =
typeof opts.whitelist === "function" && (await opts.whitelist(ctx));
const blacklisted =
typeof opts.blacklist === "function" && (await opts.blacklist(ctx));
if (blacklisted) {
ctx.throw(403, "Forbidden");
}
if (id === false || whitelisted) return await next();
let pass = false;
let token = getTokens(await db.get(key));
if (token.tokens < 1) {
pass = false;
} else {
token.tokens -= 1;
pass = true;
}
await db.set(key, token);
// check if header disabled
const disableHeader = opts.disableHeader || false;
let headers = {};
if (!disableHeader) {
// header fields
headers = {
[rate]: opts.rate,
[tokens]: token.tokens,
[capacity]: opts.capacity,
};
ctx.set(headers);
}
if (pass) return await next();
ctx.status = opts.status || 429;
ctx.body = opts.errorMessage || `Rate limit exceeded.`;
if (opts.throw) {
ctx.throw(ctx.status, ctx.body, { headers });
}
};
};
function createStore(driver, redis, map) {
let db;
if (driver === "redis") {
db = {
async get(key) {
const value = await redis.get(key)
if(value){
return JSON.parse(value)
}else{
return undefined
}
},
async set(key, value) {
return redis.set(key, JSON.stringify(value));
},
};
} else {
db = {
async get(key) {
const value = map.get(key);
return await value;
},
async set(key, value) {
map.set(key, value);
return await value;
},
};
}
return db;
}