In [None]:
# Based off https://stevehanov.ca/blog/?id=119
# but modified to allow same-value collisions

import math
import time
import random
import string

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from IPython.display import clear_output

# Install the numba package if you want to speed up the
# calculations (though the script will work without it too)
try:
    from numba import njit
except ImportError:
    njit = (lambda f: f)

@njit
def fnv_hash(data, offset_basis):
    # https://en.wikipedia.org/wiki/Fowler–Noll–Vo_hash_function
    result = (offset_basis or 0x01000193) & 0xffffffff
    for data_byte in data:
        result = ((result * 0x01000193) ^ data_byte) & 0xffffffff
    return result

def key_to_bytes(key):
    if isinstance(key, bytes): return key
    if isinstance(key, str): return key.encode("utf-8")
    raise ValueError(f"Not a bytes or str key: {key}")

class PerfectHashMap:
    def __init__(self, hash, G, V, palette):
        self.hash = hash
        self.G = G # buckets with seeds
        self.V = V # cells with values
        self.palette = palette
    
    def __getitem__(self, key):
        key = key_to_bytes(key)
        g = self.G[self.hash(key, 0) % len(self.G)]
        p = self.V[self.hash(key, g) % len(self.V)]
        return self.palette[p]
    
    def estimate_min_size(self):
        min_bits = (lambda n: math.ceil(math.log(n, 2)))
        unique_g = set(self.G) # unique seeds
        g_count = len(self.G) # number of buckets
        v_count = len(self.V) # number of value cells
        p_count = len(self.palette) # value palette size
        u_count = len(unique_g) # number of unique seeds
        u_max = max(unique_g) # max value of seeds
        # Buckets, case A: store seeds in palette, buckets store palette indices
        g_bits_a = g_count * min_bits(u_count) + u_count * min_bits(u_max+1)
        # Buckets, case B: store seeds in buckets directly
        g_bits_b = g_count * min_bits(u_max+1)
        # Size of buckets is assumed to be the best of the cases A & B
        g_bits = min(g_bits_a, g_bits_b)
        # Value cells: store palette indices in the value cells
        v_bits = v_count * min_bits(p_count)
        # Finally, add buckets size and value cells size to obtain the estimate
        # (the cost of storing the counts themselves is assumed to be constant)
        return g_bits + v_bits

# Tries to create a value-pooled perfect hash map using the given python dictionary.
# If some keys map to identical values, they can potentially be "pooled" together.
def value_pooled_perfect_hash(mapping, hash, size_G=None, size_V=None, attempts=1000):
    # Build a palette of unique values
    palette = list(set(mapping.values()))
    palette_map = {v: i for i, v in enumerate(palette)}
    
    size_G = max(size_G or 0, 1)
    size_V = max(size_V or size_G, len(palette))
    
    G = np.zeros(size_G, dtype=int)
    V = np.zeros(size_V, dtype=int)
    T = np.zeros(size_V, dtype=int)
    
    # Step 1: Place all of the keys into buckets
    buckets = [{} for i in range(size_G)]
    
    for key, value in mapping.items():
        key = key_to_bytes(key)
        value = palette_map[value] + 1
        
        bucket = buckets[hash(key, 0) % size_G]
        keys = bucket.get(value)
        if keys is None:
            keys = []
            bucket[value] = keys
        keys.append(key)
    
    for i in range(size_G):
        buckets[i] = list(buckets[i].items())
    
    # Step 2: Sort the buckets and process the ones with the most items first.
    buckets.sort(key=len, reverse=True)
    
    for bucket in buckets:
        if len(bucket) <= 0: break
        
        rehashes = 0
        g = 1
        item = 0
        T[:] = V[:]
        
        # Repeatedly try different seeds until we find a hash function
        # that places all items in the bucket into free/matching slots
        while item < len(bucket):
            value, keys = bucket[item]
            
            success = True
            for key in keys:
                slot = hash(key, g) % size_V
                if not T[slot]:
                    T[slot] = value
                elif T[slot] != value:
                    success = False
                    break
            
            if success:
                item += 1
            else:
                rehashes += 1
                if rehashes > attempts: return None
                
                g += 1
                item = 0
                T[:] = V[:]
        
        value, keys = bucket[0]
        G[hash(keys[0], 0) % size_G] = g
        
        V[:] = T[:]
    
    V -= 1
    
    return PerfectHashMap(hash, G, V, palette)

def generate_random_string(length):
    chars = string.ascii_letters + string.digits + string.punctuation
    return ''.join(random.choice(chars) for _ in range(length))

def generate_mapping(key_count, value_count, binarize_values=False):
    if value_count <= 2:
        binarize_values = False
    
    if binarize_values:
        values = [0, 1]
    else:
        values = list(i+1 for i in range(value_count))
    
    keys = set()
    while len(keys) < key_count:
        keys.add(generate_random_string(10))
    
    if binarize_values:
        value_bits = math.ceil(math.log(value_count, 2))
        keys = {(chr(b)+k) for k in keys for b in range(value_bits)}
    
    keys = list(keys)
    random.shuffle(keys)
    
    mapping = {k: values[i % len(values)] for i, k in enumerate(keys)}
    
    return mapping

def scan_parameters(mapping, attempts):
    key_count = len(mapping)
    value_count = len(set(mapping.values()))
    
    size_G_min = 1
    size_G_max = key_count
    size_V_min = value_count
    size_V_max = key_count
    
    img = np.zeros((size_G_max, size_V_max, 3))
    color_exists = (0.0, 0.0, 0.8)
    color_best = (0.0, key_count*0.25, 0.0)
    
    smallest = None
    for size_G in range(1, size_G_max+1):
        line = []
        for size_V in range(1, size_V_max+1):
            if (size_G >= size_G_min) and (size_V >= size_V_min):
                result = value_pooled_perfect_hash(mapping, fnv_hash, size_G, size_V, attempts=attempts)
                if result:
                    min_size = result.estimate_min_size()
                    if (smallest is None) or (min_size < smallest[0]):
                        smallest = (min_size, size_V, size_G)
                    img[size_G-1, size_V-1] = color_exists
                    continue
    
    min_size, size_V, size_G = smallest
    img[size_G-1, size_V-1] = color_best
    
    return img, smallest

# Note: the parameters found this way are likely suboptimal,
# and an exhaustive search along the boundary region would
# yield a (typically ~2x) smaller size
def find_parameters(mapping, attempts):
    key_count = len(mapping)
    value_count = len(set(mapping.values()))
    
    size_min = value_count
    size_max = key_count
    size = (size_min + size_max) // 2
    
    best_result = None
    
    while size_min < size_max:
        result = value_pooled_perfect_hash(mapping, fnv_hash, size, attempts=attempts)
        
        if result is None:
            size_min = size+1
        else:
            size_max = size-1
            if (best_result is None) or (len(result.G) < len(best_result.G)):
                best_result = result
        
        size = (size_min + size_max) // 2
    
    return best_result

print()

random.seed(0)

experiments_count = 10
key_count = 32
value_count = 2
binarize_values = True
attempts = 1000

calculate_full_grid = True

if calculate_full_grid:
    min_sizes = []
    
    effective_key_count = key_count
    if binarize_values:
        value_bits = math.ceil(math.log(value_count, 2))
        effective_key_count = key_count * value_bits
    
    img_accum = np.zeros((effective_key_count, effective_key_count, 3))
    accum_count = 0
    
    for i in range(experiments_count):
        mapping = generate_mapping(key_count, value_count, binarize_values)
        img, smallest = scan_parameters(mapping, attempts)
        min_sizes.append(smallest[0])
        img_accum += img
        accum_count += 1
        
        clear_output(wait=True)
        fig, ax = plt.subplots()
        ax.imshow(np.clip(img_accum/accum_count, 0, 1))
        ax.invert_yaxis()
        ax.set_xlabel('len(V)')
        ax.set_ylabel('len(G)')
        plt.show()
        time.sleep(0.1)  # Pause to make the updates visible
    
    print(f"Min sizes: {min_sizes}")
else:
    mapping = generate_mapping(key_count, value_count, binarize_values)
    result = find_parameters(mapping, attempts)
    
    print(f"Keys: {key_count}, values: {value_count}")
    print(f"Smallest found size: {result.estimate_min_size()} bits")
