You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Will this change the current api? How?
Maybe. To do this cleanly it may be necessary to support metrics that are non-scalar values. It's possible this would open the door to generating other data too - e.g. images and audio. It could make more sense to give this a more appropriate name than "metric".
Who will benefit from this feature?
Anyone training classifier models using tensorflow and tensorboard
Basically I had to generate "metrics" that are actually tensors with the curve data, then I had to override the Tensorboard callback to handle those metrics specially. For now the curve metrics are just matched by name. I also added the ability to pull in an AUC metric by name and display it in the description on the graph.
Rough example of usage:
import tensorflow as tf
from keras_tensorboard_pr_curves import PRCurve, ROCCurve, TensorBoardPRCurves
n_thresholds=200
metrics = [
tf.keras.metrics.AUC(name='AUC', curve='ROC'),
tf.keras.metrics.AUC(name='AUPR', curve='PR'),
PRCurve(name='pr_curve'),
ROCCurve(name='roc_curve'),
]
callbacks = [TensorBoardPRCurves(log_dir='/logs', pr_curve_names=['pr_curve', 'roc_curve'], auc_names=['AUPR', 'AUC'])]
# set up model as usual:
# model = ...
# model.compile(..., metrics=metrics)
# model.fit(..., callbacks=callbacks)
Some of the logic for thresholds was duplicated from the AUC metric implementation and may be better as shared code
This implementation only supports generating curves each epoch at the moment. For completeness it should probably correctly respond to the update_freq parameter (with the caveat that the first few batches will probably produce a very erratic curve shape until the confusion matrix has tallied up more examples).
If pr_curve_names is set incorrectly in the callback then the data generated by the PRCurve / ROCCurve metrics may cause unexpected errors in other code expecting scalar metrics.
The text was updated successfully, but these errors were encountered:
We're currently in the process of migrating the new Keras 3 code base from keras-team/keras-core to keras-team/keras.
Consequently, This issue may not be relevant to the Keras 3 code base . After the migration is successfully completed, feel free to reopen this Issue at keras-team/keras if you believe it remains relevant to the Keras 3 code base. If instead this Issue is a bug or security issue in legacy tf.keras, you can instead report a new issue at keras-team/tf-keras, which hosts the TensorFlow-only, legacy version of Keras.
System information.
TensorFlow version (you are using): 2.8.0
Are you willing to contribute it (Yes/No) : Yes
Describe the feature and the current behavior/state.
Tensorboard has a nifty plugin that can plot precision-recall curves for classifiers, it would be great if Keras supported this out of the box
https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/pr_curve
Will this change the current api? How?
Maybe. To do this cleanly it may be necessary to support metrics that are non-scalar values. It's possible this would open the door to generating other data too - e.g. images and audio. It could make more sense to give this a more appropriate name than "metric".
Who will benefit from this feature?
Anyone training classifier models using tensorflow and tensorboard
Contributing
Yes (I may need assistance if this involves significant internal changes to Keras though)
Proof of concept implementation here:
https://gist.github.com/edbordin/8fe7398af2495d13d81164db885a3531
Basically I had to generate "metrics" that are actually tensors with the curve data, then I had to override the Tensorboard callback to handle those metrics specially. For now the curve metrics are just matched by name. I also added the ability to pull in an AUC metric by name and display it in the description on the graph.
Rough example of usage:
Example output:
Limitations:
update_freq
parameter (with the caveat that the first few batches will probably produce a very erratic curve shape until the confusion matrix has tallied up more examples).pr_curve_names
is set incorrectly in the callback then the data generated by thePRCurve
/ROCCurve
metrics may cause unexpected errors in other code expecting scalar metrics.The text was updated successfully, but these errors were encountered: