In [113]:
# Dependencies
#%pip install jax[cuda12] flax datasets optax ty pylance

In [2]:
import jax.numpy as jnp
from flax import nnx
import optax

# always reload data
%load_ext autoreload
%autoreload 2
from breast_cancer_data import data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [139]:
class Model(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.l1 = nnx.Linear(13, 20, rngs=rngs)
        self.l2 = nnx.Linear(20, 20, rngs=rngs)
        self.l3 = nnx.Linear(20, 3, rngs=rngs)
    
    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        return x

In [140]:
rngs = nnx.Rngs(0)
model = Model(rngs)

In [None]:
def compute_loss(model, batch):
    """Compute cross-entropy loss"""
    features, labels = batch["features"], batch["label"]
    
    # Forward pass
    logits = model(features)
    
    # Cross-entropy loss
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    
    return loss

@nnx.jit
def train_step(model, optimizer, batch):
    """Single training step"""
    
    # Compute loss and gradients
    loss, grads = nnx.value_and_grad(compute_loss)(model, batch)
    
    # Update parameters
    optimizer.update(model, grads)
    
    return loss

def train(model, optimizer, train_loader, num_epochs=10):
    """Full training loop"""
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        # Loop over batches
        for batch in train_loader.iter(batch_size=128):
            loss = train_step(model, optimizer, batch)
            epoch_loss += loss
            num_batches += 1
        
        # Print progress
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1}/{num_epochs}: Loss = {avg_loss:.4f}")

def evaluate(model, val_loader):
    """Evaluate on validation set"""
    total_loss = 0.0
    correct = 0
    total = 0
    
    for batch in val_loader:
        features, labels = batch
        
        # Forward pass only (no gradients)
        logits = model(features)
        
        # Compute loss
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
        total_loss += loss
        
        # Compute accuracy
        preds = jnp.argmax(logits, axis=-1)
        targets = jnp.argmax(labels, axis=-1)
        correct += jnp.sum(preds == targets)
        total += len(features)
    
    avg_loss = total_loss / len(val_loader)
    accuracy = correct / total
    
    return avg_loss, accuracy

def train_with_validation(model, optimizer, train_loader, val_loader, num_epochs=10):
    for epoch in range(num_epochs):
        # Training
        epoch_loss = 0.0
        for batch in train_loader:
            loss = train_step(model, optimizer, batch)
            epoch_loss += loss
        
        # Validation
        val_loss, val_acc = evaluate(model, val_loader)
        
        print(f"Epoch {epoch + 1}: "
              f"Train Loss = {epoch_loss/len(train_loader):.4f}, "
              f"Val Loss = {val_loss:.4f}, "
              f"Val Acc = {val_acc:.2%}")

In [142]:
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)
train(model, optimizer, data["train"], 100)

Epoch 1/100: Loss = nan
Epoch 2/100: Loss = nan
Epoch 3/100: Loss = nan
Epoch 4/100: Loss = nan
Epoch 5/100: Loss = nan
Epoch 6/100: Loss = nan
Epoch 7/100: Loss = nan
Epoch 8/100: Loss = nan
Epoch 9/100: Loss = nan
Epoch 10/100: Loss = nan
Epoch 11/100: Loss = nan
Epoch 12/100: Loss = nan
Epoch 13/100: Loss = nan
Epoch 14/100: Loss = nan
Epoch 15/100: Loss = nan
Epoch 16/100: Loss = nan
Epoch 17/100: Loss = nan
Epoch 18/100: Loss = nan
Epoch 19/100: Loss = nan
Epoch 20/100: Loss = nan
Epoch 21/100: Loss = nan
Epoch 22/100: Loss = nan
Epoch 23/100: Loss = nan
Epoch 24/100: Loss = nan
Epoch 25/100: Loss = nan
Epoch 26/100: Loss = nan
Epoch 27/100: Loss = nan
Epoch 28/100: Loss = nan
Epoch 29/100: Loss = nan
Epoch 30/100: Loss = nan
Epoch 31/100: Loss = nan
Epoch 32/100: Loss = nan
Epoch 33/100: Loss = nan
Epoch 34/100: Loss = nan
Epoch 35/100: Loss = nan
Epoch 36/100: Loss = nan
Epoch 37/100: Loss = nan
Epoch 38/100: Loss = nan
Epoch 39/100: Loss = nan
Epoch 40/100: Loss = nan
Epoch 41/