Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to find probability for each class in case of multi-class classification #693

Merged

Conversation

amansrivastava17
Copy link
Contributor

@amansrivastava17 amansrivastava17 commented Apr 26, 2019

Added feature to get the confidence score of each class in case of Multi-class text classification
Added new param multi_class_prob to function predict to return the confidence score of all classes instead of just returning class with the highest confidence.

Example to use this method

>>> from flair.models import TextClassifier
>>> from flair.data import Sentence

# Loading English pretrained sentiment classification model
>>> classifier = TextClassifier.load('en-sentiment')

>>> sentence = Sentence('I have added a new feature in flair')
>>> classifier.predict(sentence)
>>> print('Sentiment Score: ', sentence.labels)
Sentiment Score:  [POSITIVE (0.9674275517463684)]

# Now if you want confidence score for both the class, you need to pass `multi_class_prob` as `True` to predict method

>>> classifier.predict(sentence, multi_class_prob=True)
>>> print('Sentiment Score: ', sentence.labels)
Sentiment Score:  [POSITIVE (0.9674275517463684), NEGATIVE (0.032572414726018906)]

@amansrivastava17
Copy link
Contributor Author

amansrivastava17 commented Apr 26, 2019

@alanakbik @tabergma @kashif @pranjalsrajput @khituras @stefan-it
Please have a look

@amansrivastava17
Copy link
Contributor Author

@alanakbik Please have a look.

@alanakbik
Copy link
Collaborator

@amansrivastava17 thanks for adding this - lots of people will find this useful. Just to clarify: this only modifies the private function _obtain_labels so most users that just use the predict function will not be able to set a flag to get these probabilities, correct? Could you paste a quick code example of how to use this functionality?

@amansrivastava17
Copy link
Contributor Author

amansrivastava17 commented Apr 30, 2019

@alanakbik I have added code usage in the description above. You were right, the user who uses just predict function will not be able to use this feature, for this, I have added one multi_class_prob, which when set as True will return probability for each class.

@amansrivastava17
Copy link
Contributor Author

@stefan-it Please have a look

@stefan-it
Copy link
Member

Looks good so far, I'll add a review now :)

Copy link
Member

@stefan-it stefan-it left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be awesome if you could also add a relevant integration test for that new method.

For example you could build such an integration test on-top of the imdb test:

https://github.com/zalandoresearch/flair/blob/master/tests/test_model_integration.py#L378

flair/models/text_classification_model.py Outdated Show resolved Hide resolved
flair/models/text_classification_model.py Outdated Show resolved Hide resolved
flair/models/text_classification_model.py Outdated Show resolved Hide resolved
flair/models/text_classification_model.py Outdated Show resolved Hide resolved
flair/models/text_classification_model.py Outdated Show resolved Hide resolved
flair/models/text_classification_model.py Outdated Show resolved Hide resolved
@amansrivastava17
Copy link
Contributor Author

@alanakbik
Copy link
Collaborator

@amansrivastava17 it looks like the new integration test is throwing an error. Could you check? In particular, it looks like you are using the method load_from_file in the test, which you should replace with the load method.

@amansrivastava17
Copy link
Contributor Author

amansrivastava17 commented May 2, 2019

@alanakbik Have fixed integration issue.
Please have a look.

@alanakbik
Copy link
Collaborator

Looks good, thanks!

@alanakbik
Copy link
Collaborator

👍

1 similar comment
@kashif
Copy link
Contributor

kashif commented May 2, 2019

👍

@amansrivastava17
Copy link
Contributor Author

@stefan-it have made changes, Please approve.

@stefan-it stefan-it merged commit a9d6b9a into flairNLP:master May 2, 2019
@stefan-it
Copy link
Member

Thanks for adding this feature :)

@Rajat-Mehta
Copy link

hi @stefan-it, @amansrivastava17,

Is this feature already available to be used in flair?

I tried but failed, it says: "TypeError: predict() got an unexpected keyword argument 'multi_class_prob'".

I am using flair version 0.4.1.

Thanks

@alanakbik
Copy link
Collaborator

It's merged into the master branch and so will be part of upcoming v0.4.2, but currently the only way to use the feature is to use the master branch.

@tsu3010
Copy link

tsu3010 commented May 24, 2019

Hi @alanakbik, When is v0.4.2 scheduled for release?

@alanakbik
Copy link
Collaborator

I am hoping to get v0.4.2 out before NAACL, i.e. end of next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants