In [1]:
import numpy as np
from typing import Tuple
from numba import jit

In [65]:
def make_x(rng, n, p):
    return rng.choice([0, 1, 2, np.nan], size=(n, p))

v = {
    (1, 0),
    (1, 2),
    (0, 1),
    (2, 1)
}

def clean(arr_a: np.ndarray, arr_b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    new_a = []
    new_b = []
    for a, b in zip(arr_a, arr_b):
        if (np.isnan(a) or np.isnan(b)) or ((a, b) in v):
            new_a.append(a)
            new_b.append(b)
    return np.asarray(new_a), np.asarray(new_b)

In [60]:
@jit(nopython=True, parallel=True)
def cmp_no_expl_loop(xm: np.ndarray, xc: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    valid = ~(np.isnan(xm) | np.isnan(xc))
    xmv = xm[valid].astype(np.int_)
    xcv = xc[valid].astype(np.int_)

    m = np.zeros_like(xmv, dtype=np.int_)
    c = np.zeros_like(xcv, dtype=np.int_)

    m_sel = [xcv == 1]
    m[m_sel] = -1 + xmv[m_sel]

    c_sel = [xcv == 1]
    c[xmv == 1] = -1 + xcv[xmv == 1]
    
    return c, m

In [61]:
@jit(nopython=True, parallel=True)
def cmp_expl_loop(xm: np.ndarray, xc: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    c = np.zeros_like(xm, dtype=np.int_)
    m = np.zeros_like(xm, dtype=np.int_)
    write_ix = 0
    for read_ix in np.arange(len(xm)):
        if np.isnan(xm[read_ix]) | np.isnan(xc[read_ix]):
            continue
        curr_xc = xc[read_ix]
        curr_xm = xm[read_ix]
        if curr_xc == 0:
            c[write_ix] = -1
        elif curr_xc == 2:
            c[write_ix] = 1
        if curr_xm == 0:
            m[write_ix] = -1
        elif curr_xm == 2:
            m[write_ix] = 1
        write_ix += 1
    return c[:write_ix], m[:write_ix]

In [62]:
rng = np.random.default_rng()
n = 1000
p = 10000

xm, xc = make_x(rng, n, p).flatten(), make_x(rng, n, p).flatten()

c1, m1 = cmp_no_expl_loop(xm, xc)
c2, m2 = cmp_expl_loop(xm, xc)

In [63]:
%%timeit
c1, m1 = cmp_no_expl_loop(xm.flatten(), xc.flatten())

600 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [64]:
%%timeit
c2, m2 = cmp_expl_loop(xm.flatten(), xc.flatten())

451 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
