Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predict with Monte Carlo dropout #2132

Merged
merged 5 commits into from Jun 2, 2019

Conversation

Projects
None yet
2 participants
@mrdbourke
Copy link
Contributor

commented Jun 1, 2019

I've added some functionality to use Monte Carlo dropout at prediction time (forum link).

This technique can be used to measure model uncertainty in neural networks.

The predict_wth_mc_dropout(n_times:int=10) can be called on an instance of Learner() and returns a list of predictions (with length n_times) all made with dropout turned on.

The variance of these predictions can be used to measure how certain the model is about a prediction.

High variance = high uncertainty
Low variance = low uncertainty

Inspiration for this was gathered from two main resources:

We've been using it on a text classification problem at Max Kelsen and it has proven very valuable.

Examples:

# without updated code
classifier.predict("I really loved that movie, it was awesome!")

(Category pos, tensor(1), tensor([0.0313, 0.9687]))
# without updated code
classifier.predict_with_mc_dropout("I really loved that movie, it was awesome!")

AttributeError: 'RNNLearner' object has no attribute 'predict_with_mc_dropout'
# with updated code
classifier.predict_with_mc_dropout("I really loved that movie, it was awesome!", n_times=10)

[(Category pos, tensor(1), tensor([0.2963, 0.7037])),
 (Category pos, tensor(1), tensor([0.0274, 0.9726])),
 (Category pos, tensor(1), tensor([0.0191, 0.9809])),
 (Category pos, tensor(1), tensor([0.0320, 0.9680])),
 (Category pos, tensor(1), tensor([0.0429, 0.9571])),
 (Category neg, tensor(0), tensor([0.7288, 0.2712])),
 (Category pos, tensor(1), tensor([0.0134, 0.9866])),
 (Category pos, tensor(1), tensor([0.0539, 0.9461])),
 (Category pos, tensor(1), tensor([0.0962, 0.9038])),
 (Category pos, tensor(1), tensor([0.0204, 0.9796]))]

mrdbourke and others added some commits Jun 1, 2019

@sgugger

This comment has been minimized.

Copy link
Collaborator

commented Jun 1, 2019

Thanks for your PR!
I've just restyled some if statements that fit on one line (our usual style guide), used a list comprehension to compute the predictions instead of a for loop and removed unused arguments in your new predict method. Let me know if anything feels off.

sgugger added some commits Jun 1, 2019

@mrdbourke

This comment has been minimized.

Copy link
Contributor Author

commented Jun 2, 2019

Thanks for the restyle @sgugger

All looks great to me!

@sgugger sgugger merged commit 937d67f into fastai:master Jun 2, 2019

5 of 8 checks passed

fastai.fastai Build #20190601.4 failed
Details
fastai.fastai (Linux_PyPI Python36) Linux_PyPI Python36 failed
Details
fastai.fastai (Linux_PyPI Python37) Linux_PyPI Python37 failed
Details
fastai.fastai (Linux_Conda Python36) Linux_Conda Python36 succeeded
Details
fastai.fastai (Linux_Conda Python37) Linux_Conda Python37 succeeded
Details
fastai.fastai (MacOS_Conda Python36) MacOS_Conda Python36 succeeded
Details
fastai.fastai (MacOS_Conda Python37) MacOS_Conda Python37 succeeded
Details
fastai.fastai (nbstripout_config) nbstripout_config succeeded
Details
@sgugger

This comment has been minimized.

Copy link
Collaborator

commented Jun 2, 2019

Good then, thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.