Skip to content

Commit

Permalink
DO not clone (when caching) Tensor states, but retain references to a…
Browse files Browse the repository at this point in the history
…void memory leakage
  • Loading branch information
dominicgkerr committed Apr 13, 2024
1 parent 76cc0a1 commit ef27215
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
self.compute_on_cpu = False

# save context before switch
cache = self._copy_state_dict()
cache = self._save_state_dict()

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
Expand Down Expand Up @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""
# store global state and reset to default
global_state = self._copy_state_dict()
global_state = self._save_state_dict()
_update_count = self._update_count
self.reset()

Expand Down Expand Up @@ -531,7 +531,7 @@ def sync(
dist_sync_fn = gather_all_tensors

# cache prior to syncing
self._cache = self._copy_state_dict()
self._cache = self._save_state_dict()

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
Expand Down Expand Up @@ -876,17 +876,19 @@ def state_dict( # type: ignore[override] # todo
destination[prefix + key] = deepcopy(current_val)
return destination

def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
"""Copy the current state values."""
def _save_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
"""Save the current state values, retaining references to Tensor values"""

# do not .clone() Tensor values, as new objects leak memory
cache = {}
for attr in self._defaults:
current_value = getattr(self, attr)

if isinstance(current_value, Tensor):
cache[attr] = current_value.detach().clone().to(current_value.device)
cache[attr] = current_value.detach().to(current_value.device)
else:
cache[attr] = [ # safely copy (non-graph leaf) Tensor elements
_.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
_.detach().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
]

return cache
Expand Down

0 comments on commit ef27215

Please sign in to comment.