# Implenting an Array API to use with Scikit-learn


In this tutorial, we will create an object that implements the Array API and use it in the `LinearDiscriminantAnalysis` example that is in the [scikit-learn docs](https://scikit-learn.org/stable/modules/array_api.html).

First, let's try LDA with normal numpy arrays.

We take a set of input vector and reduce the dimensionality to 1.

In [1]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target

def fit(X, y):
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)
        return X_r2

    target_names = iris.target_names
    plt.figure()
    colors = ["navy", "turquoise", "darkorange"]

    plt.figure()
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(
            X_r2[y == i, 0], X_r2[y == i, 1], alpha=0.8, color=color, label=target_name
        )
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.title("LDA of IRIS dataset")

    plt.show()
# fit(X, y)

In [2]:
# import torch

# fit(torch.asarray(X), torch.asarray(y))

Now let's try to make a NDArray object that implements the Array API.

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from egglog.exp.array_api import *


X_arr = NDArray.var("X")
y_arr = NDArray.var("y")

# Add values for the constants
egraph.register(
    rewrite(X_arr.dtype, runtime_ruleset).to(convert(X.dtype, DType)),
    rewrite(y_arr.dtype, runtime_ruleset).to(convert(y.dtype, DType)),
    rewrite(isfinite(sum(X_arr)).to_bool(), runtime_ruleset).to(TRUE),
    rewrite(isfinite(sum(y_arr)).to_bool(), runtime_ruleset).to(TRUE),
    rewrite(X_arr.shape, runtime_ruleset).to(convert(X.shape, TupleInt)),
    rewrite(y_arr.shape, runtime_ruleset).to(convert(y.shape, TupleInt)),
    rewrite(X_arr.size, runtime_ruleset).to(Int(X.size)),
    rewrite(y_arr.size, runtime_ruleset).to(Int(y.size)),
    rewrite(unique_values(y_arr).shape).to(TupleInt(Int(3))),
    rewrite(unique_values(y_arr).size).to(Int(3)),
)


res = fit(X_arr, y_arr)

# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)

# X_arr = NDArray(X_obj)
# y_arr = NDArray(y_obj)

isdtype(DType.float32, IsDtypeKind.string("integral"))
  -> FALSE
     -> FALSE

DType.float64 == NDArray.var("X").dtype
  -> DType.float64 == NDArray.var("X").dtype
     -> TRUE

asarray(NDArray.var("X")).ndim == Int(0)
  -> NDArray.var("X").ndim == Int(0)
     -> FALSE

asarray(NDArray.var("X")).ndim == Int(1)
  -> FALSE
     -> FALSE

asarray(NDArray.var("X")).ndim >= Int(3)
  -> FALSE
     -> FALSE

asarray(asarray(NDArray.var("X"))).dtype == DType.object
  -> FALSE
     -> FALSE

isdtype(asarray(asarray(NDArray.var("X"))).dtype, IsDtypeKind.string("real floating") | (IsDtypeKind.string("complex floating") | IsDtypeKind.NULL))
  -> TRUE
     -> TRUE

isfinite(sum(asarray(asarray(NDArray.var("X"))))).to_bool()
  -> isfinite(sum(NDArray.var("X"))).to_bool()
     -> TRUE

asarray(NDArray.var("X")).shape.length()
  -> Int(2)
     -> Int(2)

asarray(NDArray.var("X")).shape[Int(0)] < Int(2)
  -> FALSE
     -> FALSE

asarray(NDArray.var("X")).ndim == Int(2)
  -> TRUE
     -> TRUE

asarray

In [5]:
egraph.run((run(runtime_ruleset, limit=10) + run(limit=10)).saturate())
egraph.extract(res)

_NDArray_3 = copy(zeros(TupleInt(Int(3)) + TupleInt(Int(4)), OptionalDType.some(DType.float64), OptionalDevice.some(NDArray.var("X").device)))
_NDArray_3[IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice())))] = mean(
    NDArray.var("X")[ndarray_index(unique_inverse(NDArray.var("y"))[Int(1)] == NDArray.scalar_int(Int(0)))], OptionalIntOrTuple.int(Int(0))
)
_NDArray_2 = copy(_NDArray_3)
_NDArray_2[IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice())))] = mean(
    NDArray.var("X")[ndarray_index(unique_inverse(NDArray.var("y"))[Int(1)] == NDArray.scalar_int(Int(1)))], OptionalIntOrTuple.int(Int(0))
)
_NDArray_1 = copy(_NDArray_2)
_NDArray_1[IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(2))) + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice())))] = mean(
    NDArray.var("X")[ndarray_index(unique_inverse(NDArray

In [6]:
Slice?

[0;31mSignature:[0m      [0mSlice[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m:[0m [0;34m'object'[0m[0;34m)[0m [0;34m->[0m [0;34m'Optional[RuntimeExpr]'[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           RuntimeClass
[0;31mString form:[0m    Slice
[0;31mFile:[0m           ~/p/egg-smol-python/python/egglog/runtime.py
[0;31mDocstring:[0m      RuntimeClass(__egg_decls__: 'ModuleDeclarations', __egg_name__: 'str')
[0;31mCall docstring:[0m Create an instance of this kind by calling the __init__ classmethod