Skip to content

Commit

Permalink
Add warnings when using replicate_in_memory=True and MCCachingModule (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed Sep 8, 2023
1 parent da99ee3 commit 9d33c5d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 1 deletion.
3 changes: 3 additions & 0 deletions baal/modelwrapper.py
Expand Up @@ -18,6 +18,7 @@
from baal.utils.cuda_utils import to_cuda
from baal.utils.iterutils import map_on_tensor
from baal.utils.metrics import Loss
from baal.utils.warnings import raise_warnings_cache_replicated

log = structlog.get_logger("ModelWrapper")

Expand Down Expand Up @@ -49,6 +50,8 @@ def __init__(self, model, criterion, replicate_in_memory=True):
self.replicate_in_memory = replicate_in_memory
self._active_dataset_size = -1

raise_warnings_cache_replicated(self.model, replicate_in_memory=replicate_in_memory)

def train_on_dataset(
self,
dataset,
Expand Down
19 changes: 19 additions & 0 deletions baal/utils/warnings.py
@@ -0,0 +1,19 @@
import warnings

from torch import nn

from baal.bayesian.caching_utils import LRUCacheModule

WARNING_CACHE_REPLICATED = """
To use MCCachingModule at maximum effiency, we recommend using
`replicate_in_memory=False`, but it is `True`.
"""


def raise_warnings_cache_replicated(module, replicate_in_memory):
if (
isinstance(module, nn.Module)
and replicate_in_memory
and any(isinstance(m, LRUCacheModule) for m in module.modules())
):
warnings.warn(WARNING_CACHE_REPLICATED, UserWarning)
4 changes: 3 additions & 1 deletion notebooks/mccaching_layer.ipynb
Expand Up @@ -83,7 +83,9 @@
"source": [
"## Introducing MCCachingModule!\n",
"\n",
"By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!"
"By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!\n",
"\n",
"**NOTE**: You should *always* use `ModelWrapper(..., replicate_in_memory=False)` when in combination with `MCCachingModule`."
],
"metadata": {
"collapsed": false
Expand Down
13 changes: 13 additions & 0 deletions tests/bayesian/test_caching.py
@@ -1,7 +1,10 @@
import warnings

import pytest
import torch
from torch.nn import Sequential, Linear

from baal import ModelWrapper
from baal.bayesian.caching_utils import MCCachingModule


Expand Down Expand Up @@ -50,3 +53,13 @@ def test_caching(my_model):
assert LinearMocked.call_count == 20


def test_caching_warnings(my_model):
my_model = MCCachingModule(my_model)
with warnings.catch_warnings(record=True) as tape:
ModelWrapper(my_model, criterion=None, replicate_in_memory=True)
assert len(tape) == 1 and "MCCachingModule" in str(tape[0].message)

with warnings.catch_warnings(record=True) as tape:
ModelWrapper(my_model, criterion=None, replicate_in_memory=False)
assert len(tape) == 0

0 comments on commit 9d33c5d

Please sign in to comment.