Skip to content

bcebere/jax_tabular_examples

Repository files navigation

Machine learning models for Tabular Data using JAX and Flax.

🔥 Models

Example Code Test
MLP mlp.py test_mlp.py

Examples

Train a MLP

from sklearn.datasets import load_diabetes, load_digits
from jax_examples.models.mlp import MLP

# Regression
X, y = load_diabetes(return_X_y=True)
regression = MLP(task_type="regression")
regression.fit(X, y)
regression.predict(X)

# Classification
X, y = load_digits(return_X_y=True)
clf = MLP(task_type="classification")
clf.fit(X, y)
clf.predict(X)

License

MIT

Releases

No releases published

Packages

No packages published

Languages