diff --git a/detectron2/utils/events.py b/detectron2/utils/events.py index a37d689d08..e9c2640c32 100644 --- a/detectron2/utils/events.py +++ b/detectron2/utils/events.py @@ -6,6 +6,7 @@ import time from collections import defaultdict from contextlib import contextmanager +from functools import cached_property from typing import Optional import torch from fvcore.common.history_buffer import HistoryBuffer @@ -142,10 +143,14 @@ def __init__(self, log_dir: str, window_size: int = 20, **kwargs): kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` """ self._window_size = window_size + self._writer_args = {"log_dir": log_dir, **kwargs} + self._last_write = -1 + + @cached_property + def _writer(self): from torch.utils.tensorboard import SummaryWriter - self._writer = SummaryWriter(log_dir, **kwargs) - self._last_write = -1 + return SummaryWriter(**self._writer_args) def write(self): storage = get_event_storage() @@ -174,7 +179,7 @@ def write(self): storage.clear_histograms() def close(self): - if hasattr(self, "_writer"): # doesn't exist when the code fails at import + if "_writer" in self.__dict__: self._writer.close() diff --git a/tests/utils/test_tensorboardx.py b/tests/utils/test_tensorboardx.py new file mode 100644 index 0000000000..885fb8d357 --- /dev/null +++ b/tests/utils/test_tensorboardx.py @@ -0,0 +1,23 @@ +import os +import tempfile +import unittest + +from detectron2.utils.events import TensorboardXWriter + + +# TODO Fix up capitalization +class TestTensorboardXWriter(unittest.TestCase): + def test_no_files_created(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + writer = TensorboardXWriter(tmp_dir) + writer.close() + + self.assertFalse(os.listdir(tmp_dir)) + + def test_single_write(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + writer = TensorboardXWriter(tmp_dir) + writer._writer.add_scalar("testing", 1, 1) + writer.close() + + self.assertTrue(os.listdir(tmp_dir))