diff --git a/lib/hash.nelua b/lib/hash.nelua new file mode 100644 index 00000000..bd77e06c --- /dev/null +++ b/lib/hash.nelua @@ -0,0 +1,68 @@ +-- Hash module +-- +-- This module contains the hash() and hash_combine() functions, +-- they are used by some containers such as hash map to generate hash for values. + +require 'span' + +-- This is the same simple hash function used in Lua. +global function lhash(data: *[0]byte, len: usize, seed: usize, step: usize): usize + seed = seed ~ len + while len >= step do + seed = seed ~ ((seed << 5) + (seed >> 2) + data[len - 1]) + len = len - step + end + return seed +end + +-- This is the hash function taken from Lua for short strings. +global function lhash_short(data: span(byte)): usize + return lhash(data.data, data.size, 0x9e3779b9_usize, 1) +end + +-- This is the hash function taken from Lua for long strings. +global function lhash_long(data: span(byte)): usize + -- limit up to 32 iterations evenly spaced + return lhash(data.data, data.size, 0x9e3779b9_usize, (data.size >> 5) + 1) +end + +-- Hash combine function, algorithm taken from (C++ boost). +global function hash_combine(seed: usize, value: usize): usize + return seed ~ (value + 0x9e3779b9_usize + (seed<<6) + (seed>>2)) +end + +-- Hash function that can be used generally to hash any value. +-- To customize a hash for a specific type you can define __hash metamethod on it. +global function hash(x: auto): usize + ## if x.type.is_pointer then + if x == nilptr then + return 0 + end + ## end + + ## local type = x.type:implict_deref_type() + ## if type.is_integral or type.is_pointer then + return (@usize)(x) + ## elseif type.is_float then + local u: union{n: #[type]#, h: usize} + u.n = x + return u.h + ## elseif type.is_boolean then + if x then return 1 else return 0 end + ## elseif type.is_stringview or type.is_string then + return lhash_long({data=x.data, size=x.size}) + ## elseif type.is_span then + local T: type = #[type.subtype]# + return lhash_long({data=(@*[0]byte)(x.data), size=(@usize)(#T * #x)}) + ## elseif type.is_record and type.metafields.__hash then + return x:__hash() + ## elseif type.is_record then + local h: usize = 0 + ## for _,field in ipairs(type.fields) do -- hash all fields + h = hash_combine(h, hash(x.#|field.name|#)) + ## end + return h + ## elseif type.is_nilptr or type.is_niltype then + return 0 + ## else static_error("hash: cannot hash type '%s'", type) end +end diff --git a/lib/hashmap.nelua b/lib/hashmap.nelua new file mode 100644 index 00000000..e9d67a1a --- /dev/null +++ b/lib/hashmap.nelua @@ -0,0 +1,354 @@ +-- Hash Map +-- +-- Hash map is an associative container that contains key-value pairs with unique keys. +-- Search, insertion, and removal of elements have average constant-time complexity. +-- +-- The hash map share similarities with Lua tables but should not be used like them, +-- main differences: +-- * There is no array part +-- * The length operator returns number of elements in the map +-- * Indexing automatically inserts a key-value, to avoid this use `get` or `has` methods +-- * Values cannot be nil or set to nil +-- * Can only use pairs() to iterate + +require 'memory' +require 'hash' + +local function ceil_idiv(x: usize, y: usize): usize + return (x + y - 1) // y +end + +-- Maximum load factor (number of elements per bucket) in percent. +-- The container automatically increases the number of buckets if the load factor exceeds this threshold. +local MAX_LOAD_FACTOR: usize = 75 +-- Grow rate in percent. +-- When the maximum load factor is reached the container capacity grows by this factor. +local GROW_RATE: usize = 200 +-- Initial capacity to reserve when inserting an element for the first time in a container. +local INIT_CAPACITY: usize = 16 +-- Constant used to test invalid index. +local INVALID_INDEX: usize = #[primtypes.usize.max]# + +## local make_generic_hashmap = generalize(function(K, V, HashFunc, Allocator) + ## static_assert(traits.is_type(K), "invalid type '%s'", K) + ## static_assert(traits.is_type(V), "invalid type '%s'", V) + ## if not Allocator then + require 'allocators.default' + ## Allocator = DefaultAllocator + ## end + + local Allocator: type = #[Allocator]# + local K: type = @#[K]# + local V: type = @#[V]# + + local hashmap_nodeT: type = @record { + key: K, + value: V, + next: usize + } + + local hashmapT: type = @record{ + buckets: span(usize), + nodes: span(hashmap_nodeT), + size: usize, + allocator: Allocator + } + + ##[[ + local hashmapT = hashmapT.value + hashmapT.is_hashmap = true + hashmapT.is_container = true + ]] + + ## if HashFunc then + local hash_func: auto = #[HashFunc]# + ## else + local hash_func: auto = hash + ## end + + -- Create a hash map using a custom allocator instance. + -- This is only to be used when not using the default allocator. + function hashmapT.make(allocator: Allocator): hashmapT + local m: hashmapT + m.allocator = allocator + return m + end + + -- Resets the container to a zeroed state, freeing all used resources. + -- Complexity: O(1). + function hashmapT:destroy() + self.allocator:spandealloc(self.buckets) + self.allocator:spandealloc(self.nodes) + $self = (@hashmapT){} + end + + -- Remove all elements from the container. + -- Complexity: O(n). + function hashmapT:clear() + memory.spanset(self.buckets, INVALID_INDEX) + memory.spanzero(self.nodes) + self.size = 0 + end + + -- Used internally to find a value at a key returning it's node index. + function hashmapT:_find(key: K): (usize, usize, usize) + if unlikely(self.buckets.size == 0) then -- container is empty + return INVALID_INDEX, INVALID_INDEX, INVALID_INDEX + end + local hash_index: usize = (@usize)(hash_func(key) % self.buckets.size) + local node_index: usize = self.buckets[hash_index] + local prev_node_index: usize = INVALID_INDEX + -- iterate until the key is found + while node_index ~= INVALID_INDEX do + local node: *hashmap_nodeT = &self.nodes[node_index] + if node.key == key then + return node_index, prev_node_index, hash_index + end + prev_node_index = node_index + node_index = node.next + end + return node_index, prev_node_index, hash_index + end + + -- Sets the number of buckets to at least `count` and rehashes the container when needed. + -- Complexity: Average case O(n). + function hashmapT:rehash(count: usize) + -- count should be at least (size * 100) / MAX_LOAD_FACTOR + local min_count: usize = ceil_idiv(self.size * 100, MAX_LOAD_FACTOR) + if count < min_count then + count = min_count + end + + -- reserve number of elements + local nodes_count: usize = (count * MAX_LOAD_FACTOR) // 100 + if nodes_count > self.nodes.size then + self.nodes = self.allocator:spanxrealloc0(self.nodes, nodes_count) + end + + -- only rehash when needed + if count <= self.buckets.size then + return + end + + -- reallocate new buckets + self.buckets = self.allocator:spanxrealloc(self.buckets, count) + + -- reset buckets + memory.spanset(self.buckets, INVALID_INDEX) + for i:usize=0,= (MAX_LOAD_FACTOR * self.buckets.size) // 100) then + self:reserve((self.size * GROW_RATE) // 100) + end + + return node_index + end + end + + -- Checks if there is an element with a key in the container. + -- Complexity: Average case O(1). + function hashmapT:has(key: K): boolean + return self:_find(key) ~= INVALID_INDEX + end + + -- Returns the value that is mapped to a key, + -- performing an insertion if such key does not exist. + -- Complexity: Average case O(1). + function hashmapT:get(key: K): V + return self.nodes[self:_find_or_make(key)].value + end + + -- Returns a reference to the value that is mapped to a key, + -- performing an insertion if such key does not exist. + -- Complexity: Average case O(1). + function hashmapT:get_ptr(key: K): *V + return &self.nodes[self:_find_or_make(key)].value + end + + -- Returns the value that is mapped to a key. + -- If no such element exists, a runtime error is thrown. + -- Complexity: Average case O(1). + function hashmapT:at(key: K): V + local node_index: usize = self:_find(key) + check(node_index ~= INVALID_INDEX, 'hashmap.at: position out of bounds') + return self.nodes[node_index].value + end + + -- Returns a reference to the value that is mapped to a key. + -- If no such element exists, a runtime error is thrown. + -- Complexity: Average case O(1). + function hashmapT:at_ptr(key: K): *V + local node_index: usize = self:_find(key) + check(node_index ~= INVALID_INDEX, 'hashmap.at: position out of bounds') + return &self.nodes[node_index].value + end + + -- Inserts an element or assigns to the current element if the key already exists. + -- Complexity: Average case O(1). + function hashmapT:set(key: K, value: V) + self.nodes[self:_find_or_make(key)].value = value + end + + -- Removes an element with a key from the container (if it exists). + -- Returns true if an element was removed. + -- Complexity: Average case O(1). + function hashmapT:remove(key: K): boolean + local node_index: usize, prev_node_index: usize, hash_index: usize = self:_find(key) + if node_index == INVALID_INDEX then return false end + + -- unlink the removed node + local node: *hashmap_nodeT = &self.nodes[node_index] + if prev_node_index == INVALID_INDEX then + self.buckets[hash_index] = node.next + else + self.nodes[prev_node_index].next = node.next + end + + -- move the last node into the removed node place + local last_node_index: usize = self.size - 1 + local last_node: *hashmap_nodeT = &self.nodes[last_node_index] + if node_index ~= last_node_index then + $node = $last_node + local unused_index: usize, last_prev_node_index: usize, last_hash_index: usize = self:_find(node.key) + if last_prev_node_index == INVALID_INDEX then + self.buckets[last_hash_index] = node_index + else + self.nodes[last_prev_node_index].next = node_index + end + end + + -- clean the last node + $last_node = (@hashmap_nodeT)() + self.size = self.size - 1 + return true + end + + -- Returns the average number of elements per bucket. + function hashmapT:load_factor(): number + if unlikely(self.buckets.size == 0) then + return 0 + else + return self.size / self.buckets.size + end + end + + -- Returns the number of buckets in the container. + function hashmapT:bucket_count(): usize + return self.buckets.size + end + + -- Checks whether the container is empty. + function hashmapT:empty(): boolean + return self.size == 0 + end + + -- Returns the number of elements the container can store before triggering a rehash. + function hashmapT:capacity(): usize + return self.nodes.size + end + + -- Returns the number of elements in the container. + function hashmapT:__len(): isize + return (@isize)(self.size) + end + + -- Same as `get_ptr`, this allows indexing the hash map type. + -- Complexity: Average case O(1). + function hashmapT:__atindex(key: K): *V + return self:get_ptr(key) + end + + -- Hashmap iterator + local hashmap_iteratorT = @record { + container: *hashmapT, + index: usize + } + + -- Advance the container iterator returning its key and value. + -- NOTE: The input key is actually ignored. + function hashmap_iteratorT:next(k: K): (boolean, K, V) + if unlikely(self.index == INVALID_INDEX) then + self.index = 0 + else + self.index = self.index + 1 + end + if unlikely(self.index >= self.container.size) then + return false, (@K)(), (@V)() + end + local node: *hashmap_nodeT = &self.container.nodes[self.index] + return true, node.key, node.value + end + + -- Advance the container iterator returning its key and value by reference. + -- NOTE: The input key is actually ignored. + function hashmap_iteratorT:mnext(k: K): (boolean, K, *V) + if unlikely(self.index == INVALID_INDEX) then + self.index = 0 + else + self.index = self.index + 1 + end + if unlikely(self.index >= self.container.size) then + return false, (@K)(), nilptr + end + local node: *hashmap_nodeT = &self.container.nodes[self.index] + return true, node.key, &node.value + end + + -- Allow using pairs() to iterate the container. + function hashmapT:__pairs() + return hashmap_iteratorT.next, (@hashmap_iteratorT){container=self,index=INVALID_INDEX}, (@K)() + end + + -- Allow using mpairs() to iterate the container. + function hashmapT:__mpairs() + return hashmap_iteratorT.mnext, (@hashmap_iteratorT){container=self,index=INVALID_INDEX}, (@K)() + end + + ## return hashmapT +## end) + +global hashmap: type = #[make_generic_hashmap]# diff --git a/nelua/analyzer.lua b/nelua/analyzer.lua index d9e333ce..499d84ad 100644 --- a/nelua/analyzer.lua +++ b/nelua/analyzer.lua @@ -850,12 +850,17 @@ function visitors.GenericType(context, node) local argnode = argnodes[i] context:traverse_node(argnode) local argattr = argnode.attr - if not (argattr.comptime or argattr.type.is_comptime) then + if not (argattr.comptime or argattr.type.is_comptime or + (argattr.type.is_function and argattr._symbol and argattr.staticstorage)) then node:raisef("in generic evaluation '%s': argument #%d isn't a compile time value", name, i) end local value = argattr.value if bn.isnumeric(value) then value = bn.tonumber(value) + elseif argattr.type.is_function then + value = argattr + elseif argattr.type.is_niltype then + value = nil elseif not (traits.is_type(value) or traits.is_string(value) or traits.is_boolean(value) or diff --git a/nelua/types.lua b/nelua/types.lua index 527589df..1eb4c09a 100644 --- a/nelua/types.lua +++ b/nelua/types.lua @@ -174,8 +174,9 @@ Type.shape = shaper.shape { is_string = shaper.optional_boolean, is_span = shaper.optional_boolean, is_vector = shaper.optional_boolean, - is_list = shaper.optional_boolean, is_sequence = shaper.optional_boolean, + is_list = shaper.optional_boolean, + is_hashmap = shaper.optional_boolean, is_filestream = shaper.optional_boolean, is_time_t = shaper.optional_boolean, diff --git a/tests/all_test.nelua b/tests/all_test.nelua index f1ced4b3..099788e7 100644 --- a/tests/all_test.nelua +++ b/tests/all_test.nelua @@ -54,6 +54,14 @@ print('testing list_test... ') require 'tests.list_test' print 'OK!' +print('testing hash_test... ') +require 'tests.hash_test' +print 'OK!' + +print('testing hashmap_test... ') +require 'tests.hashmap_test' +print 'OK!' + print('testing defer_test... ') require 'tests.defer_test' print 'OK!' diff --git a/tests/hash_test.nelua b/tests/hash_test.nelua new file mode 100644 index 00000000..ed42371c --- /dev/null +++ b/tests/hash_test.nelua @@ -0,0 +1,43 @@ +require 'hash' + +do -- hash primitive types + assert(hash(0) == 0) + assert(hash(1) == 1) + assert(hash(0.0) == 0) + assert(hash(1.0) ~= 1) + assert(hash(true) == 1) + assert(hash(false) == 0) + assert(hash(nilptr) == 0) + assert(hash(nil) == 0) +end + +do -- hash strings + assert(hash('') == 0x9e3779b9_usize) + assert(hash('test') ~= 0) +end + +do -- hash records + local vec2 = @record{x: integer, y: integer} + local a: vec2 = {0,0} + local b: vec2 = {1,0} + assert(hash(a) ~= 0 and hash(b) ~= 0) + assert(hash(a) ~= hash(b)) + assert(hash(a) == hash(&a)) + assert(hash(b) == hash(&b)) +end + +do -- hash records with custom hash function + local vec3 = @record{x: integer, y: integer, z: integer} + function vec3:__hash(): usize + return (@usize)(self.x + self.y * 0xff + self.z * 0xffff) + end + + local v: vec3 + assert(hash(v) == 0) + assert(hash(v) == hash(&v)) + v = {1,1,1} + assert(hash(v) == 1 + 0xff + 0xffff) + + local pv: *vec3 + assert(hash(pv) == 0) +end diff --git a/tests/hashmap_test.nelua b/tests/hashmap_test.nelua new file mode 100644 index 00000000..4303fbab --- /dev/null +++ b/tests/hashmap_test.nelua @@ -0,0 +1,161 @@ +require 'hashmap' + +do -- inserting + local map: hashmap(integer, integer) + for i=1,50 do + local k, v = i*3167, i*10 + map:set(k, v) + end + for i=51,100 do + local k, v = i*3167, i*10 + map[k] = v + end + assert(#map == 100) + assert(map:capacity() == 128) + assert(map:bucket_count() > 128) + assert(map:load_factor() <= 0.75) + for i=1,100 do + local k, v = i*3167, i*10 + assert(map[k] == v) + assert(map:at(k) == v) + assert($map:at_ptr(k) == v) + assert(map:get(k) == v) + assert($map:get_ptr(k) == v) + assert(map:has(k)) + end + map:clear() + assert(#map == 0) + assert(map:bucket_count() > 128) + assert(map:capacity() == 128) + assert(map:load_factor() == 0) + map:set(1, 10) + map:destroy() + assert(#map == 0) + assert(map:capacity() == 0) + assert(map:bucket_count() == 0) + assert(map:load_factor() == 0) +end + +do -- reserve + local map: hashmap(integer, integer) + map:reserve(64) + assert(#map == 0) + assert(map:bucket_count() == 86) + assert(map:capacity() == 64) + assert(map:load_factor() == 0) + for i=1,100 do + local k, v = i*3167, i*10 + map:set(k, v) + end + map:reserve(256) + assert(map:bucket_count() == 342) + assert(map:capacity() == 256) + assert(map:load_factor() < 0.75) + for i=1,100 do + local k, v = i*3167, i*10 + assert(map[k] == v) + end + map:destroy() +end + +do -- rehash + local map: hashmap(integer, integer) + for i=1,100 do + local k, v = i*3167, i*10 + map:set(k, v) + end + local old_factor = map:load_factor() + map:rehash(512) + assert(#map == 100) + assert(map:load_factor() < old_factor) + assert(map:capacity() == 384) + assert(map:bucket_count() == 512) + for i=1,100 do + local k, v = i*3167, i*10 + assert(map[k] == v) + end + map:destroy() +end + +do -- remove + local map: hashmap(integer, integer) + for i=1,100 do + local k, v = i*3167, i*10 + map:set(k, v) + end + for i=1,50 do + local k = i*3167 + map:remove(k) + end + assert(#map == 50) + for i=51,100 do + local k, v = i*3167, i*10 + assert(map:get(k) == v) + map:remove(k) + end + assert(#map == 0) + map:destroy() +end + +require 'iterators' +do -- pairs and mpairs + local map: hashmap(integer, integer) + for i=1,100 do + local k, v = i*3167, i*10 + map[k] = v + end + local i = 1 + for k,v in pairs(map) do + local ek, ev = i*3167, i*10 + assert(k == ek and v == ev) + i = i + 1 + end + i = 1 + for k,v in mpairs(map) do + local ek, ev = i*3167, i*10 + assert(k == ek and $v == ev) + $v = i*100 + i = i + 1 + end + for i=1,100 do + local k, v = i*3167, i*100 + assert(map[k] == v) + end + map:destroy() +end + +do -- string map + local map: hashmap(stringview, stringview) + map['hello'] = 'hello' + map['world'] = 'world' + assert(map['hello'] == 'hello') + assert(map['world'] == 'world') + assert(#map == 2) + map:destroy() +end + +do -- custom hash function + local function hash_integer(x: integer) + return 0 + end + local map: hashmap(integer, integer, (hash_integer)) + map[1] = 1 + map[2] = 2 + assert(map[1] == 1) + assert(map[2] == 2) + map:destroy() +end + +require 'allocators.general' +do -- custom allocator + local map = (@hashmap(integer, integer, nil, GeneralAllocator)).make(general_allocator) + for i=1,100 do + local k, v = i*3167, i*10 + map:set(k, v) + end + for i=1,100 do + local k, v = i*3167, i*10 + assert(map[k] == v) + end + map:destroy() +end