-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds mechanism for calibrating probabilities for category and binary features #1949
Adds mechanism for calibrating probabilities for category and binary features #1949
Conversation
19d6382
to
08cc421
Compare
73a979d
to
97058c0
Compare
0e5f817
to
624ede4
Compare
cb0419f
to
e84124f
Compare
In the image I can't see the blue points, do they overlap with the green ones? |
Thats right, in this example the matrix scaling got rolled back to the original uncalibrated probabilities. |
ludwig/models/calibrator.py
Outdated
self.batch_size = batch_size | ||
self.skip_save_model = skip_save_model | ||
|
||
def calibration(self, dataset, dataset_name: str, save_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% sure, but it seems to me that in some cases calibration is the act of calibrating outputs, in some cases calibration is the object/function that can calibrate and in some other cases it is the act of training the calibrator.
We probably should be a bit more explicit to avoid confusion.
Wdyt about train_calibrator
, get_calibrated_probabilities_from_logits
or something else?
Those may be too verbose, but we can brainstorm a bit on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like train_calibrator
, that is clear. get_calibrated_probabilities_from_logits
is a good name IMO, where are you thinking that should go (which method would be renamed to this)? Right now it happens in the PredictModule which just calls the forward
method of the calibration module.
d6dcd7c
to
52085ed
Compare
47274d9
to
fb4c31b
Compare
b5fcd40
to
32b8011
Compare
…ogits in batch_predict.
…if no validation set available.
81ca5b7
to
aaea9d8
Compare
Implements:
temperature scaling for binary and category outputs
matrix scaling for category outputs
Example using Twitter Bots w categorical output feature:
On this example, matrix scaling does not improve ECE, though it improves NLL (probably due to overfitting) so matrix scaling gets rolled back and we get the original uncalibrated probs. On much larger datasets matrix scaling may yield better results.
Validation Set (used to determine softmax temperature):
![calibration_1_vs_all_1](https://user-images.githubusercontent.com/687280/168711948-16853a07-16de-4b70-909d-fdbde61312f8.png)
Test set:
![calibration_1_vs_all_1](https://user-images.githubusercontent.com/687280/168711933-82d41bb0-258a-4fea-8453-5c83a140ef3e.png)