diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 7e0ef433031bdf..c887ef1befc71f 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -403,26 +403,19 @@ def log( # register logged value if it doesn't exist if key not in self: - self.register_key(key, meta, value) + metric = _ResultMetric(meta, isinstance(value, Tensor)) + self[key] = metric # check the stored metadata and the current one match elif meta != self[key].meta: raise MisconfigurationException( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) + self[key].to(value.device) batch_size = self._extract_batch_size(self[key], batch_size, meta) self.update_metrics(key, value, batch_size) - def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None: - """Create one _ResultMetric object per value. - - Value can be provided as a nested collection - - """ - metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device) - self[key] = metric - def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric = self[key] # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`