# Implenting an Array API to use with Scikit-learn


In this tutorial, we will create an object that implements the Array API and use it in the `LinearDiscriminantAnalysis` example that is in the [scikit-learn docs](https://scikit-learn.org/stable/modules/array_api.html).

First, let's try LDA with normal numpy arrays.

We take a set of input vector and reduce the dimensionality to 1.

In [1]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target

def fit(X, y):
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)
        return X_r2

    target_names = iris.target_names
    plt.figure()
    colors = ["navy", "turquoise", "darkorange"]

    plt.figure()
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(
            X_r2[y == i, 0], X_r2[y == i, 1], alpha=0.8, color=color, label=target_name
        )
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.title("LDA of IRIS dataset")

    plt.show()
# fit(X, y)

In [2]:
# import torch

# fit(torch.asarray(X), torch.asarray(y))

Now let's try to make a NDArray object that implements the Array API.

In [3]:
from egglog.exp.array_api import *


X_arr = NDArray.var("X")
y_arr = NDArray.var("y")

# Add values for the constants
egraph.register(
    rewrite(X_arr.dtype, runtime_ruleset).to(convert(X.dtype, DType)),
    rewrite(y_arr.dtype, runtime_ruleset).to(convert(y.dtype, DType)),
    rewrite(isfinite(sum(X_arr)).to_bool(), runtime_ruleset).to(TRUE),
    rewrite(isfinite(sum(y_arr)).to_bool(), runtime_ruleset).to(TRUE),
    rewrite(X_arr.shape, runtime_ruleset).to(convert(X.shape, TupleInt)),
    rewrite(y_arr.shape, runtime_ruleset).to(convert(y.shape, TupleInt)),
    rewrite(X_arr.size, runtime_ruleset).to(Int(X.size)),
    rewrite(y_arr.size, runtime_ruleset).to(Int(y.size)),
    rewrite(unique_values(y_arr).shape).to(TupleInt(Int(3))),
    rewrite(unique_values(y_arr).size).to(Int(3)),
)


res = fit(X_arr, y_arr)

# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)

# X_arr = NDArray(X_obj)
# y_arr = NDArray(y_obj)

In [4]:
@egraph.register
def _optimizations(i: Int):
    yield rewrite(sqrt(NDArray.scalar_int(Int(0)))).to(NDArray.scalar_int(Int(0)))
    yield rewrite(i * Int(0)).to(Int(0))

In [5]:
egraph.run((run(runtime_ruleset) * 10 + run() * 10).saturate())
res = egraph.extract(res)
res