diff --git a/lib/ratelimit.rb b/lib/ratelimit.rb index 0d4efa8..378b7e2 100644 --- a/lib/ratelimit.rb +++ b/lib/ratelimit.rb @@ -40,7 +40,7 @@ def initialize(key, options = {}) # @return [Integer] The counter value def add(subject, count = 1) bucket = get_bucket - subject = "#{@key}:#{subject}" + subject = get_key_for_subject(subject) redis.multi do redis.hincrby(subject, bucket, count) redis.hdel(subject, (bucket + 1) % @bucket_count) @@ -55,14 +55,8 @@ def add(subject, count = 1) # @param [Integer] interval How far back (in seconds) to retrieve activity. def count(subject, interval) bucket = get_bucket - interval = [[interval, @bucket_interval].max, @bucket_span].min - count = (interval / @bucket_interval).floor - subject = "#{@key}:#{subject}" - - keys = (0..count - 1).map do |i| - (bucket - i) % @bucket_count - end - return redis.hmget(subject, *keys).inject(0) {|a, i| a + i.to_i} + keys = get_bucket_keys_for_interval(bucket, interval) + return redis.hmget(get_key_for_subject(subject), *keys).inject(0) {|a, i| a + i.to_i} end # Check if the rate limit has been exceeded. @@ -108,12 +102,58 @@ def exec_within_threshold(subject, options = {}, &block) yield(self) end + # Execute a block and increment the count once the rate limit is within bounds. + # This fixes the concurrency issue found in exec_within_threshold + # *WARNING* This will block the current thread until the rate limit is within bounds. + # + # @param [String] subject Subject for this rate limit + # @param [Hash] options Options hash + # @option options [Integer] :interval How far back to retrieve activity. + # @option options [Integer] :threshold Maximum number of actions + # @option options [Integer] :increment + # @yield The block to be run + # + # @example Send an email as long as we haven't send 5 in the last 10 minutes + # ratelimit.exec_with_threshold(email, [:threshold => 5, :interval => 600, :increment => 1]) do + # send_another_email + # end + def exec_and_increment_within_threshold(subject, options = {}, &block) + options[:threshold] ||= 30 + options[:interval] ||= 30 + options[:increment] ||= 1 + until count_incremented_within_threshold(subject, options) + sleep @bucket_interval + end + yield(self) + end + private def get_bucket(time = Time.now.to_i) ((time % @bucket_span) / @bucket_interval).floor end + def get_bucket_keys_for_interval(bucket, interval) + interval = [[interval, @bucket_interval].max, @bucket_span].min + count = (interval / @bucket_interval).floor + (0..count - 1).map do |i| + (bucket - i) % @bucket_count + end + end + + def get_key_for_subject(subject) + "#{@key}:#{subject}" + end + + def count_incremented_within_threshold(subject, options) + bucket = get_bucket + keys = get_bucket_keys_for_interval(bucket, options[:interval]) + evalScript = 'local a=KEYS[1]local b=tonumber(ARGV[1])local c=tonumber(ARGV[b+2])local d=tonumber(ARGV[b+3])local e=tonumber(ARGV[b+4])local f=tonumber(ARGV[b+5])local g=tonumber(ARGV[b+6])local h=0;local i=false;for j,k in ipairs(redis.call("HMGET",a,unpack(ARGV,2,b+1)))do h=h+(tonumber(k)or 0)end;if h