Skip to content

Commit

Permalink
added filter function for removing layers from tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Mar 27, 2023
1 parent 55bec2c commit 7179079
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion delve/torchcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class SaturationTracker(object):
Per default, only Conv2D,
Linear and LSTM-Cells
are recorded
layer_filter (func): A filter function that is used to avoid layers from being tracked.
This is function receiving a dictionary as input and returning
it with undesired entries removed.
The dictionary contains string keys mapping to torch.nn.Module objects.
writers_args (dict) : contains additional arguments passed over to the
writers. This is only used, when a writer is
initialized through a string-key.
Expand Down Expand Up @@ -169,6 +173,7 @@ def __init__(self,
savefile: str,
save_to: Union[str, delve.writers.AbstractWriter],
modules: torch.nn.Module,
layer_filter: Callable[[Dict[str, nn.Module]], Dict[str, nn.Module]] = lambda x: x,
writer_args: Optional[Dict[str, Any]] = None,
log_interval=1,
max_samples=None,
Expand All @@ -195,7 +200,7 @@ def __init__(self,

self.timeseries_method = timeseries_method
self.threshold = sat_threshold
self.layers = self.get_layers_recursive(modules)
self.layers = layer_filter(self.get_layers_recursive(modules))
self.max_samples = max_samples
self.log_interval = log_interval
self.reset_covariance = reset_covariance
Expand Down

0 comments on commit 7179079

Please sign in to comment.