diff --git a/README.md b/README.md index c3fbd90..3bc83b9 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ The protocol expose commands to interact with the distributed storage : - TOUCH : update the ttl of a key - DELETE : delete a key from the storage - INCR : increment the value for a key +- RATE_LIMITER : consume a token in the sliding window rate limiter identified by the key In a clustered deployment (2 or more instances), a client need to connect to only one instance to see all the storage. The goal is to provide a near storage associated with a nginx instance. @@ -136,42 +137,41 @@ named hazelcast.xml. This is an example of this file : ```xml - - + - - ngx-dshm - FIXME - + xsi:schemaLocation="http://www.hazelcast.com/schema/config + http://www.hazelcast.com/schema/config/hazelcast-config-5.0.xsd"> + + ngx-dshm + 5701 - 10.0.x.y + 127.0.0.1 - 10.0.x.y:5701 - 10.0.x.z:5701 + 127.0.0.1 + BINARY 1 0 0 0 - NONE - 0 - 25 - 100 - com.hazelcast.map.merge.PutIfAbsentMapMergePolicy + + PutIfAbsentMergePolicy INDEX-ONLY + + ``` The reference documentation for this configuration is @@ -363,6 +363,27 @@ INCR key -1 0\r\n INCR key -1 0 60\r\n ``` +**_RATE_LIMITER \ \ \_** + +**with data:** _no_ + +Consumes a token in a sliding window rate limiter with the key `key`. The sliding window duration is configured with `duration` seconds. The +rate limiter is created automatically + +The command attempts to consume a token and return the number of remaining available tokens. If there were no more tokens available, the +command returns -1, otherwise +the command return the number of tokens available between 0 and `capacity` + +note : GET command with this key return the available tokens + +This operation is atomic. + +Example : consumes a token in rate limiter `key` with capacity 1000 tokens every 10 seconds + +``` +RATE_LIMITER key 1000 10\r\n +``` + **_FLUSHALL [region]_** **with data:** _no_ @@ -460,10 +481,10 @@ The session_storage parameter control the storage module to be used. - An official docker image build is available at quay.io or directly in the GitHub registry : ```shell - docker pull quay.io/grrolland/ngx-distributed-shm + docker pull quay.io/grrolland/ngx-distributed-shm ``` ```shell - docker pull ghcr.io/grrolland/ngx-distributed-shm + docker pull ghcr.io/grrolland/ngx-distributed-shm ``` ## Kubernetes diff --git a/lua/dshm.lua b/lua/dshm.lua index 14a7327..7ebaca2 100644 --- a/lua/dshm.lua +++ b/lua/dshm.lua @@ -8,6 +8,7 @@ local tonumber = tonumber local math = math local tostring = tostring local table = table +local type = type local ngx = ngx local _M = { @@ -17,6 +18,11 @@ local _M = { local mt = { __index = _M } +--- +---Escape the key +---@param key string the key +---@return string the escaped key +--- local function escape(key) local i = 1 local result = {} @@ -28,7 +34,13 @@ local function escape(key) return table.concat(result, ":") end -function _M.new(self, opts) +--- +---Constructor +---@param _ self instance +---@param opts table options (optional) in order to override espace_key/unescape_key functions +---@return self instance +--- +function _M.new(_, opts) local sock, err = tcp() if not sock then return nil, err @@ -55,7 +67,11 @@ function _M.new(self, opts) unescape_key = unescape_key, }, mt) end - +--- +---Read response line +---@param self self the dshm instance +---@param data string +--- local function read_response_line(self, data) local sock = self.sock if not sock then @@ -88,7 +104,8 @@ local function read_response_line(self, data) ngx.log(ngx.DEBUG, "PROTOCOL ERROR") return nil, _M.PROTOCOL_ERROR else - local data, err = sock:receive(data) + local err + data, err = sock:receive(data) if not data then if err == "timeout" then sock:close() @@ -97,7 +114,8 @@ local function read_response_line(self, data) end ngx.log(ngx.DEBUG, "RECEIVE : ", data) -- Discard trailing \r\n - local trail, err = sock:receive() + local trail + trail, err = sock:receive() if not trail then if err == "timeout" then sock:close() @@ -110,25 +128,65 @@ local function read_response_line(self, data) end end -function _M.get(self, key) - - ngx.log(ngx.DEBUG, "Get : ", key) - +--- +---Send command +---@param self self instance +---@param command string the command to send +---@param key string the key +---@param params table the command arguments (optional) +---@param data string the data to send (optional) +---@return number number of bytes sent +---@return any error +--- +local function send_command(self, command, key, params, data) local sock = self.sock if not sock then return nil, "not initialized" end + -- Add args + local str_arg = "" + if params then + if type(params) == "table" then + str_arg = table.concat(params, " ") + else + str_arg = params + end + str_arg = " " .. str_arg + end + -- Add new line separator + str_arg = str_arg .. "\r\n" + + -- Add data + if data then + str_arg = str_arg .. data + end - local bytes, err = sock:send("get " .. self.escape_key(key) .. "\r\n") + -- Prepare the full command + local str_command = command + if key then + str_command = str_command .. " " .. self.escape_key(key) .. str_arg + end + ngx.log(ngx.DEBUG, "send command to dshm:", str_command) + local bytes, err = sock:send(str_command) if not bytes then return nil, err end + return bytes, nil +end + +--- +---Read and parse response that return DATA +---@param self self the dshm instance +---@return string data +---@return string error +--- +local function read_response_data(self) local resp, data = read_response_line(self) if resp == "LEN" then resp, data = read_response_line(self, data) if resp == "DATA" then - local resp, _ = read_response_line(self) + resp = read_response_line(self) if resp == "DONE" then return data, nil else @@ -138,28 +196,46 @@ function _M.get(self, key) return nil, _M.PROTOCOL_ERROR end elseif resp == "ERROR" then - if data == "not_found" then - return nil, "not found" - else - return nil, data - end + return nil, data else return nil, _M.PROTOCOL_ERROR end +end +--- +---Execute command get and return the data +---@param self self the dshm instance +---@param key string the key +---@return string data +--- +function _M.get(self, key) + + ngx.log(ngx.DEBUG, "Get : ", key) + + local _, err = send_command(self, "get", key) + + if err then + return nil, err + end + local resp + resp, err = read_response_data(self) + if err and err == "not_found" then + err = "not found" + end + return resp, err end +---Delete the key +---@param self self the dshm instance +---@param key string the key to delete +---@return number 1 if key has been deleted or nil +---@return string error function _M.delete(self, key) ngx.log(ngx.DEBUG, "Delete : ", key) - local sock = self.sock - if not sock then - return nil, "not initialized" - end - - local bytes, err = sock:send("delete " .. self.escape_key(key) .. "\r\n") - if not bytes then + local _, err = send_command(self, "delete", key) + if err then return nil, err end @@ -173,116 +249,75 @@ function _M.delete(self, key) end end - +---Increment the counter +---@param self self the dshm instance +---@param key string the key +---@param value string the incr value (example : 1, -1, 2) +---@param init string the initial value (optional). If counter doesn't exist, counter is initialized with this value +---@param init_ttl number the initial TTL value (optional). If counter doesn't exist, counter ttl is initialized with this value +---@return string data the counter value after command execution +---@return string error function _M.incr(self, key, value, init, init_ttl) ngx.log(ngx.DEBUG, "Incr : ", key, ", Value : ", value, ", Init : ", init, ", Init_TTL", init_ttl) - local sock = self.sock - if not sock then - return nil, "not initialized" - end - - local l_init = 0 - if init then - l_init = init - end - - local s_init_ttl = "" - if init_ttl then - s_init_ttl = " " .. init_ttl - end + local l_init = init or 0 + local l_init_ttl = init_ttl or 0 - local command = "incr " .. self.escape_key(key) .. " " .. value .. " " .. l_init .. s_init_ttl .. "\r\n" - local bytes, err = sock:send(command) - if not bytes then + local params = { value, l_init, l_init_ttl } + local _, err = send_command(self, "incr", key, params) + if err then return nil, err end - - local resp, data = read_response_line(self) - if resp == "LEN" then - resp, data = read_response_line(self, data) - if resp == "DATA" then - local resp, _ = read_response_line(self) - if resp == "DONE" then - return data, nil - else - return nil, _M.PROTOCOL_ERROR - end - else - return nil, _M.PROTOCOL_ERROR - end - elseif resp == "ERROR" then - return nil, data - else - return nil, _M.PROTOCOL_ERROR - end + return read_response_data(self) end +---Set a key +---@param self self the dshm instance +---@param key string the key +---@param value string the value +---@param exptime number the TTL value (optional). +---@return string data the value +---@return string error function _M.set(self, key, value, exptime) - if not exptime or exptime == 0 then - exptime = 0 - else + if exptime and exptime ~= 0 then exptime = math.floor(exptime + 0.5) end - ngx.log(ngx.DEBUG, "Key : ", key, ", Value : ", value, ", Exp : ", exptime) + ngx.log(ngx.DEBUG, "set Key : ", key, ", Value : ", value, ", Exp : ", exptime) local sock = self.sock if not sock then return nil, "not initialized" end + local params = { exptime, strlen(value) } + local _, err = send_command(self, "set", key, params, value) - local req = "set " .. self.escape_key(key) .. " " - .. exptime .. " " .. strlen(value) .. "\r\n" .. value - - local bytes, err = sock:send(req) - if not bytes then + if err then return nil, err end - local resp, data = read_response_line(self) - if resp == "LEN" then - resp, data = read_response_line(self, data) - if resp == "DATA" then - local resp, _ = read_response_line(self) - if resp == "DONE" then - return data, nil - else - return nil, _M.PROTOCOL_ERROR - end - else - return nil, _M.PROTOCOL_ERROR - end - elseif resp == "ERROR" then - return nil, data - else - return nil, _M.PROTOCOL_ERROR - end + return read_response_data(self) end - +---Touch a key +---@param self self the dshm instance +---@param key string the key +---@param exptime number the new TTL value (optional). +---@return string data the value +---@return string error function _M.touch(self, key, exptime) - if not exptime or exptime == 0 then - exptime = 0 - else + if exptime and exptime ~= 0 then exptime = math.floor(exptime + 0.5) end ngx.log(ngx.DEBUG, "Touch : ", key, ", Exp : ", exptime) - local sock = self.sock - if not sock then - ngx.log(ngx.DEBUG, "Socket not initialized") - return nil, "not initialized" - end - - local bytes, err = sock:send("touch " .. self.escape_key(key) .. " " - .. exptime .. "\r\n") - if not bytes then + local _, err = send_command(self, "touch", key, exptime) + if err then return nil, err end @@ -297,17 +332,52 @@ function _M.touch(self, key, exptime) end -function _M.quit(self) +--- +--- Sliding window rate limiter command. +---Rate limiter will try to 'consume' a token and return the remaining tokens available. +---If no tokens were available, this method will return nil, "rejected" +---Otherwise return the next remaining tokens available in the window +--- +---@param self self the dshm instance +---@param self string the key +---@param capacity number the tokens capacity +---@param duration number the sliding window duration in seconds +---@return number the remaining tokens available or nil if quota is exceeded +---@return string nil or error. Error code is rejected when quota is exceeded +--- +function _M.rate_limiter(self, key, capacity, duration) + + ngx.log(ngx.DEBUG, "rate_limiter : ", key, ", capacity : ", capacity, ", duration : ", duration) + local sock = self.sock if not sock then return nil, "not initialized" end - local bytes, err = sock:send("quit\r\n") - if not bytes then + local params = { capacity, duration } + local _, err = send_command(self, "rate_limiter", key, params) + if err then return nil, err end + local resp + resp, err = read_response_data(self) + if resp == "-1" then + resp = nil + err = "rejected" + end + return resp, err +end +--- +--- +---Quit command +---@return number 1 when successful +---@return string error +function _M.quit(self) + local _, err = send_command(self, "quit") + if err then + return nil, err + end return 1 end diff --git a/src/main/java/io/github/grrolland/hcshm/AbstractShmValue.java b/src/main/java/io/github/grrolland/hcshm/AbstractShmValue.java new file mode 100644 index 0000000..88175e0 --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/AbstractShmValue.java @@ -0,0 +1,33 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm; + +import java.io.Serializable; + +/*** + * Base class for value stored in Hazelcast + */ +public abstract class AbstractShmValue implements Serializable { + + /** + * Get string value + * + * @return the value as string + */ + public abstract String getValue(); +} diff --git a/src/main/java/io/github/grrolland/hcshm/ShmRegionLocator.java b/src/main/java/io/github/grrolland/hcshm/ShmRegionLocator.java index 8329dcd..5413852 100644 --- a/src/main/java/io/github/grrolland/hcshm/ShmRegionLocator.java +++ b/src/main/java/io/github/grrolland/hcshm/ShmRegionLocator.java @@ -35,7 +35,7 @@ public class ShmRegionLocator implements Serializable { * hazelcast instance * @return return the named IMap, if no region in the key return the default IMap */ - public IMap getMap(final HazelcastInstance hazelcast, final String key) { + public IMap getMap(final HazelcastInstance hazelcast, final String key) { return getMapRegion(hazelcast, getRegion(key)); } @@ -48,7 +48,7 @@ public IMap getMap(final HazelcastInstance hazelcast, final St * hazelcast instance * @return return the named IMap, if no region return the default IMap */ - public IMap getMapRegion(final HazelcastInstance hazelcast, final String region) { + public IMap getMapRegion(final HazelcastInstance hazelcast, final String region) { if (null != region) { return hazelcast.getMap(region); } else { diff --git a/src/main/java/io/github/grrolland/hcshm/ShmService.java b/src/main/java/io/github/grrolland/hcshm/ShmService.java index f89c55a..0b462ca 100644 --- a/src/main/java/io/github/grrolland/hcshm/ShmService.java +++ b/src/main/java/io/github/grrolland/hcshm/ShmService.java @@ -20,8 +20,12 @@ import com.hazelcast.core.HazelcastInstance; import com.hazelcast.map.IMap; import io.github.grrolland.hcshm.processor.IncrProcessor; +import io.github.grrolland.hcshm.processor.RateLimiterProcessor; import io.github.grrolland.hcshm.processor.TouchProcessor; +import io.github.grrolland.hcshm.ratelimiter.ConsumptionProbe; +import io.github.grrolland.hcshm.ratelimiter.RateLimiterShmValue; +import java.time.Duration; import java.util.concurrent.TimeUnit; /** @@ -57,7 +61,8 @@ public ShmService(HazelcastInstance hazelcast) { * @return the value as string or the error */ public String get(String key) { - ShmValue r = getMap(key).get(key); + IMap map = getMap(key); + AbstractShmValue r = map.get(key); if (null != r) { return r.getValue(); } else { @@ -107,7 +112,8 @@ public String set(String key, long value, long expire) { * the expiration in milliseconds */ public void touch(String key, long expire) { - getMap(key).executeOnKey(key, new TouchProcessor(expire)); + IMap map = getMap(key); + map.executeOnKey(key, new TouchProcessor(expire)); } /** @@ -124,7 +130,8 @@ public void touch(String key, long expire) { * @return the new value as string representation */ public String incr(String key, int value, int init, long initialExpire) { - return (String) getMap(key).executeOnKey(key, new IncrProcessor(value, init, initialExpire)); + IMap map = getMap(key); + return (String) map.executeOnKey(key, new IncrProcessor(value, init, initialExpire)); } /** @@ -147,6 +154,20 @@ public void flushall(String region) { regionLocator.getMapRegion(hazelcast, region).clear(); } + /*** + * Consume a token + * @param key the key + * @param capacity the maximum capacity + * @param duration the duration of a token in seconds + * @return the number of tokens remaining + */ + public String rateLimiter(String key, int capacity, long duration) { + final IMap map = regionLocator.getMap(hazelcast, key); + RateLimiterProcessor rateLimiterProcessor = new RateLimiterProcessor(capacity, Duration.ofMillis(duration)); + ConsumptionProbe consumptionProbe = (ConsumptionProbe) map.executeOnKey(key, rateLimiterProcessor); + return String.valueOf(consumptionProbe.getRemainingTokens()); + } + /** * Get the map form the key name * @@ -154,7 +175,7 @@ public void flushall(String region) { * the key * @return return the named IMap, if no region in the key return the default IMap */ - private IMap getMap(final String key) { + private IMap getMap(final String key) { return regionLocator.getMap(hazelcast, key); } diff --git a/src/main/java/io/github/grrolland/hcshm/ShmValue.java b/src/main/java/io/github/grrolland/hcshm/ShmValue.java index 01820bd..7584e62 100644 --- a/src/main/java/io/github/grrolland/hcshm/ShmValue.java +++ b/src/main/java/io/github/grrolland/hcshm/ShmValue.java @@ -17,12 +17,10 @@ */ package io.github.grrolland.hcshm; -import java.io.Serializable; - /** * Value in the SHM Map */ -public class ShmValue implements Serializable { +public class ShmValue extends AbstractShmValue { /** * The value diff --git a/src/main/java/io/github/grrolland/hcshm/commands/Command.java b/src/main/java/io/github/grrolland/hcshm/commands/Command.java index f26065b..acf1970 100644 --- a/src/main/java/io/github/grrolland/hcshm/commands/Command.java +++ b/src/main/java/io/github/grrolland/hcshm/commands/Command.java @@ -42,10 +42,14 @@ public abstract class Command { protected static final String RESPONSE_LINE_DELIMITER = "\r\n"; /** - * Protocol response : Error malformed resuqest + * Protocol response : Error malformed request */ protected static final String ERROR_MALFORMED_REQUEST = "ERROR malformed_request"; + /** + * Protocol response : Error malformed bad request + */ + protected static final String ERROR_BAD_REQUEST = "ERROR bad_request"; /** * Protocol response : Error malformed request */ @@ -280,6 +284,17 @@ protected void writeMalformedRequest(StringBuilder response) { response.append(RESPONSE_LINE_DELIMITER); } + /** + * Write the LEN protocol line + * + * @param response + * the constructing response + */ + protected void writeBadRequest(StringBuilder response) { + response.append(ERROR_BAD_REQUEST); + response.append(RESPONSE_LINE_DELIMITER); + } + /** * Write the LEN protocol line * diff --git a/src/main/java/io/github/grrolland/hcshm/commands/CommandFactory.java b/src/main/java/io/github/grrolland/hcshm/commands/CommandFactory.java index 6b227d0..dab88bb 100644 --- a/src/main/java/io/github/grrolland/hcshm/commands/CommandFactory.java +++ b/src/main/java/io/github/grrolland/hcshm/commands/CommandFactory.java @@ -70,6 +70,9 @@ public Command get(String[] commandTokens) { case FLUSHALL: command = new FlushAllCommand(service); break; + case RATE_LIMITER: + command = new RateLimiterCommand(service); + break; default: command = new UnknownCommand(service); break; diff --git a/src/main/java/io/github/grrolland/hcshm/commands/CommandVerb.java b/src/main/java/io/github/grrolland/hcshm/commands/CommandVerb.java index e298310..02ef809 100644 --- a/src/main/java/io/github/grrolland/hcshm/commands/CommandVerb.java +++ b/src/main/java/io/github/grrolland/hcshm/commands/CommandVerb.java @@ -37,6 +37,10 @@ public enum CommandVerb { * The INCR command */ INCR, + /** + * The RATE_LIMITER command + */ + RATE_LIMITER, /** * The Quit Command */ diff --git a/src/main/java/io/github/grrolland/hcshm/commands/IncrCommand.java b/src/main/java/io/github/grrolland/hcshm/commands/IncrCommand.java index 71f494e..bac0390 100644 --- a/src/main/java/io/github/grrolland/hcshm/commands/IncrCommand.java +++ b/src/main/java/io/github/grrolland/hcshm/commands/IncrCommand.java @@ -19,6 +19,7 @@ import io.github.grrolland.hcshm.ProtocolException; import io.github.grrolland.hcshm.ShmService; +import io.github.grrolland.hcshm.processor.BadRequestException; /** * The INCR Command @@ -50,12 +51,15 @@ public String execute(String[] commandTokens) { int incr = getIncrValue(commandTokens[2]); int initial = getIncrValue(commandTokens[3]); long initialExpire = commandTokens.length == 5 ? getExpire(commandTokens[4]) : 0; + String value = getService().incr(key, incr, initial, initialExpire); writeLen(response, value); writeValue(response, value); writeDone(response); } catch (ProtocolException e) { writeMalformedRequest(response); + } catch (BadRequestException e) { + writeBadRequest(response); } return response.toString(); } diff --git a/src/main/java/io/github/grrolland/hcshm/commands/RateLimiterCommand.java b/src/main/java/io/github/grrolland/hcshm/commands/RateLimiterCommand.java new file mode 100644 index 0000000..d273cb6 --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/commands/RateLimiterCommand.java @@ -0,0 +1,64 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm.commands; + +import io.github.grrolland.hcshm.ProtocolException; +import io.github.grrolland.hcshm.ShmService; +import io.github.grrolland.hcshm.processor.BadRequestException; + +/** + * The Rate limiter Command + */ +public class RateLimiterCommand extends Command { + + /** + * Default Constructor + * + * @param service + * the shm service + */ + RateLimiterCommand(ShmService service) { + super(service); + } + + /** + * Execute the command + * + * @param commandTokens + * the protocol tokens argument of the command + * @return the result of the command 'protocol encoded' + */ + public String execute(String[] commandTokens) { + final StringBuilder response = new StringBuilder(); + try { + assertTokens(commandTokens, 4); + String key = getKey(commandTokens[1]); + int capacity = getIncrValue(commandTokens[2]); + long duration = getExpire(commandTokens[3]); + String value = getService().rateLimiter(key, capacity, duration); + writeLen(response, value); + writeValue(response, value); + writeDone(response); + } catch (ProtocolException e) { + writeMalformedRequest(response); + } catch (BadRequestException e) { + writeBadRequest(response); + } + return response.toString(); + } +} diff --git a/src/main/java/io/github/grrolland/hcshm/commands/TouchCommand.java b/src/main/java/io/github/grrolland/hcshm/commands/TouchCommand.java index 1814c48..0f8b2eb 100644 --- a/src/main/java/io/github/grrolland/hcshm/commands/TouchCommand.java +++ b/src/main/java/io/github/grrolland/hcshm/commands/TouchCommand.java @@ -19,6 +19,7 @@ import io.github.grrolland.hcshm.ProtocolException; import io.github.grrolland.hcshm.ShmService; +import io.github.grrolland.hcshm.processor.BadRequestException; /** * The Touch Command @@ -52,6 +53,8 @@ public String execute(String[] commandTokens) { writeDone(response); } catch (ProtocolException e) { writeMalformedRequest(response); + } catch (BadRequestException e) { + writeBadRequest(response); } return response.toString(); } diff --git a/src/main/java/io/github/grrolland/hcshm/processor/BadRequestException.java b/src/main/java/io/github/grrolland/hcshm/processor/BadRequestException.java new file mode 100644 index 0000000..e3246cb --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/processor/BadRequestException.java @@ -0,0 +1,16 @@ +package io.github.grrolland.hcshm.processor; + +/** + * The request is well-formed, but the command is not possible with the key stored in Hazelcast is not compatible with the processor + */ +public class BadRequestException extends RuntimeException { + /** + * Constructor + * + * @param cause + * the cause + */ + public BadRequestException(final ClassCastException cause) { + super(cause); + } +} diff --git a/src/main/java/io/github/grrolland/hcshm/processor/IncrProcessor.java b/src/main/java/io/github/grrolland/hcshm/processor/IncrProcessor.java index 2b440dc..29d401c 100644 --- a/src/main/java/io/github/grrolland/hcshm/processor/IncrProcessor.java +++ b/src/main/java/io/github/grrolland/hcshm/processor/IncrProcessor.java @@ -60,6 +60,14 @@ public IncrProcessor(long value, int init, long initialExpire) { this.initialExpire = initialExpire; } + private static ShmValue getCurrentValue(final Map.Entry entry) throws BadRequestException { + try { + return entry.getValue(); + } catch (ClassCastException e) { + throw new BadRequestException(e); + } + } + /** * Process the incrementation command * @@ -70,7 +78,7 @@ public IncrProcessor(long value, int init, long initialExpire) { @Override public Object process(Map.Entry entry) { - final ShmValue r = entry.getValue(); + final ShmValue r = getCurrentValue(entry); String newval; long expire; ExtendedMapEntry extendedMapEntry = (ExtendedMapEntry) entry; diff --git a/src/main/java/io/github/grrolland/hcshm/processor/RateLimiterProcessor.java b/src/main/java/io/github/grrolland/hcshm/processor/RateLimiterProcessor.java new file mode 100644 index 0000000..8ca750a --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/processor/RateLimiterProcessor.java @@ -0,0 +1,87 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm.processor; + +import com.hazelcast.map.EntryProcessor; +import io.github.grrolland.hcshm.ratelimiter.ConsumptionProbe; +import io.github.grrolland.hcshm.ratelimiter.RateLimiterShmValue; + +import java.io.Serializable; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +/** + * Processor for the RATE_LIMITER command + */ +public class RateLimiterProcessor implements EntryProcessor, Serializable { + /** + * Capacity + */ + private final int capacity; + + /** + * Sliding window duration + */ + private final Duration duration; + + /** + * Constructor + * + * @param capacity + * the maximum capacity of the rate limiter + * @param duration + * the sliding window duration + */ + public RateLimiterProcessor(int capacity, Duration duration) { + this.capacity = capacity; + this.duration = duration; + } + + @Override + public ConsumptionProbe process(final Map.Entry entry) { + RateLimiterShmValue rateLimiterShmValue = Optional.ofNullable(getCurrentValue(entry)).orElseGet(this::create); + rateLimiterShmValue.setDuration(this.duration); + rateLimiterShmValue.setCapacity(this.capacity); + final ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + entry.setValue(rateLimiterShmValue); + return consumptionProbe; + } + + /** + * Get the current value + * + * @param entry + * the entry + * @return the current value + * @throws BadRequestException + * exception + */ + private RateLimiterShmValue getCurrentValue(final Map.Entry entry) throws BadRequestException { + try { + return entry.getValue(); + } catch (ClassCastException e) { + throw new BadRequestException(e); + } + } + + private RateLimiterShmValue create() { + return new RateLimiterShmValue(this.capacity, this.duration); + } + +} diff --git a/src/main/java/io/github/grrolland/hcshm/processor/TouchProcessor.java b/src/main/java/io/github/grrolland/hcshm/processor/TouchProcessor.java index cb0a038..356e99f 100644 --- a/src/main/java/io/github/grrolland/hcshm/processor/TouchProcessor.java +++ b/src/main/java/io/github/grrolland/hcshm/processor/TouchProcessor.java @@ -43,6 +43,14 @@ public TouchProcessor(long expire) { this.expire = expire; } + private static ShmValue getCurrentValue(final Map.Entry entry) throws BadRequestException { + try { + return entry.getValue(); + } catch (ClassCastException e) { + throw new BadRequestException(e); + } + } + /** * Touch process * @@ -52,7 +60,7 @@ public TouchProcessor(long expire) { */ @Override public Object process(Map.Entry entry) { - final ShmValue r = entry.getValue(); + final ShmValue r = getCurrentValue(entry); if (null != r) { r.expire(expire); ((ExtendedMapEntry) entry).setValue(r, expire, TimeUnit.MILLISECONDS); diff --git a/src/main/java/io/github/grrolland/hcshm/ratelimiter/ConsumptionProbe.java b/src/main/java/io/github/grrolland/hcshm/ratelimiter/ConsumptionProbe.java new file mode 100644 index 0000000..5efb552 --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/ratelimiter/ConsumptionProbe.java @@ -0,0 +1,52 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm.ratelimiter; + +/** + * Describes token consumed, and number of tokens remaining + *

+ * remainingTokens -1 means no tokens have been consumed + * remainingTokens 0 means there is no more tokens to consume + */ +public class ConsumptionProbe { + + /** + * Number of remaining tokens + */ + private final int remainingTokens; + + /** + * Return the number of remaining available tokens or -1 if capacity was exceeded + * + * @return the number of remaining available tokens or -1 + */ + public int getRemainingTokens() { + return this.remainingTokens; + + } + + /** + * Constructor + * + * @param remainingTokens + * the number of remaining token + */ + ConsumptionProbe(int remainingTokens) { + this.remainingTokens = remainingTokens; + } +} diff --git a/src/main/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValue.java b/src/main/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValue.java new file mode 100644 index 0000000..fc2aa1e --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValue.java @@ -0,0 +1,127 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm.ratelimiter; + +import io.github.grrolland.hcshm.AbstractShmValue; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +/** + * RateLimiterShmValue store rate limiter data + */ +public class RateLimiterShmValue extends AbstractShmValue { + + /** + * The number of tokens used + */ + private final List tokens; + /** + * The sliding window duration + */ + private Duration duration; + /** + * The capacity + */ + private int capacity; + + /** + * Set capacity + * + * @param pCapacity + * capacity + */ + public void setCapacity(final int pCapacity) { + this.capacity = pCapacity; + } + + /** + * Set duration + * + * @param pDuration + * duration + */ + public void setDuration(final Duration pDuration) { + this.duration = pDuration; + } + + /** + * Get the number of available tokens before capacity is exceeded + * + * @return the number of available tokens + */ + public int getRemaining() { + return Math.max(this.capacity - this.tokens.size(), 0); + } + + @Override + public String getValue() { + this.clearExpired(); + return String.valueOf(getRemaining()); + } + + /** + * Constructor + * + * @param capacity + * the capacity + * @param duration + * the sliding window duration + */ + public RateLimiterShmValue(int capacity, Duration duration) { + this.tokens = new ArrayList<>(capacity); + this.duration = duration; + this.capacity = capacity; + } + + /** + * Try to take a token and return the ConsumptionProbe + * + * @return the ConsumptionProbe + */ + public ConsumptionProbe take() { + // Clear expired tokens + this.clearExpired(); + + int remaining = -1; + // Try to consume + if (this.canConsume()) { + tokens.add(new Token()); + remaining = this.getRemaining(); + } + return new ConsumptionProbe(remaining); + } + + /** + * Can consume + * + * @return true if at least one token is available + */ + private boolean canConsume() { + return this.tokens.size() < this.capacity; + } + + /** + * Clear expired tokens + */ + private void clearExpired() { + + tokens.removeIf(pToken -> pToken.isExpired(this.duration)); + } +} diff --git a/src/main/java/io/github/grrolland/hcshm/ratelimiter/Token.java b/src/main/java/io/github/grrolland/hcshm/ratelimiter/Token.java new file mode 100644 index 0000000..b3ed5c2 --- /dev/null +++ b/src/main/java/io/github/grrolland/hcshm/ratelimiter/Token.java @@ -0,0 +1,50 @@ +/** + * ngx-distributed-shm + * Copyright (C) 2018 Flu.Tech + *

+ * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package io.github.grrolland.hcshm.ratelimiter; + +import java.io.Serializable; +import java.time.Duration; +import java.time.Instant; + +/** + * A token consumed at a specific date + */ +public class Token implements Serializable { + + /** + * Expiration + */ + private final long createdAt; + + /** + * Constructor + */ + public Token() { + this.createdAt = System.currentTimeMillis(); + } + + /** + * @param duration + * Duration + * @return true if expired + */ + boolean isExpired(Duration duration) { + // check if expiration date is before now + return Instant.ofEpochMilli(this.createdAt).plus(duration).isBefore(Instant.now()); + } +} diff --git a/src/test/java/io/github/grrolland/hcshm/HCSHMTestSuite.java b/src/test/java/io/github/grrolland/hcshm/HCSHMTestSuite.java index 3046a25..332bcde 100644 --- a/src/test/java/io/github/grrolland/hcshm/HCSHMTestSuite.java +++ b/src/test/java/io/github/grrolland/hcshm/HCSHMTestSuite.java @@ -19,6 +19,7 @@ import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; +import io.github.grrolland.hcshm.ratelimiter.RateLimiterShmValueTestCase; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.runner.RunWith; @@ -29,7 +30,9 @@ * Test Suite initializing the distributed SHM */ @RunWith(Suite.class) -@Suite.SuiteClasses({ShmValueTestCase.class, +@Suite.SuiteClasses({TouchTestCase.class, + RateLimiterShmValueTestCase.class, + ShmValueTestCase.class, DeleteTestCase.class, GetTestCase.class, IncrTestCase.class, @@ -37,6 +40,7 @@ SetTestCase.class, TouchTestCase.class, UnknownCommandTestCase.class, + RateLimiterTestCase.class, FlushAllTestCase.class}) public class HCSHMTestSuite { diff --git a/src/test/java/io/github/grrolland/hcshm/RateLimiterTestCase.java b/src/test/java/io/github/grrolland/hcshm/RateLimiterTestCase.java new file mode 100644 index 0000000..e5f9d63 --- /dev/null +++ b/src/test/java/io/github/grrolland/hcshm/RateLimiterTestCase.java @@ -0,0 +1,50 @@ +package io.github.grrolland.hcshm; + +import org.junit.Test; + +import java.io.IOException; + +public class RateLimiterTestCase extends AbstractHCSHMGetTestCase { + + /** + * Test Incrementation + */ + @Test + public void testConsume() throws IOException, InterruptedException { + + getWriter().write("RATE_LIMITER ratekey 10 1\r\n"); + getWriter().flush(); + assertResponseGetValue("9"); + + getWriter().write("RATE_LIMITER ratekey 10 1\r\n"); + getWriter().flush(); + assertResponseGetValue("8"); + + pause(2000); + + getWriter().write("RATE_LIMITER ratekey 10 1\r\n"); + getWriter().flush(); + assertResponseGetValue("9"); + + } + + @Test + public void testConsumeAll() throws IOException, InterruptedException { + + for (int i = 0; i < 10; i++) { + getWriter().write("RATE_LIMITER ratekey 10 2\r\n"); + getWriter().flush(); + assertResponseGetValue(String.valueOf(10 - 1 - i)); + } + + getWriter().write("RATE_LIMITER ratekey 10 1\r\n"); + getWriter().flush(); + assertResponseGetValue("-1"); + pause(3000); + + getWriter().write("RATE_LIMITER ratekey 10 2\r\n"); + getWriter().flush(); + assertResponseGetValue("9"); + + } +} diff --git a/src/test/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValueTestCase.java b/src/test/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValueTestCase.java new file mode 100644 index 0000000..306ef55 --- /dev/null +++ b/src/test/java/io/github/grrolland/hcshm/ratelimiter/RateLimiterShmValueTestCase.java @@ -0,0 +1,181 @@ +package io.github.grrolland.hcshm.ratelimiter; + +import org.junit.Test; + +import java.time.Duration; + +import static org.junit.Assert.assertEquals; + +/** + * RateLimiterValue TestCase + */ +public class RateLimiterShmValueTestCase { + + @Test + public void getValue() throws InterruptedException { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(2, Duration.ofMillis(100)); + // Take a token + rateLimiterShmValue.take(); + assertEquals("1", rateLimiterShmValue.getValue()); + // Wait and take another + Thread.sleep(50); + rateLimiterShmValue.take(); + assertEquals("0", rateLimiterShmValue.getValue()); + // Wait and getValue : the first token is expired + Thread.sleep(50); + assertEquals("1", rateLimiterShmValue.getValue()); + + // Take another : remainng 0 + rateLimiterShmValue.take(); + assertEquals("0", rateLimiterShmValue.getValue()); + + // Take another : remaning 0 + rateLimiterShmValue.take(); + assertEquals("0", rateLimiterShmValue.getValue()); + + // Wait 100 : all token are expired + Thread.sleep(100); + assertEquals("2", rateLimiterShmValue.getValue()); + } + + @Test + public void getRemaining() throws InterruptedException { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(2, Duration.ofMillis(100)); + rateLimiterShmValue.take(); + assertEquals(1, rateLimiterShmValue.getRemaining()); + // Pause and take + Thread.sleep(51); + rateLimiterShmValue.take(); + assertEquals(0, rateLimiterShmValue.getRemaining()); + + // Pause, take and get remaining : the first token is expired + Thread.sleep(55); + rateLimiterShmValue.take(); + assertEquals(0, rateLimiterShmValue.getRemaining()); + + // Take another + rateLimiterShmValue.take(); + assertEquals(0, rateLimiterShmValue.getRemaining()); + } + + @Test + public void take() { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(10, Duration.ofMillis(100)); + ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + assertEquals(9, consumptionProbe.getRemainingTokens()); + + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(8, consumptionProbe.getRemainingTokens()); + + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(7, consumptionProbe.getRemainingTokens()); + + } + + @Test + public void takeAll() { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(2, Duration.ofMillis(100)); + ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + assertEquals(1, consumptionProbe.getRemainingTokens()); + + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(0, consumptionProbe.getRemainingTokens()); + + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(-1, consumptionProbe.getRemainingTokens()); + + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(-1, consumptionProbe.getRemainingTokens()); + + } + + @Test + public void takeChangeCapacity() { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(3, Duration.ofMillis(100)); + assertEquals(3, rateLimiterShmValue.getRemaining()); + ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + assertEquals(2, consumptionProbe.getRemainingTokens()); + assertEquals(2, rateLimiterShmValue.getRemaining()); + + rateLimiterShmValue.setCapacity(4); + assertEquals(3, rateLimiterShmValue.getRemaining()); + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(2, consumptionProbe.getRemainingTokens()); + + rateLimiterShmValue.setCapacity(3); + assertEquals(1, rateLimiterShmValue.getRemaining()); + + // Take + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(0, consumptionProbe.getRemainingTokens()); + assertEquals(0, rateLimiterShmValue.getRemaining()); + + } + + @Test + public void takeChangeDuration() throws InterruptedException { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(3, Duration.ofMillis(100)); + assertEquals(3, rateLimiterShmValue.getRemaining()); + // Take token 1 + ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + assertEquals(2, consumptionProbe.getRemainingTokens()); + // Wait + Thread.sleep(50); + + // Change duration to 10 ms + rateLimiterShmValue.setDuration(Duration.ofMillis(10)); + // Take token 2 + consumptionProbe = rateLimiterShmValue.take(); + // Token 1 should be expired + assertEquals(2, consumptionProbe.getRemainingTokens()); + assertEquals(2, rateLimiterShmValue.getRemaining()); + + // Change duration to 1000 ms + rateLimiterShmValue.setDuration(Duration.ofMillis(1000)); + // Take token 3 + consumptionProbe = rateLimiterShmValue.take(); + + // Wait + Thread.sleep(500); + // Token 2 and 3 should not be expired + assertEquals(1, consumptionProbe.getRemainingTokens()); + assertEquals(1, rateLimiterShmValue.getRemaining()); + + // Wait + Thread.sleep(500); + // Token 2 and 3 should be expired + assertEquals(1, consumptionProbe.getRemainingTokens()); + assertEquals(1, rateLimiterShmValue.getRemaining()); + + } + + @Test + public void consumeTokenShouldRelease() throws InterruptedException { + final RateLimiterShmValue rateLimiterShmValue = new RateLimiterShmValue(10, Duration.ofMillis(100)); + + // Take token 1 + ConsumptionProbe consumptionProbe = rateLimiterShmValue.take(); + assertEquals(9, consumptionProbe.getRemainingTokens()); + // Wait 101 ms, token 1 should be released + Thread.sleep(100); + // Take token 2 + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(9, consumptionProbe.getRemainingTokens()); + + // Wait 50, and take token 3, token 2 should not be expired + Thread.sleep(50); + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(8, consumptionProbe.getRemainingTokens()); + + // Wait 50 and take Token 4 : token 2 should be expired, token 3 not expired + Thread.sleep(50); + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(8, consumptionProbe.getRemainingTokens()); + + // Wait 100 : token 2 and token 3 should be expired + Thread.sleep(100); + consumptionProbe = rateLimiterShmValue.take(); + assertEquals(9, consumptionProbe.getRemainingTokens()); + } + +} diff --git a/src/test/java/io/github/grrolland/hcshm/ratelimiter/TokenTestCase.java b/src/test/java/io/github/grrolland/hcshm/ratelimiter/TokenTestCase.java new file mode 100644 index 0000000..12e4f13 --- /dev/null +++ b/src/test/java/io/github/grrolland/hcshm/ratelimiter/TokenTestCase.java @@ -0,0 +1,30 @@ +package io.github.grrolland.hcshm.ratelimiter; + +import org.junit.Test; + +import java.time.Duration; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/*** + * Token test case + */ +public class TokenTestCase { + + @Test + public void isExpired() throws InterruptedException { + Token token = new Token(); + assertTrue(token.isExpired(Duration.ofMillis(0))); + Thread.sleep(10); + assertTrue(token.isExpired(Duration.ofMillis(500))); + } + + @Test + public void isExpiredFalse() throws InterruptedException { + Token token = new Token(); + assertFalse(token.isExpired(Duration.ofMillis(100))); + Thread.sleep(50); + assertFalse(token.isExpired(Duration.ofMillis(100))); + } +}