# Implenting an Array API to use with Scikit-learn


In this tutorial, we will create an object that implements the Array API and use it in the `LinearDiscriminantAnalysis` example that is in the [scikit-learn docs](https://scikit-learn.org/stable/modules/array_api.html).

First, let's try LDA with normal numpy arrays.

We take a set of input vector and reduce the dimensionality to 1.

In [1]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target

def fit_and_plot(X, y):
    target_names = iris.target_names
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)


    plt.figure()
    colors = ["navy", "turquoise", "darkorange"]

    plt.figure()
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(
            X_r2[y == i, 0], X_r2[y == i, 1], alpha=0.8, color=color, label=target_name
        )
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.title("LDA of IRIS dataset")

    plt.show()
# fit_and_plot(X, y)

In [2]:
import torch

# fit_and_plot(torch.asarray(X), torch.asarray(y))

Now let's try to make a NDArray object that implements the Array API.

In [3]:
from __future__ import annotations
import sys
from typing import TypeVar, ClassVar

from egglog import *

egraph = EGraph()

T = TypeVar("T", bound=BaseExpr)

runtime_ruleset = egraph.ruleset("runtime")

@egraph.class_
class Bool(BaseExpr):

    @egraph.method(preserve=True)
    def __bool__(self) -> bool:
        egraph.register(self)
        egraph.run(run().saturate())
        final_object = egraph.extract(self)
        print(f"Extracting {final_object}")
        egraph.run((run(runtime_ruleset) + run()).saturate())
        res = egraph.load_object(egraph.extract(final_object.to_py()))
        assert type(res) == bool
        return res
    
    def to_py(self) -> PyObject:
        ...

converter(bool, Bool, lambda x: TRUE if x else FALSE)

TRUE = egraph.constant("TRUE", Bool)
FALSE = egraph.constant("FALSE", Bool)

egraph.register(
    set_(TRUE.to_py()).to(egraph.save_object(True)),
    set_(FALSE.to_py()).to(egraph.save_object(False)),
)

@egraph.class_
class DType(BaseExpr):
    ...

float64 = egraph.constant("float64", DType)
float32 = egraph.constant("float32", DType)

@egraph.function
def isdtype(dtype: DType, kind: StringLike) -> Bool:
    ...

egraph.register(
    rewrite(isdtype(float64, "integral")).to(FALSE),
    rewrite(isdtype(float32, "integral")).to(FALSE),
)

assert not bool(isdtype(float64, "integral"))


@egraph.class_
class Int(BaseExpr):
    def __init__(self, value: i64Like) -> None: ...

    def __eq__(self, other: Int) -> Bool: ...
    def __ge__(self, other: Int) -> Bool: ...

@egraph.register
def _int_eq(i: i64, j: i64, r: Bool):
    yield rewrite(Int(i) == Int(i)).to(TRUE)
    yield rule(eq(r).to(Int(i) == Int(j)), i != j).then(union(r).with_(FALSE))

    yield rewrite(Int(i) >= Int(i)).to(TRUE)
    yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
    yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))

converter(int, Int, lambda x: Int(x))

assert expr_parts(egraph.simplify(Int(1) == Int(1), 10)) == expr_parts(TRUE)
assert expr_parts(egraph.simplify(Int(1) == Int(2), 10)) == expr_parts(FALSE)
assert expr_parts(egraph.simplify(Int(1) >= Int(2), 10)) == expr_parts(FALSE)
assert expr_parts(egraph.simplify(Int(1) >= Int(1), 10)) == expr_parts(TRUE)
assert expr_parts(egraph.simplify(Int(2) >= Int(1), 10)) == expr_parts(TRUE)

@egraph.class_
class NDArray(BaseExpr):
    def __init__(self, py_array: PyObject) -> None: ...

    @egraph.method(preserve=True)
    def __array_namespace__(self, api_version=None):
        return sys.modules[__name__]

    @property
    def ndim(self) -> Int: ...

@egraph.class_
class OptionalBool(BaseExpr):
    none: ClassVar[OptionalBool]
    @classmethod
    def some(cls, value: Bool) -> OptionalBool: ...

converter(type(None), OptionalBool, lambda x: OptionalBool.none)
converter(bool, OptionalBool, lambda x: OptionalBool.some(convert(x, Bool)))

@egraph.function
def asarray(a: NDArray, dtype: DType, copy: OptionalBool) -> NDArray: ...

@egraph.register
def _assarray(a: NDArray, d: DType, ob: OptionalBool):
    yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim


X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)

# Add values for the constants
@egraph.register
def _ndim(a: NDArray, d: DType, ob: OptionalBool):
    yield rewrite(NDArray(X_obj).ndim, runtime_ruleset).to(Int(X.ndim))
    yield rewrite(NDArray(y_obj).ndim, runtime_ruleset).to(Int(y.ndim))


fit_and_plot(NDArray(X_obj), NDArray(y_obj))

Extracting FALSE
Extracting FALSE
Extracting NDArray(PyObject(5305820688)).ndim == Int(0)
Extracting FALSE
Extracting FALSE


TypeError: missing a required argument: 'dtype'