Skip to content

Commit

Permalink
1. Rename old BinaryAUC metric to BinarySparseAUC(used by text_classi…
Browse files Browse the repository at this point in the history
…fier) and create a new BinaryAUC metric which does not expect sparse inputs.

2. Add auc, precision, and recall metrics to binary classification and multi-label classification settings. Add 3 extra hparams of `desired_precisions`, `desired_recalls`, and `desired_thresholds` for precision and recall metrics. Keep the existing behavior by having the default `desired_thresholds`=[0.25, 0.5, 0.75]
3. Add best model checkpointing mechanism
4. Add load_model and save_model api
5. Put create_model, metrics, optimizer onto tpu strategy
6. Add **kwargs to evaluate() method to allow setting evaluate(return_dict=True)

PiperOrigin-RevId: 635992287
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 22, 2024
1 parent 40f9bf4 commit b41777d
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 59 deletions.
10 changes: 8 additions & 2 deletions mediapipe/model_maker/python/core/tasks/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,25 @@ def _train_model(
class_weight=self._hparams.class_weights,
)

def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
def evaluate(
self,
data: dataset.Dataset,
batch_size: int = 32,
**kwargs: dict[str, Any],
) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
Args:
data: Evaluation dataset
batch_size: Number of samples per evaluation step.
**kwargs: Additional arguments to pass to `model.evaluate`.
Returns:
The loss value and accuracy.
"""
ds = data.gen_tf_dataset(
batch_size, is_training=False, preprocess=self._preprocess)
return self._model.evaluate(ds)
return self._model.evaluate(ds, **kwargs)

def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'):
"""Exports classification labels into a label file.
Expand Down
25 changes: 24 additions & 1 deletion mediapipe/model_maker/python/core/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,30 @@ def update_state(self, y_true, y_pred, sample_weight=None):


class BinaryAUC(tf.keras.metrics.AUC):
"""A Binary AUC metric for binary classification tasks.
"""A Binary AUC metric for multi-label tasks.
class_id is the index of the class/label that we want to compute Binary AUC
for.
For update state, the shapes of y_true and y_pred are expected to be:
- y_true: [batch_size x num_classes] array of one-hot encoded labels (note,
these could be in a multi-label setting where the sum of y_true can be > 1)
- y_pred: [batch_size x num_classes] array of probabilities where
y_pred[:,i] is the probability of the i-th class.
"""

def __init__(self, *args, class_id: int = 1, **kwargs):
super().__init__(*args, **kwargs)
self._class_id = class_id

def update_state(self, y_true, y_pred, sample_weight=None):
super().update_state(
y_true[:, self._class_id], y_pred[:, self._class_id], sample_weight
)


class BinarySparseAUC(tf.keras.metrics.AUC):
"""A Binary Sparse AUC metric for binary classification tasks.
For update state, the shapes of y_true and y_pred are expected to be:
- y_true: [batch_size x 1] array of 0 for negatives and 1 for positives
Expand Down
14 changes: 11 additions & 3 deletions mediapipe/model_maker/python/core/utils/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ def setUp(self):
[0.3, 0.7], # 1, 1 y
])

def _assert_metric_equals(self, metric, value):
metric.update_state(self.y_true, self.y_pred)
def _assert_metric_equals(self, metric, value, sparse=True):
if not sparse:
y_true = tf.one_hot(self.y_true, 2)
metric.update_state(y_true, self.y_pred)
else:
metric.update_state(self.y_true, self.y_pred)
self.assertEqual(metric.result(), value)

def test_sparse_recall(self):
Expand Down Expand Up @@ -70,7 +74,11 @@ def test_binary_sparse_precision_at_recall_class_id_error(self):
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)

def test_binary_auc(self):
metric = metrics.BinaryAUC(num_thresholds=1000)
metric = metrics.BinaryAUC(num_thresholds=1000, class_id=1)
self._assert_metric_equals(metric, 0.7222222, sparse=False)

def test_binary_sparse_auc(self):
metric = metrics.BinarySparseAUC(num_thresholds=1000)
self._assert_metric_equals(metric, 0.7222222)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _create_metrics(self):
]
if self._num_classes == 2:
metric_functions.extend([
metrics.BinaryAUC(name="auc", num_thresholds=1000),
metrics.BinarySparseAUC(name="auc", num_thresholds=1000),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
])
Expand Down
1 change: 1 addition & 0 deletions mediapipe/model_maker/python/vision/image_classifier/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ py_library(
":model_spec",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Hyperparameters for training image classification models."""

import dataclasses
from typing import Optional, Sequence

from mediapipe.model_maker.python.core import hyperparameters as hp

Expand Down Expand Up @@ -42,6 +43,19 @@ class HParams(hp.BaseHParams):
checkpoint_frequency: Frequency to save checkpoint.
one_hot: Whether the label data is score input or one-hot.
multi_labels: Whether the model predict multi labels.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary and multi-label classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary and multi-label classification.
desired_thresholds: If specified, adds a Precision and Recall metric per
desired_thresholds[i] entry which tracks the precision and recall given
the constraint on threshold. Only supported for binary and multi-label
classification.
best_model_metric_name: If specified, adds a callback that saves the model
with the best `best_model_metric_name` metric during training. Typically
these will be validation metrics such as `val_accuracy` and `val_auc`.
"""
# Parameters from BaseHParams class.
learning_rate: float = 0.001
Expand All @@ -59,3 +73,8 @@ class HParams(hp.BaseHParams):
checkpoint_frequency: int = 1
one_hot: bool = True
multi_labels: bool = False
# Binary only precision/recalls
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
desired_thresholds: Sequence[float] = (0.25, 0.5, 0.75)
best_model_metric_name: Optional[str] = None
Loading

0 comments on commit b41777d

Please sign in to comment.