```
we want to change from

lookup = {value: {objs}}

to 

consistent_hash(value) = bucket_id
lookup = {bucket_id: {objs}}
```

In [94]:
from bisect import bisect_left
import random
import time
import sys
import numpy as np
from sortedcontainers import SortedDict
from pympler.asizeof import asizeof
import sortednp as snp
from cykhash import Int64Set
from operator import itemgetter
from typing import Callable, Union, List, Any, Tuple
from collections import Counter, namedtuple
from dataclasses import dataclass

In [2]:
_='''
OK, so the implementation of choice is a 
SortedDict of {min_value: Bucket}
and a heap of {size: min_value} containing the splittable buckets that are big.

Bucket knows the keys it contains as well as their counts. 
If asked for a split point, it will give one that best bisects its keys. That's O(log(n)) probably.
You can custom-write a bisection for that. 
'''

In [3]:
PLANETS = ['mercury']*10000 + ['venus']*100 + ['earth', 'mars', 'jupiter', 'saturn', 'uranus', 'neptune']
class Item:
    def __init__(self):
        self.s = random.choice(PLANETS)
        self.x = random.random()
    
    def __str__(self):
        return f'{self.s} {round(self.x, 2)}'

Best idea is:
    
 - hash values to lie in some big range, like uint64
 - Initially, we have like 10 buckets containing even chunks of that range (pretend we have a good hash function...)
 - Maintain a data structure with easy max and min access (sorted deque? heaps? dict?) sorted by the number of elements stored in the bucket. Maybe have another one for n_keys or something too, we don't wanna keep trying to split one bucket that has a single high-card key in it. Ugh.
 - Anyway, split the biggest bucket when it comes time for adding more buckets. When do we add more buckets? Shit.
 - Wait. When a bucket is unsplittable, we could take it outta the list. It's its own thing now. That would work.
 - We might have to put it back in someday.
 

In [4]:
# initialize with 1 bucket spanning whole range
# split when there are >1000 items in a splittable bucket
n_bits_signed = sys.hash_info.hash_bits - 1  # typically 64 bits
HASH_MIN = -2**n_bits_signed
HASH_MAX = 2**n_bits_signed-1

In [5]:
SIZE_THRESH = 300

class HashBucket:
    """
    A HashBucket contains all obj_ids that have value hashes between some min and max value.
    When the number of items in a HashBucket reaches SIZE_THRESH, the bucket will be split
    into two buckets.
    If a bucket ever gets empty, delete it unless it's the leftmost one -- we need at least one
    always.
    """
    def __init__(self):
        self.obj_ids = set()  # uint64
        self.val_hash_counts = dict()  # {int64: int64} - which hashes are stored in this bucket
    
    def add(self, val_hash, obj_id):
        count = self.val_hash_counts.get(val_hash, 0)
        self.val_hash_counts[val_hash] = count+1
        self.obj_ids.add(obj_id)
            
    def update(self, new_val_hash_counts, new_obj_ids):
        for v, c in new_val_hash_counts.items():
            count = self.val_hash_counts.get(v, 0)
            self.val_hash_counts[v] = count + c
        self.obj_ids = self.obj_ids.union(new_obj_ids)

    def get_all_ids(self):
        return self.obj_ids
    
    def remove(self, val_hash, obj_id):
        # todo handle exceptions
        self.val_hash_counts[val_hash] -= 1
        self.obj_ids.remove(obj_id)
    
    def split(self, field, obj_lookup):
        my_hashes = list(sorted(self.val_hash_counts.keys()))
        # dump out the upper half of our hashes
        half_point = len(my_hashes) // 2
        dumped_hash_counts = {v: self.val_hash_counts[v] for v in my_hashes[half_point:]}
        
        # dereference each object 
        # Find the objects with field_vals that hash to any of dumped_hashes
        # we will move their ids to the new bucket
        dumped_obj_ids = set()
        for obj_id in list(self.obj_ids):
            obj = obj_lookup.get(obj_id)
            obj_val = getattr(obj, field, None)
            if hash(obj_val) in dumped_hash_counts:
                dumped_obj_ids.add(obj_id)
                self.obj_ids.remove(obj_id)
        for dh in dumped_hash_counts:
            del self.val_hash_counts[dh]
        return dumped_hash_counts, dumped_obj_ids
        
    def __len__(self):
        return len(self.obj_ids)
    

class DictBucket:
    """
    A DictBucket stores object ids corresponding to only one val_hash. Note that multiple values
    coult have the same val_hash (collision).
    It stores all entries in a dict of {val: obj_id_set}, so it supports lookup by field value.
    This makes finding objects by val very fast. Unlike a HashBucket, we don't have to dereference
    all the objects and check their values during a find(). 
    DictBucket is great when many objects have the same val. 
    """
    def __init__(self, val_hash, obj_ids, obj_lookup, field):
        self.val_hash = val_hash
        self.d = dict()
        for obj_id in obj_ids:
            obj = obj_lookup.get(obj_id)
            val = getattr(obj, field, None)
            if val in self.d:
                self.d[val].add(obj_id)
            else:
                self.d[val] = set([obj_id])
    
    def add(self, val, obj_id):
        obj_ids = self.d.get(val, Int64Set())
        obj_ids.add(obj_id)
        
    def remove(self, val, obj_id):
        if val not in self.d:
            raise KeyError('Object value not in here')
        if obj_id not in self.d[val]:
            raise KeyError('Object ID not in here')
        self.d[val].remove(obj_id)
        if len(self.d[val]) == 0:
            del self.d[val]

    def get_matching_ids(self, val):
        return self.d[val]
    
    def get_all_ids(self):
        return set.union(*self.d.values())
    
    def __len__(self):
        return sum(len(s) for s in self.d.values())
    

class ObjLookup:
    
    def __init__(self):
        self.objs = dict()
        
    def get(self, obj_id):
        return self.objs.get(obj_id)
    
    def add(self, obj_id, obj):
        self.objs[obj_id] = obj

In [6]:
class MutableFieldIndex:
    # Stores the possible values of this field in a set of buckets
    # Several values may be allocated to the same bucket for space efficiency reasons

    def __init__(self, field: Union[Callable, str]):
        self.buckets = SortedDict()  # O(1) add / remove, O(log(n)) find bucket for key
        self.buckets[HASH_MIN] = HashBucket()  # always contains at least one bucket
        self.objs = ObjLookup()  # todo move this to a higher level (?)
        self.field = field
    
    def get_objs(self, val):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        bucket = self.buckets[k]
        
        if isinstance(bucket, DictBucket):
            return [self.objs.get(obj_id) for obj_id in bucket.get_matching_ids(val)]
        else:
            # filter to just the objs that match val
            matched_objs = []
            for obj_id in bucket.get_all_ids():
                obj = self.objs.get(obj_id)
                obj_val = getattr(obj, self.field, None)
                if obj_val is val or obj_val == val:
                    matched_objs.append(obj)
            return matched_objs
    
    def get_obj_ids(self, val):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        bucket = self.buckets[k]
        
        if isinstance(bucket, DictBucket):
            return bucket.get_matching_ids(val)
        else:
            # filter to just the obj_ids that match val
            matched_ids = []
            for obj_id in bucket.get_all_ids():
                obj = self.objs.get(obj_id)
                obj_val = getattr(obj, self.field, None)
                if obj_val is val or obj_val == val:
                    matched_ids.append(obj)
            return matched_ids
    
    def get_all_objs(self, obj_lookup):
        return list(obj_lookup.objs.values())
        
    def _get_bucket_key_for(self, val_hash):
        list_idx = self.buckets.bisect_right(val_hash) - 1
        k, _ = self.buckets.peekitem(list_idx)
        return k
        
    def _handle_big_hash_bucket(self, k):
        # A HashBucket is over threshold. 
        # If it contains values that all hash to the same thing, make it a DictBucket.
        # If it has many val_hashes, split it into two HashBuckets.
        hb = self.buckets[k]
        if len(hb.val_hash_counts) == 1:
            # convert it to a dictbucket
            db = DictBucket(list(hb.val_hash_counts.keys())[0], hb.obj_ids, self.objs, self.field)
            del self.buckets[k]
            self.buckets[db.val_hash] = db
        else:
            # split it into two hashbuckets
            new_hash_counts, new_obj_ids = self.buckets[k].split(self.field, self.objs)
            new_bucket = HashBucket()
            new_bucket.update(new_hash_counts, new_obj_ids)
            self.buckets[min(new_hash_counts.keys())] = new_bucket
            
    
    def add(self, obj):
        val = getattr(obj, self.field, None)
        val_hash = hash(val)
        obj_id = id(obj)
        self.objs.add(obj_id, obj)
        k = self._get_bucket_key_for(val_hash)
        if isinstance(self.buckets[k], DictBucket):
            if val_hash == self.buckets[k].val_hash:
                # add to dictbucket
                self.buckets[k].add(val, obj_id)
            else:
                # can't put it in this dictbucket, the val_hash doesn't match.
                # Make a new hashbucket to hold this item. 
                self.buckets[k+1] = HashBucket()
                self.buckets[k+1].add(val_hash, obj_id)
        else:
            # add to hashbucket
            self.buckets[k].add(val_hash, obj_id)
                
        if isinstance(self.buckets[k], HashBucket) and len(self.buckets[k]) > SIZE_THRESH:
            self._handle_big_hash_bucket(k)
        
    def remove(self, val, obj_id):
        val_hash = hash(val)
        k = self._get_bucket_key_for(val_hash)
        if isinstance(self.buckets[k], HashBucket):
            self.buckets[k].remove(val_hash, obj_id)
        else:
            self.buckets[k].remove(val, obj_id)
        if len(self.buckets[k]) == 0 and k != HASH_MIN:
            del self.buckets[k]
                
    def bucket_report(self):
        ls = []
        for bkey in self.buckets:
            bucket = self.buckets[bkey]
            bset = set()
            for obj_id in bucket.get_all_ids():
                o = self.objs.get(obj_id)
                bset.add(getattr(o, self.field))
            ls.append((bkey, bset, len(bucket), type(self.buckets[bkey]).__name__))
        return ls

In [7]:
n = 10**6
items = [Item() for _ in range(n)]

In [8]:
idx = MutableFieldIndex('s')
print('adding', n, 'items')
t0 = time.time()
for item in items:
    idx.add(item)
t1 = time.time()
print('\n', round(t1-t0,3), 'seconds to build this field thing\n')
for b in idx.bucket_report():
    print(b)


adding 1000000 items

 2.863 seconds to build this field thing

(-3057177238626995881, {'uranus'}, 90, 'HashBucket')
(1902181302418005661, {'mercury'}, 989616, 'DictBucket')
(1902181302418005662, {'neptune', 'saturn', 'mars'}, 288, 'HashBucket')
(8688647582771633844, {'venus'}, 9673, 'DictBucket')
(8688647582771633845, {'earth', 'jupiter'}, 170, 'HashBucket')


In [9]:
# What would improve build time? < 1M items / second means no one's using this for 100M items.
# How fast could it go if all we had to do was hash all the values up front, and add them to sorteddict?
# bout 5x faster, looks like.

In [10]:
%%timeit -n 2 -r 5
_ = sorted([hash(random.random()) for _ in range(10**6)])
s = SortedDict()
for i in range(10*3):
    s[i] = set(range(10**3))

510 ms ± 11 ms per loop (mean ± std. dev. of 5 runs, 2 loops each)


In [318]:
menu = list(x*0.9 for x in range(int(10**7)))
class Inty:
    def __init__(self):
        self.s = random.choice(menu)
        
item_ints = [Inty() for _ in range(10**6)]

In [12]:
%%timeit -n 3 -r 3
# trying to be smart about hashing each value is much slower than just hashing each value
prev_hash = None
prev_val = None
for i, item in enumerate(sorted(item_ints, key=lambda x: x.s)):
    if i > 0 and item.s == prev_val:
        h = prev_hash
    else:
        h = hash(item.s)
        prev_hash = h
        prev_val = item.s

878 ms ± 7 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [13]:
%%timeit -n 3 -r 3
for item in item_ints:
    _ = hash(item.s)

193 ms ± 2.9 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [14]:
# todo sort both by the argsort etc

In [36]:
def run_length_encode(arr: np.ndarray):
    if len(arr) == 0:
        return None, None, None
    mismatch = arr[1:] != arr[:-1]
    i = np.append(np.where(mismatch), len(arr)-1)
    counts = np.diff(np.append(-1, i))
    starts = np.cumsum(np.append(0, counts))[:-1]
    return starts, counts, arr[i]

In [160]:
def get_field(obj, field):
    if callable(field):
        val = field(obj)
    elif isinstance(obj, dict):
        val = obj.get(field, None)
    else:
        val = getattr(obj, field, None)
    return val


In [248]:
def get_sorted_hashes(objs: List[Any], field: Union[Callable, str]) -> Tuple[np.array, List]:
    """
    Hash the given attribute for all objs. Sort objs and hashes by the hashes.

    Takes 350ms for 1M items.
    """
    def _get_field(obj, field):
        if getattr(obj, 'get', None):
            return obj.get(field, None)
        if isinstance(field, str):
            val = getattr(obj, field, None)
        return val

    hashes = np.fromiter((hash(_get_field(item, field)) for item in objs), dtype='int64')
    pos = np.argsort(hashes)
    sorted_hashes = hashes[pos]
    sorted_objs = itemgetter(*pos)(objs)  # slow AF
    return sorted_hashes, sorted_objs


In [246]:
%%timeit -n 3 -r 3  # it's 450ms

sorted_hashes, sorted_objs = get_sorted_hashes(item_ints, 's')


667 ms ± 3.84 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [234]:
%%timeit -n 3 -r 3 
hashes = np.fromiter((getattr(item, 's', None) for item in item_ints), dtype='int64')

214 ms ± 15.6 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [232]:
%%timeit -n 3 -r 3 
np.argsort(hashes)

97.1 ms ± 5.78 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [233]:
%%timeit -n 3 -r 3 
np.zeros((10**6,))

848 µs ± 147 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [171]:

def run_length_encode(arr: np.ndarray):
    if len(arr) == 0:
        return None, None, None
    mismatch = arr[1:] != arr[:-1]
    i = np.append(np.where(mismatch), len(arr)-1)
    counts = np.diff(np.append(-1, i))
    starts = np.cumsum(np.append(0, counts))[:-1]
    return starts, counts, arr[i]


In [172]:
%%timeit -n 3 -r 3 
starts, counts, val_hashes = run_length_encode(sorted_hashes)

9.56 ms ± 707 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [97]:
%%timeit -n 3 -r 3  # it's 500ms

s_hash, s_obj = get_sorted_hashes(item_ints, 's')

550 ms ± 12.6 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [174]:
#%%timeit -n 3 -r 3  # it's 500ms
# compute all hashes and sort by hash
hashes = np.fromiter((hash(item.s) for item in item_ints), dtype='int64')
pos = np.argsort(hashes)
sorted_hashes = hashes[pos]
sorted_items = itemgetter(*pos)(item_ints)  # todo: handle itemgetter scenarios with 0 or 1 objects
starts, counts, val_hashes = run_length_encode(sorted_hashes)

In [198]:
@njit
def find_bucket_starts(counts, limit):
    """
    Find counts of each hash in sorted_hashes via run-length encoding.

    Takes 10ms for 1M objs.
    """
    result = np.empty(len(counts), dtype=np.uint64)
    total = 0
    idx = 0
    for i, count in enumerate(counts):
        total += count
        if total > limit:
            total = 0
            result[idx] = i
            idx += 1
    return result[:idx]



In [331]:
"""
When a large object list is provided in the constructor -- i.e., HashIndex(objs, on={...}) has lots of objs,
adding the objs one at a time is naive and slow. Buckets will be created, overfilled, and split needlessly.
There is a ~10X performance benefit to examining the objs and constructing all the needed buckets just once.

The functions here provide that speedup. They have been squeezed pretty hard for performance. It matters here!
Building the index is most likely going to be the bottleneck when working with large datasets, especially in the
expected data-analysis niche for this library.

Workflow:
 - Hash the given attribute for all objs
 - Sort the hashes
 - Get counts of each unique hash (via run-length encoding)
 - Use a cumulative sum-like algorithm to determine the span of each bucket
 - Return all information that init needs to create buckets
"""

import numpy as np
from numba import njit
from dataclasses import dataclass
from typing import Tuple, List, Union, Callable, Any
from hashindex.utils import get_field
from operator import itemgetter


def get_sorted_hashes(objs: List[Any], field: Union[Callable, str]) -> Tuple[np.array, np.ndarray]:
    """
    Hash the given attribute for all objs. Sort objs and hashes by the hashes.

    Takes 450ms for 1M objs on a numeric field. May take longer if field is a Callable or is hard to hash.
    Breakdown:
     - 100ms to do all the get_field() calls. Cost is the part that inspects each obj to see if it's a dict.
     - 220ms to get and hash the field for each obj. No getting around that.
     - 100ms to sort the hashes
     - 30ms of whatever
    """
    hashes = np.fromiter((hash(get_field(obj, field)) for obj in objs), dtype='int64')
    pos = np.argsort(hashes)
    sorted_hashes = hashes[pos]
    sorted_objs = itemgetter(*pos)(objs)  # todo handle itemgetter len 0, 1 weirdness
    return sorted_hashes, sorted_objs


def run_length_encode(sorted_hashes: np.ndarray):
    """
    Find counts of each hash in sorted_hashes via run-length encoding.

    Takes 10ms for 1M objs.
    """
    mismatch = sorted_hashes[1:] != sorted_hashes[:-1]
    i = np.append(np.where(mismatch), len(sorted_hashes) - 1)
    counts = np.diff(np.append(-1, i))
    starts = np.cumsum(np.append(0, counts))[:-1]
    return starts, counts, sorted_hashes[i]


@njit
def find_bucket_starts(counts, limit):
    """
    Find the start positions for each bucket via a cumulative sum that resets when limit is exceeded.

    Takes about 1ms for a counts length of 1M. 300x slower without numba (noticeable on high-cardinality data)
    """
    result = np.empty(len(counts), dtype=np.uint64)
    total = 0
    idx = 0
    for i, count in enumerate(counts):
        total += count
        if total > limit:
            total = count
            result[idx] = i
            idx += 1
    return result[:idx]


@dataclass
class BucketInfo:
    distinct_hashes: np.ndarray
    distinct_hash_counts: np.ndarray
    obj_arr: np.ndarray
    hash_arr: np.ndarray

    def __str__(self):
        d = dict(zip(self.distinct_hashes, self.distinct_hash_counts))
        l1 = len(self.obj_arr)
        l2 = len(self.hash_arr)
        mh = min(self.hash_arr)
        return f'{mh}: {l1}={l2}; ' + str(d)


def compute_buckets(objs, field, bucket_size_limit):
    sorted_hashes, sorted_objs = get_sorted_hashes(objs, field)
    starts, counts, val_hashes = run_length_encode(sorted_hashes)
    bucket_starts = find_bucket_starts(counts, bucket_size_limit)

    bucket_infos = []
    for i, s in enumerate(bucket_starts):
        if i + 1 == len(bucket_starts):
            distinct_hashes = val_hashes[i:]
            distinct_hash_counts = counts[i:]
            obj_arr = sorted_objs[starts[i]:]
            hash_arr = sorted_hashes[starts[i]:]
        else:
            distinct_hashes = val_hashes[i:i + 1]
            distinct_hash_counts = counts[i:i + 1]
            obj_arr = sorted_objs[starts[i]:starts[i + 1]]
            hash_arr = sorted_hashes[starts[i]:starts[i + 1]]
        bucket_infos.append(
            BucketInfo(
                distinct_hashes=distinct_hashes,
                distinct_hash_counts=distinct_hash_counts,
                obj_arr=obj_arr,
                hash_arr=hash_arr,
            )
        )


def compute_buckets(objs, field, bucket_size_limit):
    sorted_hashes, sorted_objs = get_sorted_hashes(objs, field)
    starts, counts, val_hashes = run_length_encode(sorted_hashes)
    bucket_starts = find_bucket_starts(counts, bucket_size_limit)

    bucket_infos = []
    for i, s in enumerate(bucket_starts):
        if i + 1 == len(bucket_starts):
            distinct_hashes = val_hashes[s:]
            distinct_hash_counts = counts[s:]
            obj_arr = sorted_objs[starts[s]:]
            hash_arr = sorted_hashes[starts[s]:]
        else:
            t = bucket_starts[i + 1]
            distinct_hashes = val_hashes[s:t]
            distinct_hash_counts = counts[s:t]
            obj_arr = sorted_objs[starts[s]:starts[t]]
            hash_arr = sorted_hashes[starts[s]:starts[t]]
        bucket_infos.append(
            BucketInfo(
                distinct_hashes=distinct_hashes,
                distinct_hash_counts=distinct_hash_counts,
                obj_arr=obj_arr,
                hash_arr=hash_arr,
            )
        )
    return bucket_infos

In [329]:
%%timeit -n 5 -r 5
compute_buckets(item_ints, 's', 1000)

738 ms ± 13.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [332]:
bucket_infos = compute_buckets(item_ints, 's', 1000)
print(len(bucket_infos))
for b in bucket_infos[:10]:
    print(b)

1000
90963: 1000=1000; {90963: 1, 90972: 1, 91044: 1, 91107: 2, 91143: 1, 91179: 1, 91224: 2, 91287: 1, 91539: 1, 91593: 1, 91611: 1, 91647: 1, 91791: 1, 91854: 1, 91872: 1, 91944: 1, 92034: 1, 92088: 2, 92106: 2, 92349: 1, 92358: 1, 92367: 1, 92376: 1, 92385: 1, 92592: 1, 92610: 1, 92619: 1, 92745: 1, 92889: 1, 92907: 1, 92970: 1, 93006: 1, 93078: 1, 93510: 1, 93618: 1, 93726: 1, 93744: 1, 93780: 1, 93843: 1, 93852: 1, 93879: 1, 93888: 1, 94059: 1, 94104: 1, 94194: 1, 94221: 1, 94311: 1, 94554: 1, 94617: 1, 94743: 1, 94797: 1, 94932: 1, 95067: 1, 95175: 1, 95328: 2, 95553: 1, 95643: 1, 95706: 1, 95724: 1, 95823: 1, 95922: 1, 95940: 2, 95949: 1, 96039: 1, 96048: 1, 96120: 1, 96183: 1, 96273: 1, 96282: 1, 96291: 1, 96309: 1, 96408: 1, 96435: 1, 96579: 1, 96642: 1, 96696: 2, 96741: 1, 96876: 1, 97119: 1, 97164: 1, 97173: 2, 97209: 1, 97218: 1, 97560: 1, 97587: 1, 97749: 1, 97794: 1, 97839: 2, 97875: 1, 97956: 1, 98019: 1, 98091: 1, 98154: 1, 98244: 1, 98280: 1, 98298: 1, 98334: 1, 98361:

In [292]:
print(sorted_hashes[:20])
print(starts[:10]) 
print(counts[:10])
print(val_hashes[:10])

[0 0 0 0 0 0 0 9 9 9 9 9 9 9 9 9 9 9 9 9]
[ 0  7 24 31 37 51 63 68 74 81]
[ 7 17  7  6 14 12  5  6  7  9]
[ 0  9 18 27 36 45 54 63 72 81]


In [199]:
%%timeit -n 10 -r 10
find_bucket_starts(counts, SIZE_THRESH)

The slowest run took 41.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1.52 ms ± 3.57 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [200]:
len(counts)

631848

In [92]:
len(item_ints)

1000000

In [38]:
len(counts)

632703

In [39]:
hashes[0:10]

array([1844674407377707863, 1844674407478945242,  922337203739473107,
       1844674407478857798,  922337203739557383, 1614090106611372739,
       1614090106476759893,              805977,  461168601870065832,
        691752902872030564])

In [40]:
print('first ten: ', sorted_hashes[:20])
print('starts', starts[:5])
print('counts', counts[:5])
print('vals', val_hashes[:5])

first ten:  [  9  36  36  45  54  63  72  81  90  90 108 117 153 162 162 162 171 189
 189 189]
starts [0 1 3 4 5]
counts [1 2 1 1 1]
vals [ 9 36 45 54 63]


In [41]:
def get_bucket_ranges(counts):
    bin_items = []
    bin_start_pos = []
    start_pos = 0
    csum = 0
    prev_hash = None
    for i, c in enumerate(counts):
        if csum + c > SIZE_THRESH and csum > 0:
            # need to dump current items
            bin_items.append(csum)
            bin_start_pos.append(start_pos)
            csum = 0
            start_pos = i
        csum += c
    if csum > 0:
        bin_items.append(csum)
        bin_start_pos.append(start_pos)
    return bin_items, bin_start_pos


In [None]:
def cumsum_reset():
    v = np.array([1., 1., 1., np.nan, 1., 1., 1., 1., np.nan, 1.])
    n = np.isnan(v)
    a = ~n  # todo change this
    c = np.cumsum(a)
    d = np.diff(np.concatenate(([0.], c[n])))
    v[n] = -d
    np.cumsum(v)


In [54]:
def cumsum_breach(arr, limit):
    total = 0
    for i, y in enumerate(arr):
        total += y
        if total > limit:
            yield i
            total = 0

np.fromiter(cumsum_breach(np.array([1,2,3,1,6,6,3,1,1,4,6]), limit=5), dtype=int)

array([ 2,  4,  5,  9, 10])

In [69]:
%%timeit -n 3 -r 3
np.fromiter(cumsum_breach(counts, SIZE_THRESH), dtype=int)

127 ms ± 5.9 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [62]:
cz = np.fromiter(cumsum_breach(counts, SIZE_THRESH), dtype=int)
len(cz)

3316

In [85]:
from numba import njit

@njit
def cumsum_breach_numba2(x, target, result):
    total = 0
    iterID = 0
    for i,x_i in enumerate(x):
        total += x_i
        if total > target:
            result[iterID] = i
            iterID += 1
            total = 0
    return iterID

def cumsum_breach_array_init(x, target):
    x = np.asarray(x)
    result = np.empty(len(x),dtype=np.uint64)
    idx = cumsum_breach_numba2(x, target, result)
    return result[:idx]

In [90]:
%%timeit -n 5 -r 1
cumsum_breach_array_init(counts, SIZE_THRESH)

474 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 5 loops each)


In [87]:
len(cumsum_breach_array_init(counts, SIZE_THRESH))

3316

In [48]:
# there is a vectorized way to do this.
# First, find the dict buckets.
# 1. make an array of zeros called start_pos_flags. Set first position is 1.
# 2. Find my_counts > SIZE_THRESH. Set start_pos_flags to 1 for each position, and the position after it.
# 3. Set counts to 0 for each of these.
# Save array of dict bucket positions.
# Next we find the hash buckets.
# while True:
# 1. compute cumulative sums of counts for each segment between two start_pos_flags, put it in csums
# 2. Find the first spot where csums > SIZE_THRESH in each segment. 
# 3. break if all cumulative sums are <= SIZE_THRESH

def bucket_ranges_vec(counts_, limit=SIZE_THRESH):
    counts = np.copy(counts_)
    start_pos_flags = np.zeros(len(counts), dtype=bool)
    start_pos_flags[0] = True
    
    # mark and remove dictbuckets
    dict_pos = np.where(counts > limit)
    start_pos_flags[dict_pos] = True
    for dp in dict_pos: 
        if dp+1 < len(counts):
            start_pos_flags[dp+1] = True
    counts[dict_pos] = 0
    
    # iterate adding buckets until there is no position in the
    # cumulative sum greater than limit
    while True:
        csum_seg = np.zeros((len(counts),))
        flag_pos = np.where(start_pos_flags)
        for i in range(flag_pos):
            if i < len(flag_pos)-1:
                
        overs = np.where(csum_seg > limit)
        if not overs:
            break
        
    return start_pos_flags
    
    
bucket_ranges_vec(np.array([1,2,3,1,6,6,3,1,1,4,6]), limit=5)

In [51]:
os = np.array([])
cs = np.cumsum(os)
np.diff(np.where(cs>5))

array([[1, 1, 1, 1]])

In [63]:
#%%timeit -n 3 -r 3  # 150ms when n_bins is big, otherwise fast

bin_items, bin_start_pos = get_bucket_ranges(counts)


158 ms ± 1.54 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [45]:
len(bin_items)

3339

In [23]:
#timeit -n 3 -r 3  # 150ms when n_bins is big, otherwise fast

# todo always have a hashbucket starting at MIN_HASH, even if first thing is a dictbucket
buckets = SortedDict()
for i in range(len(bin_start_pos)):
    start_hash = val_hashes[bin_start_pos[i]]
    b = HashBucket()  # todo fix dictbucket constructor -  DictBucket(val_hash, )
    #b.val_hashes = sorted_hashes[]
    buckets[start_hash] = b 

In [24]:
buckets = SortedDict()

In [33]:
%%timeit -n 3 -r 3
for i in range(len(bin_start_pos)):
    start_hash = val_hashes[bin_start_pos[i]]
    b = HashBucket()
    b.val_hashes = sorted_hashes[i]
    buckets[start_hash] = b 

4.65 ms ± 986 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [34]:
len(buckets)

3339

In [26]:
1614090106476887972-(2**60)

461168601870040996

In [27]:
s = SortedDict()
s[2**61] = 'a'

In [28]:
buckets[start_hash] = b 

In [29]:
val_hashes[:]

array([                  0,                  27,                  36, ...,
       2075258708346911576, 2075258708346911585, 2075258708346911621])

In [30]:
# iteratively
tots = np.zeros(len(counts), dtype=int)  # running total in bin
bin_starts = np.zeros(len(counts), dtype=int)  # flagged 1 if a bin starts at that pos

dict_bin_flag = np.zeros(len(counts), dtype=bool)
counts[dict_bin_flag] = 0

bigs = np.where(tots > SIZE_THRESH)[0]
bin_starts[bigs] = 1
for i in [-1] + list(range(len(bigs)-1)):
    if i == -1:
        lo = 0
        hi = bigs[0]
    else:
        lo = bigs[i]+1
        hi = bigs[i+1]
    tots[lo:hi] = np.cumsum(counts[lo:hi])
tots[:20]

IndexError: index 0 is out of bounds for axis 0 with size 0

In [None]:
counts[:20]

In [None]:
counts.where()

In [None]:
%%timeit -n 3 -r 3 
np.cumsum(counts)

In [None]:
for i in range(len(bin_items[:10])):
    print(bin_items[i], bin_start_pos[i])

In [None]:
print(counts[10000])

In [None]:
# now make buckets for each range of max-1000 elements
buckets = SortedDict()
csum = 0
bucket_min = HASH_MIN
val_hash_counts = dict()
obj_ids = []
for i in range(len(vals)):
    val_hash = val_hashes[i]
    if counts[i] > SIZE_THRESH or csum + counts[i] > SIZE_THRESH:
        # close current bucket, if any
        if len(obj_ids):
            b = HashBucket()
            b.val_hash_counts = val_hash_counts
            b.obj_ids = obj_ids
            buckets[val_hash_min] = b
            
        # handle new element
        if counts[i] > SIZE_THRESH:
            # this thing goes in a new dict bucket
            csum = 0
            
        else:
            # start a new hash bucket to hold this thing
            pass
    else:
        # continue adding to the 
        csum += counts[i]
        val_hash_counts[val_hash] = counts[i]

In [None]:
print(type(sorted_hashes))
print(type(np.asarray(sorted_hashes)))

In [None]:
# build algo goes like...


In [None]:
%%timeit -n 5 -r 5
z = idx.get_obj_ids('saturn')

In [None]:
planet = 'uranus'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [None]:
for b in idx.bucket_report():
    print(b)

In [None]:
planet = 'venus'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [None]:
for b in idx.bucket_report():
    print(b)

In [None]:
planet = 'mars'
for o in idx.get_objs(planet):
    idx.remove(planet, id(o))

In [None]:
for b in idx.bucket_report():
    print(b)

In [None]:
class Crap:
    def __init__(self):
        self.s = random.random()
crap_items = [Crap() for _ in range(10**6)]

In [None]:
idx = MutableFieldIndex('s')
print('adding', n, 'items')
t0 = time.time()
for item in crap_items:
    idx.add(item)
t1 = time.time()
print('\n', round(t1-t0,3), 'seconds to build this field thing\n')


In [None]:
len(idx.buckets)

In [None]:
%%timeit -n 5 -r 5
idx.get_obj_ids(crap_items[0].s)

In [None]:
from hashindex import HashIndex

t0 = time.time()
hi = HashIndex(items, on='s')
t1 = time.time()
print(t1-t0, 'seconds to build a HashIndex')

hi.freeze()

In [None]:
t0 = time.time()
d = dict()
for i in items:
    if i.s not in d:
        d[i.s] = list()
    d[i.s].append(i)
t1 = time.time()
print(t1-t0, 'seconds to build a dict')

In [None]:
planet = 'mercury'

In [None]:
%%timeit -n 5 -r 5
v = hi.find(match={'s': planet})

In [None]:
%%timeit -n 5 -r 5
v = idx.get(planet)

In [None]:
%%timeit -n 5 -r 5
v = d.get(planet)
# yikes - how is this 1000x faster? something has gone really wrong here! let's see if it's the deref lookup that's
# costing so much

In [None]:
class DerefDict():
    
    def __init__(self, items):
        self.objs = {id(item): item for item in items}
        self.d = dict()
        for i in items:
            if i.s not in self.d:
                self.d[i.s] = list()
            self.d[i.s].append(id(i))
    
    def get(self, val):
        ids = self.d.get(val)
        return [self.objs.get(i) for i in ids]

t0 = time.time()
dd  = DerefDict(items)
t1 = time.time()
print(t1-t0, 'seconds to build a deref dict')

In [None]:
%%timeit -n 5 -r 5
v = dd.get(planet)
# the difference is that you are doing len(planets) dict lookups instead of just one dict lookup.
# Can we keep the list literal around during processing instead?
# e.g. - most of the time we will want an entire list (simple lookup, no intersection). Detect that 
# case and we've got something as good as dict().

In [None]:
class DerefDictListy():
    def __init__(self, items):
        self.objs = []
        self.d = dict()
        for i in items:
            if i.s not in self.d:
                self.d[i.s] = list()
            self.d[i.s].append(len(self.objs))
            self.objs.append(obj)
    
    def get(self, val):
        ids = self.d.get(val)
        return itemgetter(*ids)(self.objs)

t0 = time.time()
ddl  = DerefDict(items)
t1 = time.time()
print(t1-t0, 'seconds to build a deref dict, listy edition')

In [None]:
%%timeit -n 5 -r 5
v = ddl.get(planet)
# the difference is that you are doing len(planets) dict lookups instead of just one dict lookup.
# Can we keep the list literal around during processing instead?
# e.g. - most of the time we will want an entire list (simple lookup, no intersection). Detect that 
# case and we've got something as good as dict().

In [None]:
a = list('abcdefg')
isect_ids = snp.intersect(np.array([2,3,6]), np.array([1,2,3,4,5]), indices=True)[1][1]
itemgetter(*isect_ids)(a)

In [None]:
v = hi.find({'s': planet})
v