From ef2721575cfe5993ceba94c455903f3bf1356831 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 20:33:42 +0100 Subject: [PATCH] DO not clone (when caching) Tensor states, but retain references to avoid memory leakage --- src/torchmetrics/metric.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 12cf33cdd82..7803c0e5f71 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = self._copy_state_dict() + cache = self._save_state_dict() # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = self._copy_state_dict() + global_state = self._save_state_dict() _update_count = self._update_count self.reset() @@ -531,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = self._copy_state_dict() + self._cache = self._save_state_dict() # sync self._sync_dist(dist_sync_fn, process_group=process_group) @@ -876,17 +876,19 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: - """Copy the current state values.""" + def _save_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + """Save the current state values, retaining references to Tensor values""" + + # do not .clone() Tensor values, as new objects leak memory cache = {} for attr in self._defaults: current_value = getattr(self, attr) if isinstance(current_value, Tensor): - cache[attr] = current_value.detach().clone().to(current_value.device) + cache[attr] = current_value.detach().to(current_value.device) else: cache[attr] = [ # safely copy (non-graph leaf) Tensor elements - _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value + _.detach().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value ] return cache