In [8]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

In [9]:
iris = datasets.load_iris()
X, y = iris.data, iris.target

In [12]:
# One-hot encode the targets since we are dealing with a multi-class classification problem
encoder = OneHotEncoder(sparse_output=False)
y_encoded = encoder.fit_transform(y.reshape(-1, 1))

In [13]:
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=44)

In [19]:
from flax import linen as nn
import jax
import jax.numpy as jnp
from jax import random
import optax
from flax.training import train_state

In [20]:
class IrisNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(3)(x)
        return x

In [21]:
def create_train_state(rng_key, learning_rate, input_shape):
    model = IrisNN()
    params = model.init(rng_key, jnp.ones(input_shape))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [22]:
@jax.jit
def train_step(state, X, y):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, X)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
        return loss
    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

In [23]:
# Training loop
rng_key = random.PRNGKey(0)
state = create_train_state(rng_key, 0.001, (1,4))
for epoch in range(100):
    state = train_step(state, jnp.array(X_train), jnp.array(y_train))

In [24]:
import pickle

with open('jax_model_iris.pkl', 'wb') as file:
    pickle.dump(state.params, file)