Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For scripts/setfit/run_fewshot.py, add warning for class imbalance w. accuracy #204

Merged
merged 2 commits into from
Dec 12, 2022

Conversation

tomaarsen
Copy link
Member

Hello!

Pull request overview

  • For scripts/setfit/run_fewshot.py, add a warning when the accuracy metric is used on an imbalanced test set.

Details

This PR is in response to #203, a PR on which AmazonCF was accidentally evaluated using accuracy rather than MCC. If that script was run again in the same way after this PR, then the following warning would pop up:

...
Test set: 5000
[sic]\setfit\scripts\setfit\run_fewshot.py:112: UserWarning: The test set has a class imbalance (label 1 w. 503 samples, label 0 w. 4497 samples), but is evaluated using `accuracy`, which may lead to an evaluation that does not correspond with true model performance.
  warnings.warn(
...

This would help people like myself from not realising that accuracy is used by default in scripts/setfit/run_fewshot.py, and that not all SetFit testsets are balanced.

Note that this is only for scripts/setfit/run_fewshot.py, i.e. only for development and testing purposes. I don't believe that we should bother actual users with this warning.

The warning is shown when the largest class is at least 50% larger than the smallest class.

cc: @blakechi

  • Tom Aarsen

@@ -99,6 +101,19 @@ def main():
for dataset, metric in dataset_to_metric.items():
few_shot_train_splits, test_data = load_data_splits(dataset, args.sample_sizes, args.add_data_augmentation)

# Potentially report on naive use of accuracy with an imbalanced test set
if metric == "accuracy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we throw a warning whenever there is an unbalanced dataset is a good idea?

For me I would like to know any unbalanced dataset and the corresponding metric. Specifically, maybe we can take off this if statement to allow throwing warnings on other datasets that are unbalanced but with different metrics than "accuracy"?

if largest_n_samples > smallest_n_samples * 1.5:
warnings.warn(
f"The test set has a class imbalance ({', '.join(f'label {label} w. {n_samples} samples' for label, n_samples in label_samples)})"
", but is evaluated using `accuracy`, which may lead to an evaluation that does not correspond with true model performance.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, maybe we can print out no matter what metric the dataset is using?
Like:

Suggested change
", but is evaluated using `accuracy`, which may lead to an evaluation that does not correspond with true model performance.",
f", but is evaluated using `{metric}`, which may lead to an evaluation that does not correspond with true model performance.",

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think those comments are good ideas. It would be more like a logging.info call than a warnings.warn, as it would just be informative at all times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it could be a warning since it pointing out there is an unbalanced dataset?
Just sharing thoughts :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the thoughts! My bad, I thought you proposed to always print the test set distribution, even when balanced.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgive my limited language skills 😂

@blakechi
Copy link
Contributor

Hi @tomaarsen,

Nice work! I left some comments. Would like to know your comments :)

Co-authored-by: blakechi <blakechi.chiaohu@gmail.com>
@tomaarsen
Copy link
Member Author

After implementing the changes proposed by @blakechi, the output is now as follows:

> python .\scripts\setfit\run_fewshot.py --datasets amazon_counterfactual_en


============== amazon_counterfactual_en ============
Using custom data configuration SetFit--amazon_counterfactual_en-83c4b3502f70d59a
Reusing dataset json ([sic])
Using custom data configuration SetFit--amazon_counterfactual_en-83c4b3502f70d59a
Reusing dataset json ([sic])
Test set: 5000
Evaluating amazon_counterfactual_en using 'accuracy'.
[sic]\scripts\setfit\run_fewshot.py:112: UserWarning: The test set has a class imbalance (label 1 w. 503 samples, label 0 w. 4497 samples).
  warnings.warn(


======== [sic]\scripts\setfit\results\paraphrase-mpnet-base-v2-CosineSimilarityLoss-logistic_regression-iterations_20-batch_16\amazon_counterfactual_en\train-2-0 =======
config.json not found in HuggingFace Hub
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
...

Would love to hear your thoughts.

  • Tom Aarsen

Copy link
Contributor

@blakechi blakechi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! :)

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this warning @tomaarsen 🔥 !

We originally wrote the few-shot scripts for our paper's experiments, but adding this warning is indeed helpful for the general use - thanks!

@lewtun lewtun merged commit acb99fd into huggingface:main Dec 12, 2022
@tomaarsen tomaarsen deleted the enhancement/warn_class_imbalance branch December 12, 2022 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants