In [13]:
import numpy as np
from numba import njit
from numba.typed import List
import heapq
import time
from tqdm.notebook import tqdm, trange

In [48]:
def non_jit(N):
    heap = [[0, 0], ]
    heapq.heapify(heap)
    for n in trange(N):
        heapq.heappush(heap, [n, n**2])
        if n%2==0:
            heapq.heappop(heap)

In [47]:
@njit
def with_jit(N):
    heap = [[0., 0.], ]
    heapq.heapify(heap)
    for n in range(N):
        heapq.heappush(heap, [float(n), float(n)**2])
        if n%2==0:
            heapq.heappop(heap)

In [42]:
def time_func(func):
    Ns = [100, 1_000, 10_000, 100_000, 1_000_000]
    ts = []
    for N in Ns:
        t = time.time()
        func(N)
        ts.append(time.time() - t)
    print(ts)

In [49]:
time_func(non_jit)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

[0.03206920623779297, 0.03176307678222656, 0.045729637145996094, 0.3442039489746094, 2.4755916595458984]


In [51]:
time_func(with_jit)

[8.845329284667969e-05, 0.0009832382202148438, 0.010970354080200195, 0.12818694114685059, 1.4810428619384766]
