<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install --upgrade flax orbax jax treescope



In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

from flax import nnx
import jax
import jax.numpy as jnp


from sklearn.preprocessing import MinMaxScaler

import orbax.checkpoint as ocp

import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import io



# Dataset Class for Custom  Data
This section defines a CustomDataset class, which inherits from torch.utils.data.Dataset. It is designed to load and preprocess data from a CSV file for machine learning tasks.

python
Copy code


In [3]:
class CustomDataset(Dataset):
    # Initialize the dataset by loading, scaling, and organizing features and labels
    def __init__(self, dataset, transform=None, target_transform=None):
        # Load the CSV file into a DataFrame, using the first column as the index
        completeDF = pd.read_csv(dataset, index_col=0)

        # Extract the 'Class' column as labels
        self.labels = pd.DataFrame(completeDF['Class'])

        # Extract all other columns as features, excluding 'Class'
        features_raw = pd.DataFrame(completeDF.drop('Class', axis=1))

        # Scale the feature values to the range [0, 1] using Min-Max Scaling
        scaler = MinMaxScaler()
        self.features = pd.DataFrame(scaler.fit_transform(features_raw))

        # Assign optional transformations for features and labels
        self.transform = transform
        self.target_transform = target_transform

    # Return the total number of data samples in the dataset
    def __len__(self):
        return len(self.labels)

    # Retrieve a specific sample by index and apply transformations if needed
    def __getitem__(self, idx):
        # Extract feature values as a NumPy array for the given index
        features = np.array(self.features.iloc[idx, :])

        # Extract the corresponding label for the given index
        label = self.labels.iloc[idx, 0]

        # Apply transformation to the features, if provided
        if self.transform:
            features = self.transform(features)

        # Apply transformation to the label, if provided
        if self.target_transform:
            label = self.target_transform(label)

        # Return the processed features and label
        return features, label


In [4]:
import kagglehub
path = kagglehub.dataset_download("mssmartypants/rice-type-classification")
print(path)

/root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2


In [5]:
! ls /root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/
! mv /root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/riceClassification.csv riceClassification.csv

mv: cannot stat '/root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/riceClassification.csv': No such file or directory


In [6]:
dataset = CustomDataset(dataset="riceClassification.csv")
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [.9, 0.05,0.05])
data_loader = DataLoader(train_set, batch_size=4, shuffle=True)

for features, labels in data_loader:
    print("Batch of features has shape: ",features.shape)
    print("Batch of labels has shape: ", labels.shape)
    print(features)
    print(labels)
    break

Batch of features has shape:  torch.Size([4, 10])
Batch of labels has shape:  torch.Size([4])
tensor([[0.7227, 0.7519, 0.6698, 0.7848, 0.6703, 0.7803, 0.4570, 0.5472, 0.7905,
         0.3856],
        [0.6951, 0.7011, 0.6827, 0.7515, 0.6480, 0.7569, 0.7587, 0.5097, 0.8304,
         0.3448],
        [0.8252, 0.7812, 0.7760, 0.7451, 0.7741, 0.8645, 0.4369, 0.5966, 0.8019,
         0.3376],
        [0.7967, 0.7569, 0.7604, 0.7402, 0.7485, 0.8415, 0.3559, 0.5847, 0.7961,
         0.3322]], dtype=torch.float64)
tensor([0, 0, 0, 0])


# MLP Implementation
This section implements a customizable MLP model with dropout and batch normalization at each layer. The architecture is parameterized by input dimensions, hidden dimensions, and the number of classes

In [7]:
class MLP(nnx.Module):
    def __init__(self, din, dm1, dm2, dm3, num_classes: int,*, rngs: nnx.Rngs):
        # Define dropout, linear, and batch normalization layers for each stage
        init_fn = nnx.initializers.lecun_normal()

        self.dp1 = nnx.Dropout(rate=0.4, rngs=rngs)
        self.linear1 = nnx.Linear(din, dm1,kernel_init=nnx.with_partitioning(init_fn, (None, 'model')), rngs=rngs)
        self.bn1 = nnx.BatchNorm(dm1, rngs=rngs)

        self.dp2 = nnx.Dropout(rate=0.2, rngs=rngs)
        self.linear2 = nnx.Linear(dm1, dm2, rngs=rngs)
        self.bn2 = nnx.BatchNorm(dm2, rngs=rngs)

        self.dp3 = nnx.Dropout(rate=0.1, rngs=rngs)
        self.linear3 = nnx.Linear(dm2, dm3, rngs=rngs)
        self.bn3 = nnx.BatchNorm(dm3, rngs=rngs)

        # Output layer without batch normalization
        self.linear4 = nnx.Linear(dm3, num_classes, rngs=rngs)

    def __call__(self, x):
        # Apply dropout, linear transformation, activation, and batch normalization for each layer
        x = self.dp1(x)
        x = self.linear1(x)
        x = nnx.gelu(x)
        x = self.bn1(x)

        x = self.dp2(x)
        x = self.linear2(x)
        x = nnx.gelu(x)
        x = self.bn2(x)

        x = self.dp3(x)
        x = self.linear3(x)
        x = nnx.gelu(x)
        x = self.bn3(x)

        # Final linear transformation
        x = self.linear4(x)

        # Sigmoid activation for binary classification (uncomment if needed)
        # return nnx.sigmoid(x)
        return x
# Instantiate the model with given dimensions
model = MLP(din=10, dm1=16, dm2=32, dm3=16, num_classes=1, rngs=nnx.Rngs(0))

# Test the model with a sample input tensor of shape (3, 10)
y = model(x=jnp.ones((3, 10)))

# Display the model architecture and output
nnx.display(model)
nnx.display(y)

# Output the predictions
print(y)


[[-0.46538347]
 [-0.11259228]
 [ 0.57797575]]


# Custom Metrics

In [8]:

class CustomMetrics(nnx.metrics.Metric):
    def __init__(self):
        # Initialize counters for true positives, false positives, and false negatives
        self.true_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_negatives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def update(self, loss, logits, labels):
        """
        Update the metric counters based on the predictions and labels.
        Assumes logits are probabilities (e.g., from a sigmoid activation).
        """
        # Convert logits to binary predictions
        predictions = jnp.where(jnp.array(logits) > 0.5, 1, 0)

        predictions = predictions.ravel()
        labels = jnp.array(labels).ravel()


        # Compute metrics
        tp = jnp.sum((labels == 1) & (predictions == 1))
        fp = jnp.sum((labels == 0) & (predictions == 1))
        fn = jnp.sum((labels == 1) & (predictions == 0))

        # Update counters
        self.true_positives += tp
        self.false_positives += fp
        self.false_negatives += fn

    def compute(self):
        """
        Compute precision, recall, and F1-score from the accumulated counters.
        """
        precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
        recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
        return {"f1_score": f1_score, "precision": precision, "recall": recall}

    def reset(self):
        """
        Reset the metric counters.
        """
        self.true_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_positives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.false_negatives = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))


In [9]:
class CustomAccuracy(nnx.metrics.Metric):
    def __init__(self):
        self.correct_count= nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.total_count= nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def update(self, loss, logits, labels):
        # Convert logits to binary predictions (0 or 1) based on a 0.5 threshold
        predictions = jnp.where(jnp.array(logits) > 0.5, 1, 0)
        # Flatten if necessary
        predictions = predictions.ravel()
        labels = jnp.array(labels).ravel()

        # Calculate number of correct predictions in the current batch
        self.correct_count += jnp.sum(predictions == labels)
        self.total_count += len(labels)

    def compute(self):
        # Calculate accuracy over all batches seen so far
        if self.total_count == 0:
            return 0  # Avoid division by zero if no samples are seen
        return self.correct_count / self.total_count
    def reset(self):
        # Reset counters
        self.correct_count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))
        self.total_count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

## test class

In [10]:
metric = CustomMetrics()
logits = jnp.array([0.9, 0.7, 0.2, 0.4, 0.8, 0.1])
labels = jnp.array([1, 1, 0, 0, 1, 0])

metric.update(None, logits, labels)
metrics = metric.compute()

print(metrics["precision"], 1)
print(metrics["recall"], 1)
print(metrics["f1_score"], 1)

metric.reset()

logits = jnp.array([0.9, 0.7, 0.2, 0.8, 0.1, 0.3])
labels = jnp.array([1, 0, 1, 1, 0, 0])

# Expected binary predictions: [1, 1, 0, 1, 0, 0]
# TP = 2 (Correct positive predictions: [1, 1])
# FP = 1 (False positive prediction: [1])
# FN = 1 (False negative prediction: [1])
metric.update(None, logits, labels)
metrics = metric.compute()

expected_precision = 2 / (2 + 1)  # TP / (TP + FP)
expected_recall = 2 / (2 + 1)     # TP / (TP + FN)
expected_f1 = 2 * (expected_precision * expected_recall) / (expected_precision + expected_recall)

print(metrics["precision"], expected_precision)
print(metrics["recall"], expected_recall)
print(metrics["f1_score"], expected_f1)


1.0 1
1.0 1
1.0 1
0.6666667 0.6666666666666666
0.6666667 0.6666666666666666
0.6666666 0.6666666666666666


# Setting Up an Optimizer and Metrics

In [11]:
# Import Optax for optimizer configuration
import optax

# Learning rate and momentum for the optimizer
learning_rate = 0.005
momentum = 0.9

# Instantiate the MLP model
model = MLP(10, 16, 32, 16, 1, rngs=nnx.Rngs(0))

# Set up the optimizer using Adam with the specified learning rate
optimizer = nnx.Optimizer(model, optax.adam(learning_rate))

# Initialize metrics for tracking training performance
metrics = nnx.MultiMetric(
    accuracy=CustomAccuracy(),
    f1_precision_recall=CustomMetrics(),
    loss=nnx.metrics.Average('loss'),  # Tracks the average loss
)


#  Defining Loss, Training, and Evaluation Functions

In [13]:
class ModelOperations:
    @staticmethod
    def loss_fn(model, batch):
        """
        Compute the binary cross-entropy loss for a given model and batch.

        Args:
            model (MLP): The model used to compute predictions.
            batch (dict): A batch of input data containing 'features' and 'labels'.

        Returns:
            tuple: The computed loss and logits.
        """
        # Forward pass: Compute logits (raw predictions) using the model
        logits = model(batch['features'])

        # Compute binary cross-entropy loss for classification
        loss = optax.sigmoid_binary_cross_entropy(
            logits=logits, labels=batch['labels'].reshape(-1, 1)
        ).mean()

        # Return the computed loss and logits
        return loss, logits

    @staticmethod
    @nnx.jit
    def train_step(model, optimizer, metrics, batch):
        """
        Perform a single training step.

        Args:
            model (MLP): The neural network model.
            optimizer (nnx.Optimizer): The optimizer for updating model parameters.
            metrics (nnx.MultiMetric): The metric tracker for performance monitoring.
            batch (dict): A batch of input data containing 'features' and 'labels'.

        Returns:
            tuple: The computed loss and the sigmoid-transformed logits.
        """
        # Compute loss and gradients using a differentiable function
        grad_fn = nnx.value_and_grad(ModelOperations.loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(model, batch)

        # Update metrics in-place with loss and logits
        metrics.update(loss=loss, logits=logits, labels=batch['labels'])

        # Apply the computed gradients to update model parameters
        optimizer.update(grads)

        # Return the loss and logits after applying sigmoid for interpretation
        return loss, nnx.sigmoid(logits)

    @staticmethod
    @nnx.jit
    def eval_step(model, metrics, batch):
        """
        Perform a single evaluation step.

        Args:
            model (MLP): The neural network model.
            metrics (nnx.MultiMetric): The metric tracker for performance monitoring.
            batch (dict): A batch of input data containing 'features' and 'labels'.

        Returns:
            float: The computed loss for the batch.
        """
        # Compute loss and logits for the batch
        loss, logits = ModelOperations.loss_fn(model, batch)

        # Update metrics in-place with evaluation results
        metrics.update(loss=loss, logits=logits, labels=batch['labels'])

        # Return the loss for further aggregation or monitoring
        return loss

    @staticmethod
    @nnx.jit
    def pred_step(model, batch):
        """
        Perform a prediction step to compute sigmoid-transformed logits.

        Args:
            model (MLP): The model used to compute predictions.
            batch (dict): A batch of input data containing 'features'.

        Returns:
            jax.numpy.ndarray: The sigmoid-transformed logits.
        """
        logits = model(batch['features'])
        return nnx.sigmoid(logits)


In [14]:
import jax
from orbax.checkpoint.type_handlers import TypeHandler
from orbax.checkpoint.type_handlers import register_type_handler


class ModelCheckpointManager:
    def __init__(self, model_class, model_args, ckpt_dir, learning_rate):
        """
        Initialize the checkpoint manager with the model class, arguments, checkpoint directory, and learning rate.

        Args:
            model_class: The model class to instantiate (e.g., MLP).
            model_args (tuple): Arguments for the model class initialization.
            ckpt_dir (str): Directory of the checkpoint files.
            learning_rate (float): Learning rate for the optimizer.
        """
        self.model_class = model_class
        self.model_args = model_args
        self.ckpt_dir = ckpt_dir
        self.learning_rate = learning_rate

    def process_and_save_model_state(self, model):
        """
        Processes the model state, modifies the PRNG key, and saves the state to a checkpoint directory.

        Args:
            model: The model whose state is being processed.

        Returns:
            prng_key_value: The processed PRNG key value before modification.
        """
        # Retrieve and split the model's state
        _, state = nnx.split(model)

        # Modify the PRNG key for 'dp1'
        prng_key_value = state["dp1"]["rngs"]["default"]["key"].value
        state["dp1"]["rngs"]["default"]["key"] = nnx.VariableState(
            type=nnx.Param,
            value=jax.random.key_data(prng_key_value),
            tag='default'
        )

        # Create a new empty checkpoint directory
        self.ckpt_dir = ocp.test_utils.erase_and_create_empty(self.ckpt_dir)

        # Initialize the checkpointing system and save the state
        checkpointer = ocp.PyTreeCheckpointer()
        checkpointer.save(f'{self.ckpt_dir}/state', state)

        return prng_key_value

    def restore_model_and_initialize_optimizer(self):
        """
        Restores a model's state from a checkpoint and initializes the optimizer and metrics.

        Returns:
            tuple: The restored model, optimizer, and metrics object.
        """
        # Step 1: Re-initialize the model
        new_model = self.model_class(*self.model_args, rngs=nnx.Rngs(0))

        # Step 2: Evaluate the abstract shape of the model and split into graph/state
        abstract_model = nnx.eval_shape(lambda: new_model)
        graph_def, abstract_state = nnx.split(abstract_model)

        # Step 3: Restore the state from checkpoint
        checkpointer = ocp.PyTreeCheckpointer()
        state_restored = checkpointer.restore(f"{self.ckpt_dir}/state", abstract_state)

        # Step 4: Modify the PRNG key value and ensure correct type
        prng_key_value = state_restored["dp1"]["rngs"]["default"]["key"].value
        state_restored["dp1"]["rngs"]["default"]["key"].value = jax.random.wrap_key_data(prng_key_value)

        # Step 5: Merge the graph definition with the restored state
        restored_model = nnx.merge(graph_def, state_restored)

        # Step 6: Initialize the optimizer
        optimizer = nnx.Optimizer(restored_model, optax.adam(self.learning_rate))

        # Step 7: Initialize metrics
        metrics = nnx.MultiMetric(
            loss=nnx.metrics.Average('loss')
        )

        return restored_model, optimizer, metrics

# Example usage
"""
model_args = (10, 16, 32, 16, 1)
ckpt_dir = '/content/my-checkpoints/'  # Specify your checkpoint directory
manager = ModelCheckpointManager(MLP, model_args, ckpt_dir, learning_rate)
model = MLP(*model_args, rngs=nnx.Rngs(0))

# Process and save model state
prng_key_value = manager.process_and_save_model_state(model)

# Restore model and initialize optimizer
restored_model, new_optimizer, new_metrics = manager.restore_model_and_initialize_optimizer()
"""


"\nmodel_args = (10, 16, 32, 16, 1)\nckpt_dir = '/content/my-checkpoints/'  # Specify your checkpoint directory\nmanager = ModelCheckpointManager(MLP, model_args, ckpt_dir, learning_rate)\nmodel = MLP(*model_args, rngs=nnx.Rngs(0))\n\n# Process and save model state\nprng_key_value = manager.process_and_save_model_state(model)\n\n# Restore model and initialize optimizer\nrestored_model, new_optimizer, new_metrics = manager.restore_model_and_initialize_optimizer()\n"

# Dataset Preparation and DataLoader Creation

In [15]:
def custom_collate_fn(batch):
    # Transpose batch to group features and labels separately
    transposed_data = list(zip(*batch))

    # Convert labels and features into NumPy arrays
    labels = np.array(transposed_data[1])
    features = np.array(transposed_data[0])

    # Return a dictionary with features and labels
    return {"features": features, "labels": labels}

train_set, val_set, test_set = torch.utils.data.random_split(dataset, [0.7, 0.1, 0.2])

train_ds = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)
val_ds = DataLoader(val_set, batch_size=64, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)
test_ds = DataLoader(test_set, batch_size=64, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

# Fetch a single batch from the training DataLoader
batch_data = next(iter(train_ds))
imgs = batch_data['features']
lbls = batch_data['labels']

# Print shapes and data types for verification
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)
print(lbls)



(64, 10) float64 (64,) int64
[0 1 1 1 0 0 1 0 0 1 1 1 1 0 1 0 0 1 0 0 0 1 0 1 1 0 1 1 1 1 1 1 1 1 0 1 1
 1 0 0 0 1 0 0 0 1 0 0 0 1 1 1 0 1 1 0 0 1 1 0 1 0 1 1]


# Checkpoint! save model

# Training Loop for Neural Network with Metrics Tracking

In [17]:
import jax

import numpy as np

class TrainerWithEarlyStopping:
    def __init__(self, model, optimizer, train_ds, val_ds, metrics, num_epochs, tracking_metric, patience, manager, mode='min'):
        """
        Initializes the trainer with necessary parameters.

        Args:
            model: The model to train and evaluate.
            optimizer: Optimizer for training the model.
            train_ds: Training dataset.
            val_ds: Validation dataset.
            metrics: Metrics object for tracking performance.
            num_epochs (int): Maximum number of epochs to train.
            tracking_metric (str): Metric to track for early stopping.
            patience (int): Number of epochs to wait without improvement before stopping.
            manager: Object managing checkpointing and state saving.
            mode (str): 'min' for metrics to minimize, 'max' for metrics to maximize.
        """
        self.model = model
        self.optimizer = optimizer
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.metrics = metrics
        self.num_epochs = num_epochs
        self.tracking_metric = tracking_metric
        self.patience = patience
        self.manager = manager
        self.mode = mode
        self.metrics_history = {
            'train_loss': [], 'train_accuracy': [], 'train_precision': [],
            'train_recall': [], 'train_f1_score': [], 'val_loss': [],
            'val_accuracy': [], 'val_precision': [], 'val_recall': [],
            'val_f1_score': []
        }
        self.best_value = float('inf') if mode == 'min' else float('-inf')
        self.best_epoch = 0
        self.prng_key_value = None

    def update_metrics_history(self, prefix):
        """
        Updates the metrics history with current metrics.

        Args:
            prefix (str): Prefix for metric keys (e.g., 'train', 'val').
        """
        for metric, value in self.metrics.compute().items():
            if metric == 'f1_precision_recall':
                self.metrics_history[f'{prefix}_f1_score'].append(value['f1_score'])
                self.metrics_history[f'{prefix}_precision'].append(value['precision'])
                self.metrics_history[f'{prefix}_recall'].append(value['recall'])
            else:
                self.metrics_history[f'{prefix}_{metric}'].append(value)

    def log_metrics(self, epoch, prefix):
        """
        Logs metrics for the specified phase (train, val).

        Args:
            epoch (int): Current epoch number.
            prefix (str): Prefix for metric keys (e.g., 'train', 'val').
        """
        print(
            f"[{prefix}] epoch: {epoch + 1}/{self.num_epochs}, "
            f"loss: {self.metrics_history[f'{prefix}_loss'][-1]:.4f}, "
            f"accuracy: {self.metrics_history[f'{prefix}_accuracy'][-1]:.4f}, "
            f"precision: {self.metrics_history[f'{prefix}_precision'][-1]:.4f}, "
            f"recall: {self.metrics_history[f'{prefix}_recall'][-1]:.4f}, "
            f"f1_score: {self.metrics_history[f'{prefix}_f1_score'][-1]:.4f}"
        )

    def train_and_evaluate(self, trainer):
        """
        Trains and evaluates the model, implementing early stopping and model saving.

        Args:
            trainer: Trainer object with `train_step` and `eval_step` methods.

        Returns:
            dict: Metrics history.
            float: Best value for the tracking metric.
            any: Final PRNG key value after training.
        """
        for epoch in range(self.num_epochs):
            # Training phase
            for batch in self.train_ds:
                trainer.train_step(self.model, self.optimizer, self.metrics, batch)
            self.update_metrics_history('train')
            self.metrics.reset()
            self.log_metrics(epoch, 'train')

            # Validation phase
            for batch in self.val_ds:
                trainer.eval_step(self.model, self.metrics, batch)
            self.update_metrics_history('val')
            self.metrics.reset()
            self.log_metrics(epoch, 'val')

            # Early stopping and model saving
            current_value = self.metrics_history[f'val_{self.tracking_metric}'][-1]
            if (self.mode == 'min' and current_value < self.best_value) or (self.mode == 'max' and current_value > self.best_value):
                self.best_value = current_value
                self.best_epoch = epoch
                print(f"Epoch {epoch + 1}: Best {self.tracking_metric} {self.best_value:.4f}. Saving model...")
                self.prng_key_value = self.manager.process_and_save_model_state(self.model)
            elif epoch - self.best_epoch >= self.patience:
                print("Early stopping triggered.")
                break

        return self.metrics_history, self.best_value, self.prng_key_value

# Example usage
num_epochs = 50
tracking_metric = 'val_accuracy'
patience = 3
mode = 'max'
ModelOperations = ModelOperations()


model_args = (10, 16, 32, 16, 1)
ckpt_dir = '/content/my-checkpoints/'  # Specify your checkpoint directory

manager = ModelCheckpointManager(MLP, model_args, ckpt_dir, learning_rate)

trainer = TrainerWithEarlyStopping(
    model=model,
    optimizer=optimizer,
    train_ds=train_ds,
    val_ds=val_ds,
    metrics=metrics,
    num_epochs=50,
    tracking_metric='loss',
    patience=5,
    manager=manager,
    mode='min'
)
metrics_history, best_value, prng_key_value = trainer.train_and_evaluate(ModelOperations)
#metrics_history, state_org, prng_key_value = train_and_evaluate(model, optimizer, train_ds, val_ds, metrics, num_epochs, tracking_metric, patience, trainer, mode=mode)


[train] epoch: 1/50, loss: 0.3698, accuracy: 0.8116, precision: 0.8583, recall: 0.7865, f1_score: 0.8208
[val] epoch: 1/50, loss: 0.2205, accuracy: 0.9029, precision: 0.9352, recall: 0.8862, f1_score: 0.9100
Epoch 1: Best loss 0.2205. Saving model...
[train] epoch: 2/50, loss: 0.1840, accuracy: 0.9233, precision: 0.9406, recall: 0.9184, f1_score: 0.9294
[val] epoch: 2/50, loss: 0.1418, accuracy: 0.9375, precision: 0.9593, recall: 0.9265, f1_score: 0.9426
Epoch 2: Best loss 0.1418. Saving model...
[train] epoch: 3/50, loss: 0.1547, accuracy: 0.9350, precision: 0.9494, recall: 0.9311, f1_score: 0.9402
[val] epoch: 3/50, loss: 0.1300, accuracy: 0.9481, precision: 0.9527, recall: 0.9537, f1_score: 0.9532
Epoch 3: Best loss 0.1300. Saving model...
[train] epoch: 4/50, loss: 0.1371, accuracy: 0.9455, precision: 0.9560, recall: 0.9442, f1_score: 0.9501
[val] epoch: 4/50, loss: 0.1212, accuracy: 0.9425, precision: 0.9744, recall: 0.9204, f1_score: 0.9467
Epoch 4: Best loss 0.1212. Saving model

# Prediction and Evaluation Pipeline

In [19]:
model.eval()

test_ds = DataLoader(test_set, batch_size=32, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

ypred = []  # List to store predicted probabilities
label = []  # List to store true labels

for test_batch in test_ds:
    logits = ModelOperations.pred_step(model, test_batch)
    ypred.extend(np.ravel(logits))  # Flatten and collect predictions
    label.extend(np.ravel(test_batch["labels"]))  # Flatten and collect true labels

binary_ypred = np.where(np.array(ypred) > 0.5, 1, 0)

accuracy = sum([1 for pred, true in zip(binary_ypred, label) if pred == true]) / len(label)
print(f"Accuracy: {accuracy:.4f}")


Accuracy: 0.7848
