<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/CustomMetrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install --upgrade flax orbax jax



In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

from flax import nnx
import jax.numpy as jnp

from sklearn.preprocessing import MinMaxScaler

import orbax.checkpoint as ocp

import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import io



In [3]:

# Generate a dataset with 100 samples, 10 features, 5 informative, 5 redundant, and 2 classes
X, y = make_classification(n_samples=10000, n_features=10, n_informative=8,
                          n_classes=2, random_state=42)
df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
df['Class'] = y
csv_data = df.to_csv(index=True)


# Dataset Class for Custom  Data
This section defines a CustomDataset class, which inherits from torch.utils.data.Dataset. It is designed to load and preprocess data from a CSV file for machine learning tasks.

python
Copy code


In [4]:
class CustomImageDataset(Dataset):
    # Initialize the dataset by loading, scaling, and organizing features and labels
    def __init__(self, dataset, transform=None, target_transform=None):
        # Load the CSV file into a DataFrame, using the first column as the index
        completeDF = pd.read_csv(dataset, index_col=0)

        # Extract the 'Class' column as labels
        self.labels = pd.DataFrame(completeDF['Class'])

        # Extract all other columns as features, excluding 'Class'
        features_raw = pd.DataFrame(completeDF.drop('Class', axis=1))

        # Scale the feature values to the range [0, 1] using Min-Max Scaling
        scaler = MinMaxScaler()
        self.features = pd.DataFrame(scaler.fit_transform(features_raw))

        # Assign optional transformations for features and labels
        self.transform = transform
        self.target_transform = target_transform

    # Return the total number of data samples in the dataset
    def __len__(self):
        return len(self.labels)

    # Retrieve a specific sample by index and apply transformations if needed
    def __getitem__(self, idx):
        # Extract feature values as a NumPy array for the given index
        features = np.array(self.features.iloc[idx, :])

        # Extract the corresponding label for the given index
        label = self.labels.iloc[idx, 0]

        # Apply transformation to the features, if provided
        if self.transform:
            features = self.transform(features)

        # Apply transformation to the label, if provided
        if self.target_transform:
            label = self.target_transform(label)

        # Return the processed features and label
        return features, label


In [5]:
import kagglehub
path = kagglehub.dataset_download("mssmartypants/rice-type-classification")
print(path)

/root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2


In [6]:
! ls /root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/
! mv /root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/riceClassification.csv riceClassification.csv

mv: cannot stat '/root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/riceClassification.csv': No such file or directory


In [7]:
#dataset = CustomImageDataset(dataset=io.StringIO(csv_data))
dataset = CustomImageDataset(dataset="riceClassification.csv")
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [.9, 0.05,0.05])
data_loader = DataLoader(train_set, batch_size=4, shuffle=True)

for features, labels in data_loader:
    print("Batch of features has shape: ",features.shape)
    print("Batch of labels has shape: ", labels.shape)
    print(features)
    print(labels)
    break

Batch of features has shape:  torch.Size([4, 10])
Batch of labels has shape:  torch.Size([4])
tensor([[0.4464, 0.7199, 0.3409, 0.9179, 0.4184, 0.5301, 0.7653, 0.4536, 0.6562,
         0.6445],
        [0.7972, 0.7674, 0.7494, 0.7517, 0.7394, 0.8419, 0.3968, 0.5646, 0.8317,
         0.3451],
        [0.1432, 0.4239, 0.1024, 0.9252, 0.1427, 0.1962, 0.3097, 0.2244, 0.6361,
         0.6663],
        [0.7786, 0.7595, 0.7281, 0.7587, 0.7265, 0.8267, 0.6643, 0.5578, 0.8260,
         0.3531]], dtype=torch.float64)
tensor([1, 0, 1, 0])


# MLP Implementation
This section implements a customizable MLP model with dropout and batch normalization at each layer. The architecture is parameterized by input dimensions, hidden dimensions, and the number of classes

In [8]:
class MLP(nnx.Module):
    def __init__(self, din, dm1, dm2, dm3, num_classes: int, rngs: nnx.Rngs):
        # Define dropout, linear, and batch normalization layers for each stage
        self.dp1 = nnx.Dropout(rate=0.4, rngs=rngs)
        self.linear1 = nnx.Linear(din, dm1, rngs=rngs)
        self.bn1 = nnx.BatchNorm(dm1, rngs=rngs)

        self.dp2 = nnx.Dropout(rate=0.2, rngs=rngs)
        self.linear2 = nnx.Linear(dm1, dm2, rngs=rngs)
        self.bn2 = nnx.BatchNorm(dm2, rngs=rngs)

        self.dp3 = nnx.Dropout(rate=0.1, rngs=rngs)
        self.linear3 = nnx.Linear(dm2, dm3, rngs=rngs)
        self.bn3 = nnx.BatchNorm(dm3, rngs=rngs)

        # Output layer without batch normalization
        self.linear4 = nnx.Linear(dm3, num_classes, rngs=rngs)

    def __call__(self, x):
        # Apply dropout, linear transformation, activation, and batch normalization for each layer
        x = self.dp1(x)
        x = self.linear1(x)
        x = nnx.gelu(x)
        x = self.bn1(x)

        x = self.dp2(x)
        x = self.linear2(x)
        x = nnx.gelu(x)
        x = self.bn2(x)

        x = self.dp3(x)
        x = self.linear3(x)
        x = nnx.gelu(x)
        x = self.bn3(x)

        # Final linear transformation
        x = self.linear4(x)

        # Sigmoid activation for binary classification (uncomment if needed)
        # return nnx.sigmoid(x)
        return x
# Instantiate the model with given dimensions
model = MLP(din=10, dm1=16, dm2=32, dm3=16, num_classes=1, rngs=nnx.Rngs(0))

# Test the model with a sample input tensor of shape (3, 10)
y = model(x=jnp.ones((3, 10)))

# Display the model architecture and output
nnx.display(model)
nnx.display(y)

# Output the predictions
print(y)


MLP(
  dp1=Dropout(rate=0.4, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(17, dtype=uint32),
        tag='default'
      )
    )
  )),
  linear1=Linear(
    kernel=Param(
      value=Array(shape=(10, 16), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    in_features=10,
    out_features=16,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x7ebf078f3f40>,
    bias_init=<function zeros at 0x7ebf07ff8dc0>,
    dot_general=<function dot_general at 0x7ebf08925ea0>
  ),
  bn1=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),

# Custom Metrics

In [9]:

class CustomMetrics(nnx.metrics.Metric):
    def __init__(self):
        # Initialize counters for true positives, false positives, and false negatives
        self.true_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_negatives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def update(self, loss, logits, labels):
        """
        Update the metric counters based on the predictions and labels.
        Assumes logits are probabilities (e.g., from a sigmoid activation).
        """
        # Convert logits to binary predictions
        predictions = jnp.where(jnp.array(logits) > 0.5, 1, 0)

        predictions = predictions.ravel()
        labels = jnp.array(labels).ravel()


        # Compute metrics
        tp = jnp.sum((labels == 1) & (predictions == 1))
        fp = jnp.sum((labels == 0) & (predictions == 1))
        fn = jnp.sum((labels == 1) & (predictions == 0))

        # Update counters
        self.true_positives += tp
        self.false_positives += fp
        self.false_negatives += fn

    def compute(self):
        """
        Compute precision, recall, and F1-score from the accumulated counters.
        """
        precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
        recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
        return {"f1_score": f1_score, "precision": precision, "recall": recall}

    def reset(self):
        """
        Reset the metric counters.
        """
        self.true_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_negatives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))


In [10]:
class CustomAccuracy(nnx.metrics.Metric):
    def __init__(self):
        self.correct_count= nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.total_count= nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def update(self, loss, logits, labels):
        # Convert logits to binary predictions (0 or 1) based on a 0.5 threshold
        predictions = jnp.where(jnp.array(logits) > 0.5, 1, 0)
        # Flatten if necessary
        predictions = predictions.ravel()
        labels = jnp.array(labels).ravel()

        # Calculate number of correct predictions in the current batch
        self.correct_count += jnp.sum(predictions == labels)
        self.total_count += len(labels)

    def compute(self):
        # Calculate accuracy over all batches seen so far
        if self.total_count == 0:
            return 0  # Avoid division by zero if no samples are seen
        return self.correct_count / self.total_count
    def reset(self):
        # Reset counters
        self.correct_count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.total_count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

## test class

In [11]:
metric = CustomMetrics()
logits = jnp.array([0.9, 0.7, 0.2, 0.4, 0.8, 0.1])
labels = jnp.array([1, 1, 0, 0, 1, 0])

metric.update(None, logits, labels)
metrics = metric.compute()

print(metrics["precision"], 1)
print(metrics["recall"], 1)
print(metrics["f1_score"], 1)

metric.reset()

logits = jnp.array([0.9, 0.7, 0.2, 0.8, 0.1, 0.3])
labels = jnp.array([1, 0, 1, 1, 0, 0])

# Expected binary predictions: [1, 1, 0, 1, 0, 0]
# TP = 2 (Correct positive predictions: [1, 1])
# FP = 1 (False positive prediction: [1])
# FN = 1 (False negative prediction: [1])
metric.update(None, logits, labels)
metrics = metric.compute()

expected_precision = 2 / (2 + 1)  # TP / (TP + FP)
expected_recall = 2 / (2 + 1)     # TP / (TP + FN)
expected_f1 = 2 * (expected_precision * expected_recall) / (expected_precision + expected_recall)

print(metrics["precision"], expected_precision)
print(metrics["recall"], expected_recall)
print(metrics["f1_score"], expected_f1)


1.0 1
1.0 1
1.0 1
0.6666667 0.6666666666666666
0.6666667 0.6666666666666666
0.6666666 0.6666666666666666


# Setting Up an Optimizer and Metrics

In [22]:
# Import Optax for optimizer configuration
import optax

# Learning rate and momentum for the optimizer
learning_rate = 0.005
momentum = 0.9

# Instantiate the MLP model
model = MLP(10, 16, 32, 16, 1, rngs=nnx.Rngs(0))

# Set up the optimizer using Adam with the specified learning rate
optimizer = nnx.Optimizer(model, optax.adam(learning_rate))

# Initialize metrics for tracking training performance
metrics = nnx.MultiMetric(
    accuracy=CustomAccuracy(),
    f1_precision_recall=CustomMetrics(),
    loss=nnx.metrics.Average('loss'),  # Tracks the average loss
)


#  Defining Loss, Training, and Evaluation Functions

In [13]:
def loss_fn(model: MLP, batch):
    # Forward pass: Compute logits (raw predictions) using the model
    logits = model(batch['features'])

    # Compute binary cross-entropy loss for classification
    loss = optax.sigmoid_binary_cross_entropy(
        logits=logits, labels=batch['labels'].reshape(-1, 1)
    ).mean()

    # Optionally, use a custom loss function like mean squared error
    # loss = (logits - batch['labels'])**2

    # Return the computed loss and logits
    return loss, logits
@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    """Train for a single step."""
    # Compute loss and gradients using a differentiable function
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)

    # Update metrics in-place with loss and logits
    metrics.update(loss=loss, logits=logits, labels=batch['labels'])

    # Apply the computed gradients to update model parameters
    optimizer.update(grads)

    # Return the loss and logits after applying sigmoid for interpretation
    return loss, nnx.sigmoid(logits)
@nnx.jit
def eval_step(model: MLP, metrics: nnx.MultiMetric, batch):
    # Compute loss and logits for the batch
    loss, logits = loss_fn(model, batch)

    # Update metrics in-place with evaluation results
    metrics.update(loss=loss, logits=logits, labels=batch['labels'])

    # Return the loss for further aggregation or monitoring
    return loss


# Dataset Preparation and DataLoader Creation

In [19]:
def custom_collate_fn(batch):
    # Transpose batch to group features and labels separately
    transposed_data = list(zip(*batch))

    # Convert labels and features into NumPy arrays
    labels = np.array(transposed_data[1])
    features = np.array(transposed_data[0])

    # Return a dictionary with features and labels
    return {"features": features, "labels": labels}

train_set, val_set, test_set = torch.utils.data.random_split(dataset, [0.7, 0.1, 0.2])

train_ds = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)
val_ds = DataLoader(val_set, batch_size=64, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)
test_ds = DataLoader(test_set, batch_size=64, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

# Fetch a single batch from the training DataLoader
batch_data = next(iter(train_ds))
imgs = batch_data['features']
lbls = batch_data['labels']

# Print shapes and data types for verification
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)
print(lbls)



(64, 10) float64 (64,) int64
[1 0 0 1 1 1 1 0 1 0 0 1 1 1 0 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 0 1 0 1 1 1 1
 1 1 1 1 1 1 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 1 1 1 1 1]


# Checkpoint! save model

In [20]:
import jax
from orbax.checkpoint.type_handlers import TypeHandler
from orbax.checkpoint.type_handlers import register_type_handler
from copy import deepcopy

def process_and_save_model_state(model, ckpt_dir):
    """
    Processes the model state, modifies the PRNG key, and saves the state to a checkpoint directory.

    Args:
        model: The model whose state is being processed.
        ckpt_dir (str): Directory to save the checkpoint.
    """
    # Retrieve the model's state
    state = nnx.state(model)

    # Split the model into parameters and state
    _, state = nnx.split(model)

    # Deep copy the state
    state_org = deepcopy(state)

    # Display the current state (for inspection)
    #nnx.display(state)

    # Modify the PRNG key for 'dp1'
    prng_key_value = state["dp1"]["rngs"]["default"]["key"].value
    state["dp1"]["rngs"]["default"]["key"] = nnx.VariableState(
        type=nnx.Param,
        value=jax.random.key_data(prng_key_value),
        tag='default'
    )

    # Create a new empty checkpoint directory
    ckpt_dir = ocp.test_utils.erase_and_create_empty(ckpt_dir)

    # Initialize the checkpointing system
    checkpointer = ocp.PyTreeCheckpointer()

    # Save the modified state to the checkpoint directory
    checkpointer.save(f'{ckpt_dir}/state', state)
    #return state_org, prng_key_value

# Example usage
#model = ...  # Initialize your model
ckpt_dir = '/content/my-checkpoints/'  # Specify your checkpoint directory
#state_org, prng_key_value = process_and_save_model_state(model, ckpt_dir)
process_and_save_model_state(model, ckpt_dir)



# Training Loop for Neural Network with Metrics Tracking

In [27]:
import jax

def update_metrics_history(metrics_history, prefix, metrics):
    """
    Updates the metrics history dictionary with the computed metrics for a given phase.

    Args:
        metrics_history (dict): Dictionary to store metrics over epochs.
        prefix (str): Prefix for metric keys (e.g., 'train', 'val', 'test').
        metrics (object): Metrics object with `compute` method returning metric values.
    """
    for metric, value in metrics.compute().items():
        if metric == 'f1_precision_recall':
            metrics_history[f'{prefix}_f1_score'].append(value['f1_score'])
            metrics_history[f'{prefix}_precision'].append(value['precision'])
            metrics_history[f'{prefix}_recall'].append(value['recall'])
        else:
            metrics_history[f'{prefix}_{metric}'].append(value)

def log_metrics(epoch, num_epochs, metrics_history, prefix):
    """
    Logs the metrics for a given phase (train, val) in the current epoch.

    Args:
        epoch (int): Current epoch number.
        num_epochs (int): Total number of epochs.
        metrics_history (dict): Dictionary containing historical metrics.
        prefix (str): Prefix for metric keys (e.g., 'train', 'val').
    """
    print(
        f"[{prefix}] epoch: {epoch + 1}/{num_epochs}, "
        f"loss: {metrics_history[f'{prefix}_loss'][-1]:.4f}, "
        f"accuracy: {metrics_history[f'{prefix}_accuracy'][-1]:.4f}, "
        f"precision: {metrics_history[f'{prefix}_precision'][-1]:.4f}, "
        f"recall: {metrics_history[f'{prefix}_recall'][-1]:.4f}, "
        f"f1_score: {metrics_history[f'{prefix}_f1_score'][-1]:.4f}, "
    )

def train_and_evaluate(model, optimizer, train_ds, val_ds, metrics, num_epochs, tracking_metric):
    """
    Trains and evaluates the model while tracking and logging metrics.

    Args:
        model: Model to train.
        optimizer: Optimizer for model training.
        train_ds: Training dataset.
        val_ds: Validation dataset.
        metrics: Metrics object to track performance.
        num_epochs (int): Number of training epochs.
    """
    metrics_history = {
        'train_loss': [], 'train_accuracy': [], 'train_precision': [],
        'train_recall': [], 'train_f1_score': [], 'val_loss': [],
        'val_accuracy': [], 'val_precision': [], 'val_recall': [],
        'val_f1_score': []
    }

    for epoch in range(num_epochs):
        # Training phase
        for batch in train_ds:
            loss, logits = train_step(model, optimizer, metrics, batch)
        update_metrics_history(metrics_history, 'train', metrics)
        metrics.reset()  # Reset metrics after training
        log_metrics(epoch, num_epochs, metrics_history, 'train')

        # Validation phase
        for batch in val_ds:
            loss = eval_step(model, metrics, batch)
        update_metrics_history(metrics_history, 'val', metrics)
        metrics.reset()  # Reset metrics after validation
        log_metrics(epoch, num_epochs, metrics_history, 'val')
        if metrics_history[tracking_metric][-1] == max(metrics_history[tracking_metric]):
            print('saving model')
            ckpt_dir = '/content/my-checkpoints/'  # Specify your checkpoint directory
            process_and_save_model_state(model, ckpt_dir)

    return metrics_history

# Example usage
num_epochs = 5
tracking_metric = 'val_accuracy'
metrics_history = train_and_evaluate(model, optimizer, train_ds, val_ds, metrics, num_epochs, tracking_metric)


[train] epoch: 1/5, loss: 0.1069, accuracy: 0.9592, precision: 0.9669, recall: 0.9590, f1_score: 0.9629, 
[val] epoch: 1/5, loss: 0.0994, accuracy: 0.9654, precision: 0.9560, recall: 0.9801, f1_score: 0.9679, 
saving model
[train] epoch: 2/5, loss: 0.1039, accuracy: 0.9616, precision: 0.9684, recall: 0.9619, f1_score: 0.9651, 
[val] epoch: 2/5, loss: 0.1052, accuracy: 0.9632, precision: 0.9586, recall: 0.9727, f1_score: 0.9656, 
[train] epoch: 3/5, loss: 0.1009, accuracy: 0.9605, precision: 0.9673, recall: 0.9610, f1_score: 0.9641, 
[val] epoch: 3/5, loss: 0.0988, accuracy: 0.9643, precision: 0.9684, recall: 0.9643, f1_score: 0.9664, 
[train] epoch: 4/5, loss: 0.1020, accuracy: 0.9610, precision: 0.9684, recall: 0.9609, f1_score: 0.9646, 
[val] epoch: 4/5, loss: 0.1026, accuracy: 0.9637, precision: 0.9674, recall: 0.9643, f1_score: 0.9658, 
[train] epoch: 5/5, loss: 0.0966, accuracy: 0.9644, precision: 0.9715, recall: 0.9639, f1_score: 0.9677, 
[val] epoch: 5/5, loss: 0.1056, accuracy:

# Prediction and Evaluation Pipeline

In [17]:
model.eval()
@nnx.jit
def pred_step(model: MLP, batch):
    logits = model(batch['features'])
    return nnx.sigmoid(logits)

test_ds = DataLoader(test_set, batch_size=32, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

ypred = []  # List to store predicted probabilities
label = []  # List to store true labels

for test_batch in test_ds:
    logits = pred_step(model, test_batch)
    ypred.extend(np.ravel(logits))  # Flatten and collect predictions
    label.extend(np.ravel(test_batch["labels"]))  # Flatten and collect true labels

binary_ypred = np.where(np.array(ypred) > 0.5, 1, 0)

accuracy = sum([1 for pred, true in zip(binary_ypred, label) if pred == true]) / len(label)
print(f"Accuracy: {accuracy:.4f}")


Accuracy: 0.6944


# Initializing a New Model and Load!

In [18]:
newModel = MLP(10, 16, 32, 16, 1, rngs=nnx.Rngs(0))  # Re-initialize the model
abstract_model = nnx.eval_shape(lambda: newModel)  # Evaluate the abstract shape of the model
graphdef, abstract_state = nnx.split(abstract_model)  # Split the model into graph and state
print('The abstract NNX state (all leaves are abstract arrays):')
#nnx.display(abstract_state)  # Uncomment this to display the abstract state (optional for debugging)
checkpointer = ocp.PyTreeCheckpointer()
state_restored = checkpointer.restore(ckpt_dir + 'state', abstract_state)  # Restore the state from checkpoint
#state_restored = checkpointer.restore(optdir+ 'state', abstract_state)  # Restore the state from checkpoint
#nnx.display(state_restored['dp1'])  # Optional: Display restored state of dp1

#prng_key_value = state["dp1"]["rngs"]["default"]["key"].value  # Access the PRNG key from the original state
state_restored["dp1"]["rngs"]["default"]["key"].type(nnx.RngKey)  # Ensure the restored key has the correct type
state_restored["dp1"]["rngs"]["default"]["key"].value = jax.random.wrap_key_data(prng_key_value)  # Modify the PRNG key value

nnx.display(state_restored['dp1'])  # Display the restored dp1 state
nnx.display(state_org['dp1'])  # Display the original dp1 state
jax.tree.map(np.testing.assert_array_equal, state_org, state_restored)  # Ensure both states are equal
print('NNX State restored: ')
#nnx.display(state_restored)  # Uncomment this to display the full restored state

newModel = nnx.merge(graphdef, state_restored)  # Merge the graph definition with the restored state
optimizer = nnx.Optimizer(newModel, optax.adam(learning_rate))  # Reinitialize the optimizer
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),  # Initialize metrics for tracking the loss
)


The abstract NNX state (all leaves are abstract arrays):




NameError: name 'prng_key_value' is not defined

# Train more if needed

In [None]:
import jax
#model.train()
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'train_precision': [],
  'train_recall': [],
  'train_f1': [],
  'val_loss': [],
  'val_accuracy': [],
  'val_precision': [],
  'val_recall': [],
  'val_f1': [],
  'test_loss': [],
  'test_accuracy': [],
  'test_precision': [],
  'test_recall': [],
  'test_f1': [],
}

num_epochs = 40  # Number of training epochs
for epoch in range(num_epochs):
    # Train on each batch in the training dataset
    for batch in train_ds:
        loss, logits = train_step(model, optimizer, metrics, batch)

    # Compute and log training metrics for this epoch
    for metric, value in metrics.compute().items():
        metrics_history[f'train_{metric}'].append(value)

    # Reset metrics for the next epoch
    metrics.reset()

    # Log training performance
    print(
        f"[train] epoch: {epoch + 1}/{num_epochs}, "
        f"loss: {metrics_history['train_loss'][-1]:.4f}, "
    )

# later stuff
"""
for batch in val_ds:
    loss = eval_step(model, metrics, batch)

for metric, value in metrics.compute().items():
    metrics_history[f'val_{metric}'].append(value)

metrics.reset()
print(
    f"[val] epoch: {epoch + 1}/{num_epochs}, "
    f"loss: {metrics_history['val_loss'][-1]:.4f}, "
)

print(
    f"[train] epoch: {epoch + 1}/{num_epochs}, "
    f"loss: {metrics_history['train_loss'][-1]:.4f}, "
    f"accuracy: {metrics_history['train_accuracy'][-1]:.2f}"
)
"""

# Validate saved model

In [None]:
@nnx.jit
def pred_step(model: MLP, batch):
    logits = model(batch['features'])
    return nnx.sigmoid(logits)

test_ds = DataLoader(test_set, batch_size=32, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

ypred = []  # List to store predicted probabilities
label = []  # List to store true labels

for test_batch in test_ds:
    logits = pred_step(newModel, test_batch)
    ypred.extend(np.ravel(logits))  # Flatten and collect predictions
    label.extend(np.ravel(test_batch["labels"]))  # Flatten and collect true labels

binary_ypred = np.where(np.array(ypred) > 0.5, 1, 0)

accuracy = sum([1 for pred, true in zip(binary_ypred, label) if pred == true]) / len(label)
print(f"Accuracy: {accuracy:.4f}")
