```
we want to change from

lookup = {value: {objs}}

to 

consistent_hash(value) = bucket_id
lookup = {bucket_id: {objs}}
```

In [1]:
from bisect import bisect_left
import random
import time
import sys
import numpy as np
from sortedcontainers import SortedDict
from pympler.asizeof import asizeof
import sortednp as snp
from cykhash import Int64Set
from operator import itemgetter
from typing import Callable, Union

In [2]:
_='''
OK, so the implementation of choice is a 
SortedDict of {min_value: Bucket}
and a heap of {size: min_value} containing the splittable buckets that are big.

Bucket knows the keys it contains as well as their counts. 
If asked for a split point, it will give one that best bisects its keys. That's O(log(n)) probably.
You can custom-write a bisection for that. 
'''

In [3]:
PLANETS = ['mercury']*10000 + ['venus']*100 + ['earth', 'mars', 'jupiter', 'saturn', 'uranus', 'neptune']
class Item:
    def __init__(self):
        self.s = random.choice(PLANETS)
        self.x = random.random()
    
    def __str__(self):
        return f'{self.s} {round(self.x, 2)}'

Best idea is:
    
 - hash values to lie in some big range, like uint64
 - Initially, we have like 10 buckets containing even chunks of that range (pretend we have a good hash function...)
 - Maintain a data structure with easy max and min access (sorted deque? heaps? dict?) sorted by the number of elements stored in the bucket. Maybe have another one for n_keys or something too, we don't wanna keep trying to split one bucket that has a single high-card key in it. Ugh.
 - Anyway, split the biggest bucket when it comes time for adding more buckets. When do we add more buckets? Shit.
 - Wait. When a bucket is unsplittable, we could take it outta the list. It's its own thing now. That would work.
 - We might have to put it back in someday.
 

In [42]:
# initialize with 1 bucket spanning whole range
# split when there are >1000 items in a splittable bucket
n_bits_signed = sys.hash_info.hash_bits - 1  # typically 64 bits
HASH_MIN = -2**n_bits_signed
HASH_MAX = 2**n_bits_signed-1

In [158]:
SIZE_THRESH = 300

class HashBucket:
    """
    A HashBucket contains all obj_ids that have value hashes between some min and max value.
    When the number of items in a HashBucket reaches SIZE_THRESH, the bucket will be split
    into two buckets.
    If a bucket ever gets empty, delete it unless it's the leftmost one -- we need at least one
    always.
    """
    def __init__(self):
        self.obj_ids = set()  # uint64
        self.val_hash_counts = dict()  # {int64: int64} - which hashes are stored in this bucket
    
    def add(self, val_hash, obj_id):
        count = self.val_hash_counts.get(val_hash, 0)
        self.val_hash_counts[val_hash] = count+1
        self.obj_ids.add(obj_id)
            
    def update(self, new_val_hash_counts, new_obj_ids):
        for v, c in new_val_hash_counts.items():
            count = self.val_hash_counts.get(v, 0)
            self.val_hash_counts[v] = count + c
        self.obj_ids = self.obj_ids.union(new_obj_ids)

    def get_all_ids(self):
        return self.obj_ids
    
    def remove(self, val_hash, obj_id):
        # todo handle exceptions
        self.val_hash_counts[val_hash] -= 1
        self.obj_ids.remove(obj_id)
    
    def split(self, field, obj_lookup):
        my_hashes = list(sorted(self.val_hash_counts.keys()))
        # dump out the upper half of our hashes
        half_point = len(my_hashes) // 2
        dumped_hash_counts = {v: self.val_hash_counts[v] for v in my_hashes[half_point:]}
        
        # dereference each object 
        # Find the objects with field_vals that hash to any of dumped_hashes
        # we will move their ids to the new bucket
        dumped_obj_ids = set()
        for obj_id in list(self.obj_ids):
            obj = obj_lookup.get(obj_id)
            obj_val = getattr(obj, field, None)
            if hash(obj_val) in dumped_hash_counts:
                dumped_obj_ids.add(obj_id)
                self.obj_ids.remove(obj_id)
        for dh in dumped_hash_counts:
            del self.val_hash_counts[dh]
        return dumped_hash_counts, dumped_obj_ids
        
    def __len__(self):
        return len(self.obj_ids)
    

class DictBucket:
    """
    A DictBucket stores object ids corresponding to only one val_hash. Note that multiple values
    coult have the same val_hash (collision).
    It stores all entries in a dict of {val: obj_id_set}, so it supports lookup by field value.
    This makes finding objects by val very fast. Unlike a HashBucket, we don't have to dereference
    all the objects and check their values during a find(). 
    DictBucket is great when many objects have the same val. 
    """
    def __init__(self, val_hash, obj_ids, obj_lookup, field):
        self.val_hash = val_hash
        self.d = dict()
        for obj_id in obj_ids:
            obj = obj_lookup.get(obj_id)
            val = getattr(obj, field, None)
            if val in self.d:
                self.d[val].add(obj_id)
            else:
                self.d[val] = set([obj_id])
    
    def add(self, val, obj_id):
        obj_ids = self.d.get(val, Int64Set())
        obj_ids.add(obj_id)
        
    def remove(self, val, obj_id):
        if val not in self.d:
            raise KeyError('Object value not in here')
        if obj_id not in self.d[val]:
            raise KeyError('Object ID not in here')
        self.d[val].remove(obj_id)
        if len(self.d[val]) == 0:
            del self.d[val]

    def get_matching_ids(self, val):
        return self.d[val]
    
    def get_all_ids(self):
        return set.union(*self.d.values())
    
    def __len__(self):
        return sum(len(s) for s in self.d.values())
    

class ObjLookup:
    
    def __init__(self):
        self.objs = dict()
        
    def get(self, obj_id):
        return self.objs.get(obj_id)
    
    def add(self, obj_id, obj):
        self.objs[obj_id] = obj

In [159]:
class MutableFieldIndex:
    # Stores the possible values of this field in a set of buckets
    # Several values may be allocated to the same bucket for space efficiency reasons

    def __init__(self, field: Union[Callable, str]):
        self.buckets = SortedDict()  # O(1) add / remove, O(log(n)) find bucket for key
        self.buckets[HASH_MIN] = HashBucket()  # always contains at least one bucket
        self.objs = ObjLookup()
        self.field = field
    
    def get_objs(self, val):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        bucket = self.buckets[k]
        
        if isinstance(bucket, DictBucket):
            return [self.objs.get(obj_id) for obj_id in bucket.get_matching_ids(val)]
        else:
            # filter to just the objs that match val
            matched_objs = []
            for obj_id in bucket.get_all_ids():
                obj = self.objs.get(obj_id)
                obj_val = getattr(obj, self.field, None)
                if obj_val is val or obj_val == val:
                    matched_objs.append(obj)
            return matched_objs
    
    def get_obj_ids(self, val):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        bucket = self.buckets[k]
        
        if isinstance(bucket, DictBucket):
            return bucket.get_matching_ids(val)
        else:
            # filter to just the obj_ids that match val
            matched_ids = []
            for obj_id in bucket.get_all_ids():
                obj = self.objs.get(obj_id)
                obj_val = getattr(obj, self.field, None)
                if obj_val is val or obj_val == val:
                    matched_ids.append(obj)
            return matched_ids
    
    def get_all_objs(self, obj_lookup):
        return list(obj_lookup.objs.values())
        
    def _get_bucket_key_for(self, val_hash):
        list_idx = self.buckets.bisect_right(val_hash) - 1
        k, _ = self.buckets.peekitem(list_idx)
        return k
        
    def _handle_big_hash_bucket(self, k):
        # A HashBucket is over threshold. 
        # If it contains values that all hash to the same thing, make it a DictBucket.
        # If it has many val_hashes, split it into two HashBuckets.
        hb = self.buckets[k]
        if len(hb.val_hash_counts) == 1:
            # convert it to a dictbucket
            db = DictBucket(list(hb.val_hash_counts.keys())[0], hb.obj_ids, self.objs, self.field)
            del self.buckets[k]
            self.buckets[db.val_hash] = db
        else:
            # split it into two hashbuckets
            new_hash_counts, new_obj_ids = self.buckets[k].split(self.field, self.objs)
            new_bucket = HashBucket()
            new_bucket.update(new_hash_counts, new_obj_ids)
            self.buckets[min(new_hash_counts.keys())] = new_bucket
            
    
    def add(self, obj):
        val = getattr(obj, self.field, None)
        val_hash = hash(val)
        obj_id = id(obj)
        self.objs.add(obj_id, obj)
        k = self._get_bucket_key_for(val_hash)
        if isinstance(self.buckets[k], DictBucket):
            if val_hash == self.buckets[k].val_hash:
                # add to dictbucket
                self.buckets[k].add(val, obj_id)
            else:
                # can't put it in this dictbucket, the val_hash doesn't match.
                # Make a new hashbucket to hold this item. 
                self.buckets[k+1] = HashBucket()
                self.buckets[k+1].add(val_hash, obj_id)
        else:
            # add to hashbucket
            self.buckets[k].add(val_hash, obj_id)
                
        if isinstance(self.buckets[k], HashBucket) and len(self.buckets[k]) >= SIZE_THRESH:
            self._handle_big_hash_bucket(k)
        
    def remove(self, val, obj_id):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        if isinstance(self.buckets[k], HashBucket):
            self.buckets[k].remove(val_hash, obj_id)
        else:
            self.buckets[k].remove(val, obj_id)
        if len(self.buckets[k]) == 0 and k != HASH_MIN:
            del self.buckets[k]
                
    def bucket_report(self):
        ls = []
        for bkey in self.buckets:
            bucket = self.buckets[bkey]
            bset = set()
            for obj_id in bucket.get_all_ids():
                o = self.objs.get(obj_id)
                bset.add(getattr(o, self.field))
            ls.append((bkey, bset, len(bucket), type(self.buckets[bkey]).__name__))
        return ls

In [160]:

n = 10**6
items = [Item() for _ in range(n)]

In [161]:
idx = MutableFieldIndex('s')
print('adding', n, 'items')
t0 = time.time()
for item in items:
    idx.add(item)
t1 = time.time()
print('\n', round(t1-t0,3), 'seconds to build this field thing\n')
for b in idx.bucket_report():
    print(b)


adding 1000000 items

 2.665 seconds to build this field thing

(-9138651267913408432, {'mercury'}, 989456, 'DictBucket')
(-9138651267913408431, {'jupiter', 'neptune'}, 165, 'HashBucket')
(-4331169577938643553, {'uranus', 'earth', 'saturn'}, 294, 'HashBucket')
(8896414948515954251, {'venus'}, 9988, 'DictBucket')
(9066641044505047935, {'mars'}, 97, 'HashBucket')


In [151]:
%%timeit -n 5 -r 5
z = idx.get_obj_ids('saturn')

169 µs ± 17.6 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [152]:
planet = 'uranus'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [153]:
for b in idx.bucket_report():
    print(b)

(-9138651267913408432, {'mercury'}, 989564, 'DictBucket')
(-9138651267913408431, {'jupiter', 'neptune'}, 200, 'HashBucket')
(-4331169577938643553, {'saturn', 'earth'}, 190, 'HashBucket')
(8896414948515954251, {'venus'}, 9833, 'DictBucket')
(9066641044505047935, {'mars'}, 106, 'HashBucket')


In [154]:
planet = 'venus'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [155]:
for b in idx.bucket_report():
    print(b)

(-9138651267913408432, {'mercury'}, 989564, 'DictBucket')
(-9138651267913408431, {'jupiter', 'neptune'}, 200, 'HashBucket')
(-4331169577938643553, {'saturn', 'earth'}, 190, 'HashBucket')
(9066641044505047935, {'mars'}, 106, 'HashBucket')


In [156]:
planet = 'mars'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [157]:
for b in idx.bucket_report():
    print(b)

(-9138651267913408432, {'mercury'}, 989564, 'DictBucket')
(-9138651267913408431, {'jupiter', 'neptune'}, 200, 'HashBucket')
(-4331169577938643553, {'saturn', 'earth'}, 190, 'HashBucket')


In [167]:
class Crap:
    def __init__(self):
        self.s = random.random()
crap_items = [Crap() for _ in range(10**6)]

In [168]:
idx = MutableFieldIndex('s')
print('adding', n, 'items')
t0 = time.time()
for item in crap_items:
    idx.add(item)
t1 = time.time()
print('\n', round(t1-t0,3), 'seconds to build this field thing\n')


adding 1000000 items

 6.597 seconds to build this field thing



In [169]:
len(idx.buckets)

4757

In [174]:
%%timeit -n 5 -r 5
idx.get_objs(crap_items[0].s)

196 µs ± 32.1 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
from hashindex import HashIndex

t0 = time.time()
hi = HashIndex(items, on='s')
t1 = time.time()
print(t1-t0, 'seconds to build a HashIndex')

hi.freeze()

In [None]:
t0 = time.time()
d = dict()
for i in items:
    if i.s not in d:
        d[i.s] = list()
    d[i.s].append(i)
t1 = time.time()
print(t1-t0, 'seconds to build a dict')


In [None]:
planet = 'mercury'

In [None]:
%%timeit -n 5 -r 5
v = hi.find(match={'s': planet})

In [None]:
%%timeit -n 5 -r 5
v = idx.get(planet)

In [None]:
%%timeit -n 5 -r 5
v = d.get(planet)
# yikes - how is this 1000x faster? something has gone really wrong here! let's see if it's the deref lookup that's
# costing so much

In [None]:
class DerefDict():
    
    def __init__(self, items):
        self.objs = {id(item): item for item in items}
        self.d = dict()
        for i in items:
            if i.s not in self.d:
                self.d[i.s] = list()
            self.d[i.s].append(id(i))
    
    def get(self, val):
        ids = self.d.get(val)
        return [self.objs.get(i) for i in ids]

t0 = time.time()
dd  = DerefDict(items)
t1 = time.time()
print(t1-t0, 'seconds to build a deref dict')

In [None]:
%%timeit -n 5 -r 5
v = dd.get(planet)
# the difference is that you are doing len(planets) dict lookups instead of just one dict lookup.
# Can we keep the list literal around during processing instead?
# e.g. - most of the time we will want an entire list (simple lookup, no intersection). Detect that 
# case and we've got something as good as dict().

In [None]:
class DerefDictListy():
    def __init__(self, items):
        self.objs = []
        self.d = dict()
        for i in items:
            if i.s not in self.d:
                self.d[i.s] = list()
            self.d[i.s].append(len(self.objs))
            self.objs.append(obj)
    
    def get(self, val):
        ids = self.d.get(val)
        return itemgetter(*ids)(self.objs)

t0 = time.time()
ddl  = DerefDict(items)
t1 = time.time()
print(t1-t0, 'seconds to build a deref dict, listy edition')

In [None]:
%%timeit -n 5 -r 5
v = ddl.get(planet)
# the difference is that you are doing len(planets) dict lookups instead of just one dict lookup.
# Can we keep the list literal around during processing instead?
# e.g. - most of the time we will want an entire list (simple lookup, no intersection). Detect that 
# case and we've got something as good as dict().

In [None]:
a = list('abcdefg')
isect_ids = snp.intersect(np.array([2,3,6]), np.array([1,2,3,4,5]), indices=True)[1][1]
itemgetter(*isect_ids)(a)

In [None]:
v = hi.find({'s': planet})
v