# Speeding up Monte-Carlo Inference With MCCachingModule

It is common knowledge that running MCDropout is slow and computationally expensive.
Baal proposes a new simple API called `MCCachingModule` to speedup MCDropout by more than 70%!

**TLDR: MCCachingWrapper**

```python
>>> from baal.bayesian.caching_utils import MCCachingModule
>>> # Regular code to perform MCDropout with Baal.
>>> model = MCDropoutModule(original_module)
>>> # To gain 70% speedup, simply do
>>> model = MCCachingModule(model)
```

Below we detail our approach in this toy example. We will use a `VGG16` model and run MCDropout for 20 iterations on the test set of CIFAR10.

We get the following results on a GeForce 1060Ti:

| Number of Iteration | 20       | 50       | 100      |
|---------------------|----------|----------|----------|
| Regular MC-Dropout  | 2:58     | 7:27     | 13:45    |
| Ours                | **0:50** | **1:46** | **3:32** |

We are excited to see how the community uses this new feature!

### Code!

In [8]:
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16
from torchvision.transforms import ToTensor

from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule
from baal.modelwrapper import ModelWrapper

ITERATIONS = 20

vgg = vgg16().cuda()
vgg.eval()

ds = CIFAR10('/tmp', train=False, transform=ToTensor(), download=True)

# Takes ~2:58 minutes.
with MCDropoutModule(vgg) as model_2:
    wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
    wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)

## Introducing MCCachingModule!

By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!

**NOTE**: You should *always* use `ModelWrapper(..., replicate_in_memory=False)` when in combination with `MCCachingModule`.

In [9]:
# Takes ~50 seconds!.
with MCCachingModule(vgg) as model:
    with MCDropoutModule(model) as model_2:
        wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
        wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)