# Implenting an Array API to use with Scikit-learn


In this tutorial, we will create an object that implements the Array API and use it in the `LinearDiscriminantAnalysis` example that is in the [scikit-learn docs](https://scikit-learn.org/stable/modules/array_api.html).

First, let's try LDA with normal numpy arrays.

We take a set of input vector and reduce the dimensionality to 1.

In [1]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target

def fit_and_plot(X, y):
    target_names = iris.target_names
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)


    plt.figure()
    colors = ["navy", "turquoise", "darkorange"]

    plt.figure()
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(
            X_r2[y == i, 0], X_r2[y == i, 1], alpha=0.8, color=color, label=target_name
        )
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.title("LDA of IRIS dataset")

    plt.show()
# fit_and_plot(X, y)

In [2]:
import torch

# fit_and_plot(torch.asarray(X), torch.asarray(y))

Now let's try to make a NDArray object that implements the Array API.

In [10]:
from __future__ import annotations
import sys
from typing import TypeVar

from egglog import *

egraph = EGraph()

T = TypeVar("T", bound=BaseExpr)

def simplify(expr: T) -> T:
    egraph.register(expr)
    egraph.run(run(limit=10).saturate())
    return egraph.extract(expr)

@egraph.class_
class Bool(BaseExpr):
    @egraph.method(preserve=True)
    def __bool__(self) -> bool:
        egraph.register(self)
        egraph.run(run(limit=10).saturate())
        res = egraph.load_object(egraph.extract(self.to_py()))
        assert isinstance(res, bool)
        return res
    
    def to_py(self) -> PyObject:
        ...

TRUE = egraph.constant("TRUE", Bool)
FALSE = egraph.constant("FALSE", Bool)

egraph.register(
    set_(TRUE.to_py()).to(egraph.save_object(True)),
    set_(FALSE.to_py()).to(egraph.save_object(False)),
)

@egraph.class_
class DType(BaseExpr):
    ...

float64 = egraph.constant("float64", DType)
float32 = egraph.constant("float32", DType)

@egraph.function
def isdtype(dtype: DType, kind: StringLike) -> Bool:
    ...

egraph.register(
    rewrite(isdtype(float64, "integral")).to(FALSE),
    rewrite(isdtype(float32, "integral")).to(FALSE),
)

assert not bool(isdtype(float64, "integral"))



@egraph.class_
class NDArray(BaseExpr):
    def __init__(self, py_array: PyObject) -> None: ...

    @egraph.method(preserve=True)
    def __array_namespace__(self, api_version=None):
        return sys.modules[__name__]

@egraph.function
def asarray(a: NDArray, dtype: DType, copy: )


X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)

fit_and_plot(NDArray(X_obj), NDArray(y_obj))

AttributeError: module '__main__' has no attribute 'asarray'

'__main__'

TypeError: The input is not a supported array type