Example | Code | Test |
---|---|---|
MLP | mlp.py | test_mlp.py |
Train a MLP
from sklearn.datasets import load_diabetes, load_digits
from jax_examples.models.mlp import MLP
# Regression
X, y = load_diabetes(return_X_y=True)
regression = MLP(task_type="regression")
regression.fit(X, y)
regression.predict(X)
# Classification
X, y = load_digits(return_X_y=True)
clf = MLP(task_type="classification")
clf.fit(X, y)
clf.predict(X)