# Implenting an Array API to use with Scikit-learn


In this tutorial, we will create an object that implements the Array API and use it in the `LinearDiscriminantAnalysis` example that is in the [scikit-learn docs](https://scikit-learn.org/stable/modules/array_api.html).

First, let's try LDA with normal numpy arrays.

We take a set of input vector and reduce the dimensionality to 1.

In [1]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target

def fit_and_plot(X, y):
    target_names = iris.target_names
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)


    plt.figure()
    colors = ["navy", "turquoise", "darkorange"]

    plt.figure()
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(
            X_r2[y == i, 0], X_r2[y == i, 1], alpha=0.8, color=color, label=target_name
        )
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.title("LDA of IRIS dataset")

    plt.show()
# fit_and_plot(X, y)

In [2]:
import torch

# fit_and_plot(torch.asarray(X), torch.asarray(y))

Now let's try to make a NDArray object that implements the Array API.

In [6]:
from __future__ import annotations
import sys
from typing import TypeVar, ClassVar
import itertools
import numpy as np

from egglog import *

egraph = EGraph()

T = TypeVar("T", bound=BaseExpr)

runtime_ruleset = egraph.ruleset("runtime")


@egraph.class_
class Bool(BaseExpr):
    @egraph.method(preserve=True)
    def __bool__(self) -> bool:
        egraph.register(self)
        egraph.run(run().saturate())
        final_object = egraph.extract(self)
        print(f"{self} -> {final_object}")
        egraph.run((run(runtime_ruleset) + run()).saturate())
        res = egraph.load_object(egraph.extract(final_object.to_py()))
        assert type(res) == bool
        return res

    def to_py(self) -> PyObject:
        ...

    def __or__(self, other: Bool) -> Bool:
        ...

converter(bool, Bool, lambda x: TRUE if x else FALSE)

TRUE = egraph.constant("TRUE", Bool)
FALSE = egraph.constant("FALSE", Bool)

egraph.register(
    set_(TRUE.to_py()).to(egraph.save_object(True)),
    set_(FALSE.to_py()).to(egraph.save_object(False)),
    rewrite(TRUE | FALSE).to(TRUE),
    rewrite(FALSE | TRUE).to(TRUE),
    rewrite(FALSE | FALSE).to(FALSE),
    rewrite(TRUE | TRUE).to(TRUE),
)


@egraph.class_
class DType(BaseExpr):
    float64: ClassVar[DType]
    float32: ClassVar[DType]
    int64: ClassVar[DType]
    object: ClassVar[DType]

    def __eq__(self, other: DType) -> Bool:
        ...


float64 = DType.float64
float32 = DType.float32
int64 = DType.int64

converter(type, DType, lambda x: convert(np.dtype(x), DType))
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name))
egraph.register(
    *(
        rewrite(l == r).to(TRUE if expr_parts(l) == expr_parts(r) else FALSE)
        for l, r in itertools.product([DType.float64, DType.float32, DType.object, DType.int64], repeat=2)
    )
)


@egraph.class_
class IsDtypeKind(BaseExpr):
    NULL: ClassVar[IsDtypeKind]

    @classmethod
    def string(cls, s: StringLike) -> IsDtypeKind:
        ...

    @classmethod
    def dtype(cls, d: DType) -> IsDtypeKind:
        ...

    def __or__(self, other: IsDtypeKind) -> IsDtypeKind:
        ...


# TODO: Make kind more generic to support tuples.
@egraph.function
def isdtype(dtype: DType, kind: IsDtypeKind) -> Bool:
    ...


converter(np.dtype, IsDtypeKind, lambda x: IsDtypeKind.dtype(convert(x, DType)))
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
converter(
    tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
)

@egraph.register
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
    return [
        rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
        rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
        rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
        rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
        rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
        rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
        rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2))
        rewrite(k1 | IsDtypeKind.NULL).to(k1),
    ]


assert not bool(isdtype(DType.float32, IsDtypeKind.string("integral")))


@egraph.class_
class Int(BaseExpr):
    def __init__(self, value: i64Like) -> None:
        ...

    def __eq__(self, other: Int) -> Bool:
        ...

    def __ge__(self, other: Int) -> Bool:
        ...


@egraph.register
def _int_eq(i: i64, j: i64, r: Bool):
    yield rewrite(Int(i) == Int(i)).to(TRUE)
    yield rule(eq(r).to(Int(i) == Int(j)), i != j).then(union(r).with_(FALSE))

    yield rewrite(Int(i) >= Int(i)).to(TRUE)
    yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
    yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))


converter(int, Int, lambda x: Int(x))

assert expr_parts(egraph.simplify(Int(1) == Int(1), 10)) == expr_parts(TRUE)
assert expr_parts(egraph.simplify(Int(1) == Int(2), 10)) == expr_parts(FALSE)
assert expr_parts(egraph.simplify(Int(1) >= Int(2), 10)) == expr_parts(FALSE)
assert expr_parts(egraph.simplify(Int(1) >= Int(1), 10)) == expr_parts(TRUE)
assert expr_parts(egraph.simplify(Int(2) >= Int(1), 10)) == expr_parts(TRUE)


@egraph.class_
class NDArray(BaseExpr):
    def __init__(self, py_array: PyObject) -> None:
        ...

    @egraph.method(preserve=True)
    def __array_namespace__(self, api_version=None):
        return sys.modules[__name__]

    @property
    def ndim(self) -> Int:
        ...

    @property
    def dtype(self) -> DType:
        ...


@egraph.class_
class OptionalBool(BaseExpr):
    none: ClassVar[OptionalBool]

    @classmethod
    def some(cls, value: Bool) -> OptionalBool:
        ...


converter(type(None), OptionalBool, lambda x: OptionalBool.none)
converter(Bool, OptionalBool, lambda x: OptionalBool.some(x))
converter(bool, OptionalBool, lambda x: OptionalBool.some(convert(x, Bool)))


@egraph.class_
class OptionalDType(BaseExpr):
    none: ClassVar[OptionalDType]

    @classmethod
    def some(cls, value: DType) -> OptionalDType:
        ...


converter(type(None), OptionalDType, lambda x: OptionalDType.none)
converter(DType, OptionalDType, lambda x: OptionalDType.some(x))


@egraph.function
def asarray(a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none) -> NDArray:
    ...


@egraph.register
def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool):
    yield rewrite(asarray(a, d, ob).ndim).to(a.ndim)  # asarray doesn't change ndim
    yield rewrite(asarray(a)).to(a)  # asarray doesn't change to_py


X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)

# Add values for the constants
egraph.register(
    rewrite(NDArray(X_obj).ndim, runtime_ruleset).to(Int(X.ndim)),
    rewrite(NDArray(y_obj).ndim, runtime_ruleset).to(Int(y.ndim)),
    rewrite(NDArray(X_obj).dtype, runtime_ruleset).to(convert(X.dtype, DType)),
    rewrite(NDArray(y_obj).dtype, runtime_ruleset).to(convert(y.dtype, DType)),
)


fit_and_plot(NDArray(X_obj), NDArray(y_obj))

isdtype(DType.float32, IsDtypeKind.string("integral")) -> FALSE
DType.float64 == NDArray(PyObject(5326120464)).dtype -> DType.float64 == NDArray(PyObject(5326120464)).dtype
asarray(NDArray(PyObject(5326120464)), OptionalDType.none, OptionalBool.none).ndim == Int(0) -> NDArray(PyObject(5326120464)).ndim == Int(0)
asarray(NDArray(PyObject(5326120464)), OptionalDType.none, OptionalBool.none).ndim == Int(1) -> FALSE
asarray(NDArray(PyObject(5326120464)), OptionalDType.none, OptionalBool.none).ndim >= Int(3) -> FALSE
asarray(asarray(NDArray(PyObject(5326120464)), OptionalDType.none, OptionalBool.none), OptionalDType.none, OptionalBool.none).dtype == DType.object -> FALSE
isdtype(
    asarray(asarray(NDArray(PyObject(5326120464)), OptionalDType.none, OptionalBool.none), OptionalDType.none, OptionalBool.none).dtype,
    (IsDtypeKind.string("real floating") | (IsDtypeKind.string("complex floating") | IsDtypeKind.NULL)),
) -> isdtype(DType.float64, (IsDtypeKind.string("real floating") | (IsDtyp

EggSmolError: Not found: fake expression Bool.to_py [Value { tag: "Bool", bits: 33 }]

In [None]:
type(type("df")) == type

True

In [None]:
np.dtype(np.float64)

dtype('float64')

In [None]:
type(np.dtype)

numpy._DTypeMeta