In [1]:
import heapq
import torch
import bisect
import random
import collections

import numpy as np

In [2]:
embed = torch.rand(size=(400003, 768))
embed.shape

torch.Size([400003, 768])

In [3]:
other = torch.rand(size=(768,))
other.shape

torch.Size([768])

In [4]:
class Element:
    def __init__(self, dists, ix, other):
        self._dists = dists
        self._ix = ix
        self._other = other
        self._cache = None
        
    def clear(self):
        self._cache = None
        
    def get(self):
        if self._cache is None:
             self._cache = self._other.dot(self._dists[self._ix, :].ravel()).item()
        return self._cache
    
    def index(self):
        return self._ix
    
    def _cmp(self, other):
        return self.get() - other.get()

    def __lt__(self, other):
        return self._cmp(other) < 0

    def __le__(self, other):
        return self._cmp(other) <= 0

    def __eq__(self, other):
        return self._cmp(other) == 0

    def __ne__(self, other):
        return self._cmp(other) != 0

    def __ge__(self, other):
        return self._cmp(other) >= 0

    def __gt__(self, other):
        return self._cmp(other) > 0
    
    def __hash__(self):
        return hash(self.get())
    
    def __repr__(self):
        return f"{self._ix}[{self.get()}]"

In [5]:
elems = (
    # [Element(embed, ix, other) for ix in range(embed.shape[0])] +
    [Element(embed, ix, other) for ix in range(embed.shape[0])])
random.shuffle(elems)
elems[0], elems[1]

(215020[200.6805419921875], 258108[201.31385803222656])

In [6]:
def clear_cache(elems):
    for elem in elems:
        # elem.clear()
        elem.get()
    return elems.copy()

In [7]:
elems = clear_cache(elems)

In [8]:
def largest_sort(count, elems):
    elems = clear_cache(elems)

    # already = set()
    candidates = []
    for cur in elems:
        # ix = cur.index()
        # if ix in already:
        #     continue
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            candidates.append(cur)
            mod = True
        elif dist > candidates[-1].get():
            # already.remove(candidates[-1].index())
            candidates[-1] = cur
            mod = True
        if mod:
            # already.add(ix)
            candidates.sort(reverse=True)
    return candidates

In [9]:
def largest_heap(count, elems):
    elems = clear_cache(elems)
    
    heap = []
    # already = set()
    for elem in elems:
        # ix = elem.index()
        # if ix in already:
        #     continue
        # already.add(ix)
        if len(heap) < count:
            heapq.heappush(heap, elem)
        else:
            heapq.heappushpop(heap, elem)
            # already.remove(prev.index())
    return heapq.nlargest(count, heap)

In [10]:
def largest_bisect(count, elems):
    elems = clear_cache(elems)
    
    candidates = collections.deque(maxlen=count)
    for cur in elems:
        if candidates and cur.get() <= candidates[0].get():
            continue
        # ix = cur.index()
        pos = bisect.bisect_left(candidates, cur)
        # if pos < len(candidates) and candidates[pos].index() == ix:
        #     continue
        if len(candidates) >= count:
            if pos == 0:
                continue
            candidates.popleft()
            pos -= 1
        candidates.insert(pos, cur)
    return list(candidates)[::-1]

In [11]:
def largest_bisect_replace(count, elems):
    elems = clear_cache(elems)

    candidates = []
    for cur in elems:
        # ix = cur.index()
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            # pos = bisect.bisect_left(candidates, cur)
            # if pos < len(candidates) and candidates[pos].index() == ix:
            #     continue
            candidates.append(cur)
            mod = True
        elif dist > candidates[0].get():
            # pos = bisect.bisect_left(candidates, cur)
            # if pos < len(candidates) and candidates[pos].index() == ix:
            #     continue
            candidates[0] = cur
            mod = True
        if mod:
            candidates.sort()
    return candidates[::-1]

In [12]:
def largest_sort_all(count, elems):
    elems = clear_cache(elems)
    
    elems.sort(key=lambda entry: entry.get(), reverse=True)
    return elems[:count]

In [13]:
def largest_recursive(count, elems):
    elems = clear_cache(elems)
    batch = count * 2
    
    def key(entry):
        return entry.get()

    def from_sorted(arr):
        return arr[:count]
        # ret = 0
        # prev = None
        # for elem in arr:
        #     if prev is not None and prev.index() == elem.index():
        #         continue
        #     yield elem
        #     ret += 1
        #     if ret >= count:
        #         break
        #     prev = elem
    
    def inner(arr, start, end):
        size = end - start
        if size <= batch:
            return from_sorted(sorted(arr[start:end], key=key, reverse=True))
        mid = start + size // 2
        out_arr = []
        out_arr.extend(inner(arr, start, mid))
        out_arr.extend(inner(arr, mid, end))
        return from_sorted(sorted(out_arr, key=key, reverse=True))
    
    return inner(elems, 0, len(elems))

In [14]:
def largest_fast(count, elems):
    elems = clear_cache(elems)
    buff = count
    
    # already = set()
    arr = []
    candidates = []
    for cur in elems:
        if candidates and candidates[-1] >= cur:
            continue
        # ix = cur.index()
        # if ix in already:
        #     continue
        arr.append(cur)
        # already.add(ix)
        if len(arr) < buff:
            continue
        candidates.extend(arr)
        arr.clear()
        candidates.sort(reverse=True)
        candidates = candidates[:count]
        # already = set((elem.index() for elem in candidates))
    candidates.extend(arr)
    candidates.sort(reverse=True)
    return candidates[:count]

In [15]:
def largest_spec(count, elems):
    elems = clear_cache(elems)
    # a_buff = 100

    # already = set()
    candidates = []
    for cur in elems:
        # ix = cur.index()
        # if ix in already:
        #     continue
        dist = cur.get()
        mod = False
        if len(candidates) < count:
            candidates.append(cur)
            mod = True
        elif dist > candidates[-1].get():
            candidates[-1] = cur
            mod = True
        if mod:
            # if len(already) < a_buff:
            #     already.add(ix)
            # else:
            #     already = set(elem.index() for elem in candidates)
            candidates.sort(
                key=lambda entry: entry.get(), reverse=True)
    return candidates

In [16]:
def largest_manual(count, elems):
    elems = clear_cache(elems)
    
    def pivot(arr, left, right):
        return left + (right - left) // 2
    
    def choose(arr, left, right, remain):
        if right - left <= remain:
            return sorted(arr[left:right], reverse=True)
        pivot_ix = pivot(arr, left, right)
        last_ix = right - 1
        arr[pivot_ix], arr[last_ix] = arr[last_ix], arr[pivot_ix]
        pivot_ix = last_ix
        ix = left
        while ix < pivot_ix:
            if arr[ix] < arr[pivot_ix]:
                ix += 1
            elif arr[ix] > arr[pivot_ix]:
                arr[ix], arr[pivot_ix] = arr[pivot_ix], arr[ix]
                pivot_ix -= 1
                arr[ix], arr[pivot_ix] = arr[pivot_ix], arr[ix]
            else:
                ix += 1
        size = right - pivot_ix
        if size == remain:
            return sorted(arr[pivot_ix:right], reverse=True)
        if size > remain:
            return choose(arr, pivot_ix, right, remain)
        return choose(arr, pivot_ix, right, remain) + choose(arr, left, pivot_ix, remain - size)
    
    return choose(elems, 0, len(elems), count)

In [17]:
def largest_fast_swap(count, elems, use_manual):
    elems = clear_cache(elems)
    
    def merge_manual(arr, candidates):
        arr.sort(reverse=True)
        if candidates[0] is None:
            candidates[:] = arr
            return
        for arr_ix in range(len(arr) - 1, -1, -1):
            last = candidates[-1]
            can_ix = len(candidates) - 2
            while(can_ix >= 0 and candidates[can_ix] < arr[arr_ix]):
                candidates[can_ix + 1] = candidates[can_ix]
                can_ix -= 1
            if (last < arr[arr_ix]):
                candidates[can_ix + 1] = arr[arr_ix]
                arr[arr_ix] = last
    
    def merge_sort(arr, candidates):
        arr.sort(reverse=True)
        if candidates[0] is None:
            candidates[:count] = arr
            return
        candidates[count:] = arr
        candidates.sort(reverse=True)
    
    merge = merge_manual if use_manual else merge_sort
    buff = [None] * count
    res = [None] * (count if use_manual else count * 2)
    buff_ix = 0
    for cur in elems:
        if res[0] is not None and res[count - 1] > cur:
            continue
        buff[buff_ix] = cur
        buff_ix += 1
        if buff_ix < len(buff):
            continue
        merge(buff, res)
        buff_ix = 0
    if buff_ix > 0:
        merge(buff[:buff_ix], res)
    return res[:count]

In [18]:
def largest_heapq(count, elems):
    elems = clear_cache(elems)
    
    return heapq.nlargest(count, elems)

In [19]:
def largest_loop_heapq(count, elems):
    elems = clear_cache(elems)
    buff = count
    
    arr = []
    candidates = []
    for cur in elems:
        if candidates and candidates[-1] >= cur:
            continue
        arr.append(cur)
        if len(arr) < buff:
            continue
        candidates.extend(arr)
        arr.clear()
        candidates = heapq.nlargest(count, candidates)
    candidates.extend(arr)
    return heapq.nlargest(count, candidates)

In [20]:
def largest_nocopy(count, elems):
    elems = clear_cache(elems)
    buff_size = count
    
    candidates = elems[:count + buff_size]
    candidates.sort(reverse=True)
    can_ix = count
    cur_cmp = candidates[count - 1]
    for cur_ix in range(count + buff_size, len(elems)):
        cur = elems[cur_ix]
        if cur <= cur_cmp:
            continue
        candidates[can_ix] = cur
        can_ix += 1
        if can_ix < len(candidates):
            continue
        can_ix = count
        candidates.sort(reverse=True)
        cur_cmp = candidates[count - 1]
    if can_ix > count:
        candidates = candidates[:can_ix]
        candidates.sort(reverse=True)
    return candidates[:count]

In [21]:
def largest_bisect_2(count, elems):
    elems = clear_cache(elems)
    
    candidates = collections.deque(maxlen=count)
    cur_cmp = None
    for cur in elems:
        if candidates:
            if cur <= cur_cmp:
                continue
            pos = bisect.bisect_left(candidates, cur)
            if len(candidates) >= count:
                if pos == 0:
                    continue
                candidates.popleft()
                pos -= 1
        else:
            pos = 0
        candidates.insert(pos, cur)
        cur_cmp = candidates[0]
    return list(candidates)[::-1]

In [22]:
def largest_numpy(count, elems):
    elems = clear_cache(elems)
    
    arr = np.array([elem.get() for elem in elems])
    ind = np.argpartition(arr, -count)[-count:]
    return [elems[ix] for ix in ind[np.argsort(arr[ind])[::-1]]]

In [23]:
# time

In [24]:
%%time

largest_sort(10, elems)

CPU times: user 241 ms, sys: 10.6 ms, total: 252 ms
Wall time: 258 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [25]:
%%time

largest_heap(10, elems)

CPU times: user 265 ms, sys: 10.1 ms, total: 275 ms
Wall time: 290 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [26]:
%%time

largest_bisect(10, elems)

CPU times: user 207 ms, sys: 2.53 ms, total: 209 ms
Wall time: 209 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [27]:
%%time

largest_bisect_replace(10, elems)

CPU times: user 229 ms, sys: 3.57 ms, total: 233 ms
Wall time: 232 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [28]:
%%time

largest_recursive(10, elems)

CPU times: user 276 ms, sys: 4.05 ms, total: 280 ms
Wall time: 283 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [29]:
%%time

largest_fast(10, elems)

CPU times: user 227 ms, sys: 4.24 ms, total: 231 ms
Wall time: 234 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [30]:
%%time

largest_spec(10, elems)

CPU times: user 231 ms, sys: 3.09 ms, total: 234 ms
Wall time: 235 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [31]:
%%time

largest_manual(10, elems)

CPU times: user 685 ms, sys: 6.88 ms, total: 692 ms
Wall time: 693 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [32]:
%%time

largest_sort_all(10, elems)

CPU times: user 247 ms, sys: 5.14 ms, total: 252 ms
Wall time: 252 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [33]:
%%time

largest_fast_swap(10, elems, use_manual=False)

CPU times: user 232 ms, sys: 6.74 ms, total: 239 ms
Wall time: 240 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [34]:
%%time

largest_fast_swap(10, elems, use_manual=True)

CPU times: user 239 ms, sys: 5.06 ms, total: 244 ms
Wall time: 247 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [35]:
%%time

largest_heapq(10, elems)

CPU times: user 220 ms, sys: 3.97 ms, total: 224 ms
Wall time: 228 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [36]:
%%time

largest_loop_heapq(10, elems)

CPU times: user 220 ms, sys: 3.32 ms, total: 223 ms
Wall time: 223 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [37]:
%%time

largest_nocopy(10, elems)

CPU times: user 255 ms, sys: 3.71 ms, total: 259 ms
Wall time: 258 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [38]:
%%time

largest_bisect_2(10, elems)

CPU times: user 249 ms, sys: 4.46 ms, total: 254 ms
Wall time: 256 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [39]:
%%time

largest_numpy(10, elems)

CPU times: user 196 ms, sys: 4.06 ms, total: 200 ms
Wall time: 203 ms


[298074[223.73533630371094],
 84684[223.57077026367188],
 217015[222.9144744873047],
 185183[222.31655883789062],
 376916[222.1710968017578],
 15974[221.45262145996094],
 10551[221.3866424560547],
 293900[221.2071533203125],
 337383[221.08334350585938],
 357838[221.02090454101562]]

In [40]:
# timeit

In [41]:
%%timeit

largest_sort(10, elems)

222 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
%%timeit

largest_heap(10, elems)

247 ms ± 6.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%%timeit

largest_bisect(10, elems)

203 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
%%timeit

largest_bisect_replace(10, elems)

226 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
%%timeit

largest_recursive(10, elems)

265 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [46]:
%%timeit

largest_fast(10, elems)

220 ms ± 3.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
%%timeit

largest_spec(10, elems)

223 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [48]:
%%timeit

largest_manual(10, elems)

697 ms ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [49]:
%%timeit

largest_sort_all(10, elems)

263 ms ± 72.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
%%timeit

largest_fast_swap(10, elems, use_manual=False)

251 ms ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [51]:
%%timeit

largest_fast_swap(10, elems, use_manual=True)

240 ms ± 6.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [52]:
%%timeit

largest_heapq(10, elems)

211 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
%%timeit

largest_loop_heapq(10, elems)

232 ms ± 15.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
%%timeit

largest_nocopy(10, elems)

310 ms ± 52.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
%%timeit

largest_bisect_2(10, elems)

248 ms ± 8.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [56]:
%%timeit

largest_numpy(10, elems)

191 ms ± 5.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [57]:
along = torch.rand(size=(1000000, 768))
olong = torch.rand(size=(768,))
elong = [Element(along, ix, olong) for ix in range(along.shape[0])]
elong = clear_cache(elong)

In [58]:
%%time

largest_numpy(10, elong)

CPU times: user 307 ms, sys: 87.8 ms, total: 395 ms
Wall time: 886 ms


[792112[213.24227905273438],
 602395[212.86349487304688],
 385839[212.81141662597656],
 673029[212.56121826171875],
 476611[212.2786407470703],
 715213[212.15164184570312],
 227910[211.85107421875],
 638866[211.81922912597656],
 218035[211.71377563476562],
 785145[211.5546112060547]]

In [59]:
%%time

largest_numpy(50, elong)

CPU times: user 291 ms, sys: 33.3 ms, total: 324 ms
Wall time: 368 ms


[792112[213.24227905273438],
 602395[212.86349487304688],
 385839[212.81141662597656],
 673029[212.56121826171875],
 476611[212.2786407470703],
 715213[212.15164184570312],
 227910[211.85107421875],
 638866[211.81922912597656],
 218035[211.71377563476562],
 785145[211.5546112060547],
 564156[211.40933227539062],
 784878[211.39935302734375],
 867588[211.39801025390625],
 498661[210.9582061767578],
 276785[210.75697326660156],
 106645[210.74668884277344],
 288094[210.57296752929688],
 974268[210.541748046875],
 445894[210.4384765625],
 336998[210.4301300048828],
 150145[210.41493225097656],
 409274[210.3786163330078],
 844095[210.130126953125],
 893539[210.0985107421875],
 216788[210.07603454589844],
 76082[210.06900024414062],
 218754[210.050537109375],
 178360[209.95828247070312],
 448742[209.88134765625],
 780556[209.84066772460938],
 315921[209.82101440429688],
 127181[209.75350952148438],
 365113[209.74267578125],
 414989[209.72650146484375],
 751845[209.71609497070312],
 550913[209