# Optimizing Scikit-Learn with Array API and Numba

https://scikit-learn.org/stable/modules/array_api.html
https://scikit-learn.org/stable/modules/lda_qda.html
https://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html


In [12]:
from egglog.exp.array_api import *
from egglog.exp.array_api_numba import *
from egglog.exp.array_api_program_gen import *
from sklearn import config_context
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def run_lda(x, y):
    lda = LinearDiscriminantAnalysis()
    return lda.fit(x, y).transform(x)


X_np, y_np = make_classification(random_state=0, n_samples=1000000)
run_lda(X_np, y_np)


array([[ 0.64233002],
       [ 0.63661245],
       [-1.603293  ],
       ...,
       [-1.1506433 ],
       [ 0.71687176],
       [-1.51119579]])

In [2]:
# %timeit run_lda(X_np, y_np)


In [3]:
# %%timeit
# with config_context(array_api_dispatch=True):
#     run_lda(X_np, y_np)


In [4]:
X_arr = NDArray.var("X")
X_orig = copy(X_arr)

assume_dtype(X_arr, X_np.dtype)
assume_shape(X_arr, X_np.shape)
assume_isfinite(X_arr)

y_arr = NDArray.var("y")
y_orig = copy(y_arr)

assume_dtype(y_arr, y_np.dtype)
assume_shape(y_arr, y_np.shape)
assume_value_one_of(y_arr, (0, 1))


In [13]:
with EGraph([array_api_module]) as egraph:
    with config_context(array_api_dispatch=True):
        X_r2 = run_lda(X_arr, y_arr)
egraph = EGraph([array_api_numba_module])
egraph.register(X_r2)
egraph.run(10000)
X_r2_optimized = egraph.extract(X_r2)


In [14]:
egraph = EGraph([array_api_module_string])
fn_program = ndarray_function_two(X_r2_optimized, X_orig, y_orig)
egraph.register(fn_program)
egraph.run(10000)
fn = egraph.load_object(egraph.extract(fn_program.py_object))

assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))


In [7]:
# egraph.display(n_inline_leaves=2, split_primitive_outputs=True)


In [8]:
# %timeit fn(X_np, y_np)


In [15]:
from numba.core.imputils import lower_builtin, impl_ret_untracked
import operator
from numba.core import types
from numba.core.typing.templates import AbstractTemplate, signature, infer_global

from llvmlite import ir


@infer_global(operator.eq)
class DtypeEq(AbstractTemplate):
    def generic(self, args, kws):
        [lhs, rhs] = args
        if isinstance(lhs, types.DType) and isinstance(rhs, types.DType):
            return signature(types.boolean, lhs, rhs)


@lower_builtin(operator.eq, types.DType, types.DType)
def const_eq_impl(context, builder, sig, args):
    arg1, arg2 = sig.args
    val = 1 if arg1 == arg2 else 0
    res = ir.Constant(ir.IntType(1), val)
    return impl_ret_untracked(context, builder, sig.return_type, res)


In [17]:
from numba import njit

fn_numba = njit(fn)
assert np.allclose(run_lda(X_np, y_np), fn_numba(X_np, y_np))


In [18]:
%timeit fn_numba(X_np, y_np)


1.12 s ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%timeit run_lda(X_np, y_np)


2.74 s ± 49.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%%timeit
with config_context(array_api_dispatch=True):
    run_lda(X_np, y_np)


1.48 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%timeit fn(X_np, y_np)


1.56 s ± 17.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
