In [1]:
import heapq
import torch
import bisect
import random
import collections

In [2]:
embed = torch.rand(size=(400003, 768))
embed.shape

torch.Size([400003, 768])

In [3]:
other = torch.rand(size=(768,))
other.shape

torch.Size([768])

In [4]:
class Element:
    def __init__(self, dists, ix, other):
        self._dists = dists
        self._ix = ix
        self._other = other
        self._cache = None
        
    def clear(self):
        self._cache = None
        
    def get(self):
        if self._cache is None:
             self._cache = self._other.dot(self._dists[self._ix, :].ravel()).item()
        return self._cache
    
    def index(self):
        return self._ix
    
    def _cmp(self, other):
        return self.get() - other.get()

    def __lt__(self, other):
        return self._cmp(other) < 0

    def __le__(self, other):
        return self._cmp(other) <= 0

    def __eq__(self, other):
        return self._cmp(other) == 0

    def __ne__(self, other):
        return self._cmp(other) != 0

    def __ge__(self, other):
        return self._cmp(other) >= 0

    def __gt__(self, other):
        return self._cmp(other) > 0
    
    def __hash__(self):
        return hash(self.get())
    
    def __repr__(self):
        return f"{self._ix}[{self.get()}]"

In [5]:
elems = (
    # [Element(embed, ix, other) for ix in range(embed.shape[0])] +
    [Element(embed, ix, other) for ix in range(embed.shape[0])])
random.shuffle(elems)
elems[0], elems[1]

(331624[207.1761016845703], 176049[209.71820068359375])

In [6]:
def clear_cache(elems):
    for elem in elems:
        # elem.clear()
        elem.get()
    return elems.copy()

In [7]:
elems = clear_cache(elems)

In [8]:
def largest_sort(count, elems):
    elems = clear_cache(elems)

    # already = set()
    candidates = []
    for cur in elems:
        # ix = cur.index()
        # if ix in already:
        #     continue
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            candidates.append(cur)
            mod = True
        elif dist > candidates[-1].get():
            # already.remove(candidates[-1].index())
            candidates[-1] = cur
            mod = True
        if mod:
            # already.add(ix)
            candidates.sort(reverse=True)
    return candidates

In [9]:
def largest_heap(count, elems):
    elems = clear_cache(elems)
    
    heap = []
    # already = set()
    for elem in elems:
        # ix = elem.index()
        # if ix in already:
        #     continue
        # already.add(ix)
        if len(heap) < count:
            heapq.heappush(heap, elem)
        else:
            heapq.heappushpop(heap, elem)
            # already.remove(prev.index())
    return heapq.nlargest(count, heap)

In [10]:
def largest_bisect(count, elems):
    elems = clear_cache(elems)
    
    candidates = collections.deque(maxlen=count)
    for cur in elems:
        if candidates and cur.get() <= candidates[0].get():
            continue
        # ix = cur.index()
        pos = bisect.bisect_left(candidates, cur)
        # if pos < len(candidates) and candidates[pos].index() == ix:
        #     continue
        if len(candidates) >= count:
            if pos == 0:
                continue
            candidates.popleft()
            pos -= 1
        candidates.insert(pos, cur)
    return list(candidates)[::-1]

In [11]:
def largest_bisect_replace(count, elems):
    elems = clear_cache(elems)

    candidates = []
    for cur in elems:
        # ix = cur.index()
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            # pos = bisect.bisect_left(candidates, cur)
            # if pos < len(candidates) and candidates[pos].index() == ix:
            #     continue
            candidates.append(cur)
            mod = True
        elif dist > candidates[0].get():
            # pos = bisect.bisect_left(candidates, cur)
            # if pos < len(candidates) and candidates[pos].index() == ix:
            #     continue
            candidates[0] = cur
            mod = True
        if mod:
            candidates.sort()
    return candidates[::-1]

In [12]:
def largest_sort_all(count, elems):
    elems = clear_cache(elems)
    
    elems.sort(key=lambda entry: entry.get(), reverse=True)
    return elems[:count]

In [13]:
def largest_recursive(count, elems):
    elems = clear_cache(elems)
    batch = count * 2
    
    def key(entry):
        return entry.get()

    def from_sorted(arr):
        return arr[:count]
        # ret = 0
        # prev = None
        # for elem in arr:
        #     if prev is not None and prev.index() == elem.index():
        #         continue
        #     yield elem
        #     ret += 1
        #     if ret >= count:
        #         break
        #     prev = elem
    
    def inner(arr, start, end):
        size = end - start
        if size <= batch:
            return from_sorted(sorted(arr[start:end], key=key, reverse=True))
        mid = start + size // 2
        out_arr = []
        out_arr.extend(inner(arr, start, mid))
        out_arr.extend(inner(arr, mid, end))
        return from_sorted(sorted(out_arr, key=key, reverse=True))
    
    return inner(elems, 0, len(elems))

In [14]:
def largest_fast(count, elems):
    elems = clear_cache(elems)
    buff = count
    
    # already = set()
    arr = []
    candidates = []
    for cur in elems:
        if candidates and candidates[-1] >= cur:
            continue
        # ix = cur.index()
        # if ix in already:
        #     continue
        arr.append(cur)
        # already.add(ix)
        if len(arr) < buff:
            continue
        candidates.extend(arr)
        arr.clear()
        candidates.sort(reverse=True)
        candidates = candidates[:count]
        # already = set((elem.index() for elem in candidates))
    candidates.extend(arr)
    candidates.sort(reverse=True)
    return candidates[:count]

In [15]:
def largest_spec(count, elems):
    elems = clear_cache(elems)
    # a_buff = 100

    # already = set()
    candidates = []
    for cur in elems:
        # ix = cur.index()
        # if ix in already:
        #     continue
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            candidates.append(cur)
            mod = True
        elif dist > candidates[-1].get():
            candidates[-1] = cur
            mod = True
        if mod:
            # if len(already) < a_buff:
            #     already.add(ix)
            # else:
            #     already = set(elem.index() for elem in candidates)
            candidates.sort(
                key=lambda entry: entry.get(), reverse=True)
    return candidates

In [16]:
def largest_manual(count, elems):
    elems = clear_cache(elems)
    
    def pivot(arr, left, right):
        return left + (right - left) // 2
    
    def choose(arr, left, right, remain):
        if right - left <= remain:
            return sorted(arr[left:right], reverse=True)
        pivot_ix = pivot(arr, left, right)
        last_ix = right - 1
        arr[pivot_ix], arr[last_ix] = arr[last_ix], arr[pivot_ix]
        pivot_ix = last_ix
        ix = left
        while ix < pivot_ix:
            if arr[ix] < arr[pivot_ix]:
                ix += 1
            elif arr[ix] > arr[pivot_ix]:
                arr[ix], arr[pivot_ix] = arr[pivot_ix], arr[ix]
                pivot_ix -= 1
                arr[ix], arr[pivot_ix] = arr[pivot_ix], arr[ix]
            else:
                ix += 1
        size = right - pivot_ix
        if size == remain:
            return sorted(arr[pivot_ix:right], reverse=True)
        if size > remain:
            return choose(arr, pivot_ix, right, remain)
        return choose(arr, pivot_ix, right, remain) + choose(arr, left, pivot_ix, remain - size)
    
    return choose(elems, 0, len(elems), count)

In [17]:
def largest_fast_swap(count, elems, use_manual):
    elems = clear_cache(elems)
    
    def merge_manual(arr, candidates):
        arr.sort(reverse=True)
        if candidates[0] is None:
            candidates[:] = arr
            return
        for arr_ix in range(len(arr) - 1, -1, -1):
            last = candidates[-1]
            can_ix = len(candidates) - 2
            while(can_ix >= 0 and candidates[can_ix] < arr[arr_ix]):
                candidates[can_ix + 1] = candidates[can_ix]
                can_ix -= 1
            if (last < arr[arr_ix]):
                candidates[can_ix + 1] = arr[arr_ix]
                arr[arr_ix] = last
    
    def merge_sort(arr, candidates):
        arr.sort(reverse=True)
        if candidates[0] is None:
            candidates[:count] = arr
            return
        candidates[count:] = arr
        candidates.sort(reverse=True)
    
    merge = merge_manual if use_manual else merge_sort
    buff = [None] * count
    res = [None] * (count if use_manual else count * 2)
    buff_ix = 0
    for cur in elems:
        if res[0] is not None and res[count - 1] > cur:
            continue
        buff[buff_ix] = cur
        buff_ix += 1
        if buff_ix < len(buff):
            continue
        merge(buff, res)
        buff_ix = 0
    if buff_ix > 0:
        merge(buff[:buff_ix], res)
    return res[:count]

In [18]:
def largest_heapq(count, elems):
    elems = clear_cache(elems)
    
    return heapq.nlargest(count, elems)

In [19]:
def largest_loop_heapq(count, elems):
    elems = clear_cache(elems)
    buff = count
    
    arr = []
    candidates = []
    for cur in elems:
        if candidates and candidates[-1] >= cur:
            continue
        arr.append(cur)
        if len(arr) < buff:
            continue
        candidates.extend(arr)
        arr.clear()
        candidates = heapq.nlargest(count, candidates)
    candidates.extend(arr)
    return heapq.nlargest(count, candidates)

In [20]:
def largest_nocopy(count, elems):
    elems = clear_cache(elems)
    buff_size = count
    
    candidates = elems[:count + buff_size]
    candidates.sort(reverse=True)
    can_ix = count
    cur_cmp = candidates[count - 1]
    for cur_ix in range(count + buff_size, len(elems)):
        cur = elems[cur_ix]
        if cur <= cur_cmp:
            continue
        candidates[can_ix] = cur
        can_ix += 1
        if can_ix < len(candidates):
            continue
        can_ix = count
        candidates.sort(reverse=True)
        cur_cmp = candidates[count - 1]
    if can_ix > count:
        candidates = candidates[:can_ix]
        candidates.sort(reverse=True)
    return candidates[:count]

In [21]:
def largest_bisect_2(count, elems):
    elems = clear_cache(elems)
    
    candidates = collections.deque(sorted(elems[:count]), maxlen=count)
    cur_cmp = candidates[0]
    for cur_ix in range(count, len(elems)):
        cur = elems[cur_ix]
        if cur <= cur_cmp:
            continue
        pos = bisect.bisect_left(candidates, cur)
        if pos == 0:
            continue
        candidates.popleft()
        candidates.insert(pos - 1, cur)
        cur_cmp = candidates[0]
    return list(candidates)[::-1]

In [22]:
# time

In [23]:
%%time

largest_sort(10, elems)

CPU times: user 233 ms, sys: 4.91 ms, total: 238 ms
Wall time: 240 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [24]:
%%time

largest_heap(10, elems)

CPU times: user 237 ms, sys: 4.42 ms, total: 241 ms
Wall time: 242 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [25]:
%%time

largest_bisect(10, elems)

CPU times: user 203 ms, sys: 2.7 ms, total: 206 ms
Wall time: 206 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [26]:
%%time

largest_bisect_replace(10, elems)

CPU times: user 238 ms, sys: 4.37 ms, total: 242 ms
Wall time: 244 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [27]:
%%time

largest_recursive(10, elems)

CPU times: user 259 ms, sys: 2.89 ms, total: 262 ms
Wall time: 262 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [28]:
%%time

largest_fast(10, elems)

CPU times: user 215 ms, sys: 3.47 ms, total: 219 ms
Wall time: 219 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [29]:
%%time

largest_spec(10, elems)

CPU times: user 230 ms, sys: 4.14 ms, total: 234 ms
Wall time: 234 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [30]:
%%time

largest_manual(10, elems)

CPU times: user 887 ms, sys: 8.8 ms, total: 896 ms
Wall time: 898 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [31]:
%%time

largest_sort_all(10, elems)

CPU times: user 247 ms, sys: 4.97 ms, total: 252 ms
Wall time: 253 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [32]:
%%time

largest_fast_swap(10, elems, use_manual=False)

CPU times: user 224 ms, sys: 3.23 ms, total: 227 ms
Wall time: 227 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [33]:
%%time

largest_fast_swap(10, elems, use_manual=True)

CPU times: user 229 ms, sys: 3.09 ms, total: 232 ms
Wall time: 232 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [34]:
%%time

largest_heapq(10, elems)

CPU times: user 223 ms, sys: 5.48 ms, total: 228 ms
Wall time: 233 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [35]:
%%time

largest_loop_heapq(10, elems)

CPU times: user 221 ms, sys: 4.67 ms, total: 226 ms
Wall time: 230 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [36]:
%%time

largest_nocopy(10, elems)

CPU times: user 244 ms, sys: 3.39 ms, total: 247 ms
Wall time: 247 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [37]:
%%time

largest_bisect_2(10, elems)

CPU times: user 237 ms, sys: 2.42 ms, total: 239 ms
Wall time: 239 ms


[392627[223.7050323486328],
 299264[223.30430603027344],
 125886[222.54617309570312],
 182242[222.53273010253906],
 22261[222.20208740234375],
 47451[221.66619873046875],
 381434[221.57395935058594],
 346655[221.5247802734375],
 31258[221.4044189453125],
 805[221.12020874023438]]

In [38]:
# timeit

In [39]:
%%timeit

largest_sort(10, elems)

214 ms ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%%timeit

largest_heap(10, elems)

232 ms ± 9.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
%%timeit

largest_bisect(10, elems)

189 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [42]:
%%timeit

largest_bisect_replace(10, elems)

214 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%%timeit

largest_recursive(10, elems)

259 ms ± 9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
%%timeit

largest_fast(10, elems)

211 ms ± 5.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
%%timeit

largest_spec(10, elems)

210 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [46]:
%%timeit

largest_manual(10, elems)

861 ms ± 5.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
%%timeit

largest_sort_all(10, elems)

226 ms ± 5.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [48]:
%%timeit

largest_fast_swap(10, elems, use_manual=False)

214 ms ± 2.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [49]:
%%timeit

largest_fast_swap(10, elems, use_manual=True)

236 ms ± 34.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
%%timeit

largest_heapq(10, elems)

208 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [51]:
%%timeit

largest_loop_heapq(10, elems)

205 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [52]:
%%timeit

largest_nocopy(10, elems)

235 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
%%timeit

largest_bisect_2(10, elems)

235 ms ± 1.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
# along = torch.rand(size=(1000000, 768))
# olong = torch.rand(size=(768,))
# elong = [Element(along, ix, olong) for ix in range(along.shape[0])]
# elong = clear_cache(elong)

In [55]:
# %%time

# largest_heapq(10, elong)

In [56]:
# %%time

# largest_heapq(50, elong)