In [29]:
import sklearn.datasets as skd
import sklearn.metrics as skm
from sklearn.model_selection import train_test_split
import numpy as np
import jax
import jax.numpy as jnp
import jaxopt
import flax.linen as nn
import optax
from tqdm.auto import trange

In [3]:
def evaluate(Y_test, preds):
    print(f"R2 score: {skm.r2_score(Y_test, preds)}")
    print(f"MAE: {skm.mean_absolute_error(Y_test, preds)}")

In [4]:
X, Y = skd.fetch_california_housing(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=42)

In [49]:
class MLPRegressor(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(100)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


def mean_squared_error(model):
    def _apply(params, X, Y):
        return jnp.mean((model.apply(variables, X) - Y)**2)
    return _apply

def mean_absolute_error(model):
    def _apply(params, X, Y):
        return jnp.mean(jnp.abs(model.apply(variables, X) - Y))
    return _apply


model = MLPRegressor()
variables = model.init(jax.random.PRNGKey(42), X_train[:1])
solver = jaxopt.OptaxSolver(opt=optax.adam(0.001), fun=mean_squared_error(model))
state = solver.init_state(variables, X_train[:1], Y_train[:1])
step = jax.jit(solver.update)
rng = np.random.default_rng(42)

for r in (pbar := trange(5000)):
    idx = rng.choice(len(Y_train), 200, replace=False)
    variables, state = step(params=variables, state=state, X=X_train[idx], Y=Y_train[idx])
    pbar.set_postfix_str(f"Loss: {state.value}")

# model = create_model(X_train[0].shape)
# model.fit(X_train, Y_train, epochs=200, batch_size=200)
evaluate(Y_test, model.apply(variables, X_test))

100%|██████████| 5000/5000 [00:01<00:00, 2664.06it/s, Loss: 33426.890625]   


R2 score: -23226.030187432847
MAE: 152.35350469279385


In [50]:
import sklearn.linear_model as sklm
import sklearn.neural_network as sknn

model = sknn.MLPRegressor().fit(X_train, Y_train)
evaluate(Y_test, model.predict(X_test))

R2 score: 0.4914483094206872
MAE: 0.5884081193797143
