Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching mechanism for MCDropout #254

Merged
merged 13 commits into from Jul 13, 2023
Merged

Conversation

Dref360
Copy link
Member

@Dref360 Dref360 commented Mar 4, 2023

Example with VGG16, would be interesting to see the speedup on segmentation model.

import torch
from torchvision.models import vgg16
from tqdm import tqdm

from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule

vgg = vgg16().cuda()
vgg.eval()

input = torch.randn(10, 3, 224, 224).cuda()

# Regular: 1:49
# Cached : 20 seconds
with MCCachingModule(vgg) as model:
    with MCDropoutModule(model) as model_2:
        [model_2(input).detach().cpu() for _ in tqdm(range(1000))]

TODO:

  • Void cache when retraining
  • Testing on actual real data.
  • Test
  • Documentation

@Dref360 Dref360 marked this pull request as ready for review May 1, 2023 16:58
@Dref360 Dref360 requested a review from parmidaatg June 27, 2023 22:04
@Dref360
Copy link
Member Author

Dref360 commented Jun 27, 2023

Need to add documentation, but ready for review.

@Dref360 Dref360 merged commit 89eaf09 into master Jul 13, 2023
2 checks passed
@Dref360 Dref360 deleted the feat/caching_montecarlo_2 branch July 13, 2023 21:29
@Dref360 Dref360 restored the feat/caching_montecarlo_2 branch July 13, 2023 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant