diff --git a/chash.c b/chash.c index 5c3148b..b8c1fec 100644 --- a/chash.c +++ b/chash.c @@ -10,6 +10,8 @@ #define u_char unsigned char #endif +#define CHASH_OK 0 +#define CHASH_ERR -1 #define crc32_final(crc) \ crc ^= 0xffffffff @@ -144,7 +146,7 @@ chash_point_init(chash_point_t *arr, uint32_t base_hash, uint32_t start, } -void +int chash_point_sort(chash_point_t arr[], uint32_t n) { chash_point_t *points; @@ -163,6 +165,9 @@ chash_point_sort(chash_point_t arr[], uint32_t n) step = pow(2, 32) / m; points = (chash_point_t *) calloc(m, sizeof(chash_point_t)); + if (points == NULL) { + return CHASH_ERR; + } for (i = 0; i < n; i++) { node = &arr[i]; @@ -246,10 +251,12 @@ chash_point_sort(chash_point_t arr[], uint32_t n) } free(points); + + return CHASH_OK; } -void +int chash_point_add(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id, chash_point_t *new_points) @@ -258,9 +265,16 @@ chash_point_add(chash_point_t *old_points, uint32_t old_length, chash_point_t *tmp_points; tmp_points = (chash_point_t *) calloc(num, sizeof(chash_point_t)); + if (tmp_points == NULL) { + return CHASH_ERR; + } chash_point_init_crc(tmp_points, 0, base_hash, from, num, id); - chash_point_sort(tmp_points, num); + + if (chash_point_sort(tmp_points, num) != CHASH_OK) { + free(tmp_points); + return CHASH_ERR; + } j = num - 1; k = old_length + num - 1; @@ -283,10 +297,12 @@ chash_point_add(chash_point_t *old_points, uint32_t old_length, } free(tmp_points); + + return CHASH_OK; } -void +int chash_point_reduce(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id) { @@ -296,7 +312,11 @@ chash_point_reduce(chash_point_t *old_points, uint32_t old_length, tmp_points = (chash_point_t *) calloc(num, sizeof(chash_point_t)); chash_point_init_crc(tmp_points, 0, base_hash, from, num, id); - chash_point_sort(tmp_points, num); + + if (chash_point_sort(tmp_points, num) != CHASH_OK) { + free(tmp_points); + return CHASH_ERR; + } for (i = 0, j = 0, k = 0; i < old_length; i++) { if (j < num @@ -315,6 +335,8 @@ chash_point_reduce(chash_point_t *old_points, uint32_t old_length, } free(tmp_points); + + return CHASH_OK; } diff --git a/chash.h b/chash.h index b7f864c..8fb97b0 100644 --- a/chash.h +++ b/chash.h @@ -39,12 +39,12 @@ typedef struct { */ void chash_point_init(chash_point_t *points, uint32_t base_hash, uint32_t start, uint32_t num, uint32_t id) LCH_EXPORT; -void chash_point_sort(chash_point_t *points, uint32_t npoints) LCH_EXPORT; +int chash_point_sort(chash_point_t *points, uint32_t npoints) LCH_EXPORT; -void chash_point_add(chash_point_t *old_points, uint32_t old_length, +int chash_point_add(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id, chash_point_t *new_points) LCH_EXPORT; -void chash_point_reduce(chash_point_t *old_points, uint32_t old_length, +int chash_point_reduce(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id) LCH_EXPORT; void chash_point_delete(chash_point_t *old_points, uint32_t old_length, uint32_t id) LCH_EXPORT; diff --git a/lib/resty/chash.lua b/lib/resty/chash.lua index 90e73bc..930ebad 100644 --- a/lib/resty/chash.lua +++ b/lib/resty/chash.lua @@ -19,6 +19,10 @@ local pairs = pairs local tostring = tostring local tonumber = tonumber local bxor = bit.bxor +local error = error + + +local CHASH_OK = 0 ffi.cdef[[ @@ -31,12 +35,12 @@ typedef struct { void chash_point_init(chash_point_t *points, uint32_t base_hash, uint32_t start, uint32_t num, uint32_t id); -void chash_point_sort(chash_point_t *points, uint32_t size); +int chash_point_sort(chash_point_t *points, uint32_t size); -void chash_point_add(chash_point_t *old_points, uint32_t old_length, +int chash_point_add(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id, chash_point_t *new_points); -void chash_point_reduce(chash_point_t *old_points, uint32_t old_length, +int chash_point_reduce(chash_point_t *old_points, uint32_t old_length, uint32_t base_hash, uint32_t from, uint32_t num, uint32_t id); void chash_point_delete(chash_point_t *old_points, uint32_t old_length, uint32_t id); @@ -117,7 +121,9 @@ local function _precompute(nodes) start = start + num end - clib.chash_point_sort(points, npoints) + if clib.chash_point_sort(points, npoints) ~= CHASH_OK then + error("no memory") + end return ids, points, npoints, newnodes end @@ -194,14 +200,21 @@ local function _incr(self, id, weight) local new_npoints = self.npoints + weight * CONSISTENT_POINTS if self.size < new_npoints then new_points = ffi_new(chash_point_t, new_npoints) - self.size = new_npoints end local base_hash = bxor(crc32(tostring(id)), 0xffffffff) - clib.chash_point_add(self.points, self.npoints, base_hash, - old_weight * CONSISTENT_POINTS, - weight * CONSISTENT_POINTS, - index, new_points) + local rc = clib.chash_point_add(self.points, self.npoints, base_hash, + old_weight * CONSISTENT_POINTS, + weight * CONSISTENT_POINTS, + index, new_points) + + if rc ~= CHASH_OK then + error("no memory") + end + + if self.size < new_npoints then + self.size = new_npoints + end self.points = new_points self.npoints = new_npoints @@ -230,10 +243,15 @@ local function _decr(self, id, weight) end local base_hash = bxor(crc32(tostring(id)), 0xffffffff) - clib.chash_point_reduce(self.points, self.npoints, base_hash, - (old_weight - weight) * CONSISTENT_POINTS, - CONSISTENT_POINTS * weight, - index) + local from = (old_weight - weight) * CONSISTENT_POINTS + local num = CONSISTENT_POINTS * weight + + local rc = clib.chash_point_reduce(self.points, self.npoints, base_hash, + from, num, index) + + if rc ~= CHASH_OK then + error("no memory") + end nodes[id] = old_weight - weight self.npoints = self.npoints - CONSISTENT_POINTS * weight diff --git a/lib/resty/roundrobin.lua b/lib/resty/roundrobin.lua index 97743b5..be727d0 100644 --- a/lib/resty/roundrobin.lua +++ b/lib/resty/roundrobin.lua @@ -4,6 +4,7 @@ local next = next local tonumber = tonumber local setmetatable = setmetatable local math_random = math.random +local error = error local utils = require "resty.balancer.utils"