In [8]:
import numpy as np
from numba import njit
from numba.typed import List
import heapq
import time
from tqdm.notebook import tqdm, trange

In [9]:
def non_jit(N):
    heap = [[0, 0], ]
    heapq.heapify(heap)
    for n in trange(N):
        heapq.heappush(heap, [n, n**2])
        if n%2==0:
            heapq.heappop(heap)

In [10]:
@njit
def with_jit(N):
    heap = [[0., 0.], ]
    heapq.heapify(heap)
    for n in range(N):
        heapq.heappush(heap, [float(n), float(n)**2])
        if n%2==0:
            heapq.heappop(heap)

In [15]:
def time_func(func):
    Ns = [100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000]
    ts = []
    for N in Ns:
        t = time.time()
        func(N)
        ts.append(time.time() - t)
    print(ts)

In [16]:
time_func(non_jit)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

[0.04896879196166992, 0.04239010810852051, 0.05488181114196777, 0.3565330505371094, 2.676194667816162, 28.587865829467773]


In [18]:
time_func(with_jit)

[9.751319885253906e-05, 0.0007112026214599609, 0.05316877365112305, 0.09879612922668457, 1.1171901226043701, 11.232038497924805]
