In [25]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

In [26]:
iris = datasets.load_iris()
X, y = iris.data, iris.target

In [27]:
# One-hot encode the targets since we are dealing with a multi-class classification problem
encoder = OneHotEncoder(sparse_output=False)
y_encoded = encoder.fit_transform(y.reshape(-1, 1))

In [28]:
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=44)

In [29]:
from flax import linen as nn
import jax
import jax.numpy as jnp
from jax import random
import optax
from flax.training import train_state

In [30]:
class IrisNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(3)(x)
        return x

In [31]:
def create_train_state(rng_key, learning_rate, input_shape):
    model = IrisNN()
    params = model.init(rng_key, jnp.ones(input_shape))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [32]:
@jax.jit
def train_step(state, X, y):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, X)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
        return loss
    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

In [33]:
# Training loop
rng_key = random.PRNGKey(0)
state = create_train_state(rng_key, 0.001, (1,4))
for epoch in range(100):
    state = train_step(state, jnp.array(X_train), jnp.array(y_train))

In [34]:
import pickle

with open('jax_model_iris.pkl', 'wb') as file:
    pickle.dump(state.params, file)

In [35]:
# Load the saved JAX model parameters
with open('jax_model_iris.pkl', 'rb') as f:
    params = pickle.load(f)

In [36]:
# Initialize and predict
model = IrisNN()
input_data = X[:5]  # Taking first 5 samples for prediction
predictions = model.apply({'params': params}, jnp.array(input_data))

In [37]:
# Convert logits to probabilities
probs = jax.nn.softmax(predictions)

In [38]:
# Get top 3 predictions
top_k_values, top_k_indices = jax.lax.top_k(probs, k=3)

In [39]:
for i in range(len(input_data)):
    print(f"Sample {i+1} predictions:")
    for j in range(3):
        class_index = top_k_indices[i, j]
        probability = top_k_values[i, j]
        class_name = iris.target_names[class_index]
        print(f"\t{class_name} ({probability*100:.2f}%)")

Sample 1 predictions:
	setosa (96.15%)
	versicolor (3.76%)
	virginica (0.09%)
Sample 2 predictions:
	setosa (93.78%)
	versicolor (6.05%)
	virginica (0.17%)
Sample 3 predictions:
	setosa (94.91%)
	versicolor (4.93%)
	virginica (0.15%)
Sample 4 predictions:
	setosa (93.14%)
	versicolor (6.65%)
	virginica (0.22%)
Sample 5 predictions:
	setosa (96.22%)
	versicolor (3.69%)
	virginica (0.09%)
