diff --git a/delve/torchcallback.py b/delve/torchcallback.py index 5ba5c73..aeafda5 100644 --- a/delve/torchcallback.py +++ b/delve/torchcallback.py @@ -49,6 +49,10 @@ class SaturationTracker(object): Per default, only Conv2D, Linear and LSTM-Cells are recorded + layer_filter (func): A filter function that is used to avoid layers from being tracked. + This is function receiving a dictionary as input and returning + it with undesired entries removed. + The dictionary contains string keys mapping to torch.nn.Module objects. writers_args (dict) : contains additional arguments passed over to the writers. This is only used, when a writer is initialized through a string-key. @@ -169,6 +173,7 @@ def __init__(self, savefile: str, save_to: Union[str, delve.writers.AbstractWriter], modules: torch.nn.Module, + layer_filter: Callable[[Dict[str, nn.Module]], Dict[str, nn.Module]] = lambda x: x, writer_args: Optional[Dict[str, Any]] = None, log_interval=1, max_samples=None, @@ -195,7 +200,7 @@ def __init__(self, self.timeseries_method = timeseries_method self.threshold = sat_threshold - self.layers = self.get_layers_recursive(modules) + self.layers = layer_filter(self.get_layers_recursive(modules)) self.max_samples = max_samples self.log_interval = log_interval self.reset_covariance = reset_covariance