In [1]:
import numpy as np
from typing import Tuple
from numba import jit

In [7]:
def make_x(rng, n, p):
    return rng.choice([0, 1, 2, np.nan], size=(n, p))

In [8]:
@jit(nopython=True, parallel=True)
def cmp_no_expl_loop(xm: np.ndarray, xc: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    valid = ~(np.isnan(xm) | np.isnan(xc))
    xmv = xm[valid].astype(np.int_)
    xcv = xc[valid].astype(np.int_)

    m = np.zeros_like(xmv, dtype=np.int_)
    c = np.zeros_like(xcv, dtype=np.int_)

    m[xcv == 1] = 1 - xmv[xcv == 1]
    m *= -1

    c[xmv == 1] = 1 - xcv[xmv == 1]
    c *= -1
    
    return c, m

In [9]:
@jit(nopython=True, parallel=True)
def cmp_expl_loop(xm: np.ndarray, xc: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    c = np.zeros_like(xm, dtype=np.int_)
    m = np.zeros_like(xm, dtype=np.int_)
    write_ix = 0
    for read_ix in np.arange(len(xm)):
        if np.isnan(xm[read_ix]) | np.isnan(xc[read_ix]):
            continue
        curr_xc = xc[read_ix]
        curr_xm = xm[read_ix]
        if curr_xc == 0:
            c[write_ix] = -1
        elif curr_xc == 2:
            c[write_ix] = 1
        if curr_xm == 0:
            m[write_ix] = -1
        elif curr_xm == 2:
            m[write_ix] = 1
        write_ix += 1
    return c[:write_ix], m[:write_ix]

In [10]:
rng = np.random.default_rng()
n = 1000
p = 10000
xm = make_x(rng, n, p)
xc = make_x(rng, n, p)

c1, m1 = cmp_no_expl_loop(xm.flatten(), xc.flatten())
c2, m2 = cmp_expl_loop(xm.flatten(), xc.flatten())

In [11]:
%%timeit
c1, m1 = cmp_no_expl_loop(xm.flatten(), xc.flatten())

651 ms ± 39.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
c2, m2 = cmp_expl_loop(xm.flatten(), xc.flatten())

466 ms ± 19.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
