# Iris Classification using Flax Neural Network

This notebook implements a 2-layer feedforward neural network using Flax to classify the Iris dataset. We'll use JAX for efficient numerical computations and Flax for neural network layers.

## Setup Environment

First, let's install the required packages. We'll need JAX, Flax, and other dependencies.

In [None]:
!pip install --upgrade pip
!pip install flax optax

## Import Dependencies

Now let's import all the necessary libraries for our neural network implementation.

In [None]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

## Data Preparation

Let's load and preprocess the Iris dataset:

In [None]:
# Load and preprocess the Iris dataset
data = load_iris()
X = data['data']
y = data['target'].reshape(-1, 1)

# Scale the features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# One-hot encode the labels
encoder = OneHotEncoder(sparse=False)
y = encoder.fit_transform(y)

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

## Model Definition

Now we'll define our 2-layer feedforward neural network using Flax:

In [None]:
class FeedForwardNN(nn.Module):
    hidden_dim: int = 32
    output_dim: int = 3

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x

# Create a training state
def create_train_state(rng, learning_rate, model):
    params = model.init(rng, jnp.ones([1, 4]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Define the loss function
def cross_entropy_loss(logits, labels):
    return optax.softmax_cross_entropy(logits, labels).mean()

def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    return {'loss': loss, 'accuracy': accuracy}

## Training Functions

Let's define the training and evaluation steps:

In [None]:
# Training step
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['X'])
        loss = cross_entropy_loss(logits, batch['y'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['y'])
    return state, metrics

# Evaluation step
@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['X'])
    return compute_metrics(logits, batch['y'])

## Training Loop

Now let's train the model and visualize the results:

In [None]:
# Initialize model and training state
rng = jax.random.PRNGKey(0)
model = FeedForwardNN()
state = create_train_state(rng, 0.001, model)

# Training parameters
num_epochs = 100
batch_size = 32
num_train = X_train.shape[0]

# Lists to store metrics for plotting
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

# Training loop
for epoch in range(num_epochs):
    # Training
    perm = jax.random.permutation(jax.random.PRNGKey(epoch), num_train)
    batch_losses = []
    batch_accuracies = []
    
    for i in range(0, num_train, batch_size):
        idx = perm[i:i+batch_size]
        batch = {
            'X': jnp.array(X_train[idx]),
            'y': jnp.array(y_train[idx])
        }
        state, metrics = train_step(state, batch)
        batch_losses.append(metrics['loss'])
        batch_accuracies.append(metrics['accuracy'])
    
    # Calculate average training metrics
    avg_train_loss = np.mean(batch_losses)
    avg_train_accuracy = np.mean(batch_accuracies)
    train_losses.append(avg_train_loss)
    train_accuracies.append(avg_train_accuracy)
    
    # Evaluation
    test_batch = {
        'X': jnp.array(X_test),
        'y': jnp.array(y_test)
    }
    test_metrics = eval_step(state, test_batch)
    test_losses.append(test_metrics['loss'])
    test_accuracies.append(test_metrics['accuracy'])
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}")
        print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}")
        print(f"Test Loss: {test_metrics['loss']:.4f}, Test Accuracy: {test_metrics['accuracy']:.4f}\n")

## Visualize Results

Let's plot the training and testing metrics:

In [None]:
import matplotlib.pyplot as plt

# Plot training and testing loss
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss')
plt.legend()

# Plot training and testing accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(test_accuracies, label='Testing Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Testing Accuracy')
plt.legend()

plt.tight_layout()
plt.show()