Skip to content

Commit

Permalink
snake_case EventAccumulator methods
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jan 1, 2024
1 parent 09de0a8 commit dec30d2
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/paper.yml
Expand Up @@ -13,7 +13,7 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Build draft PDF
uses: openjournals/openjournals-draft-action@master
Expand Down
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Expand Up @@ -7,14 +7,14 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2
rev: v0.1.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.8.0
hooks:
- id: mypy

Expand All @@ -36,6 +36,7 @@ repos:
- id: codespell
stages: [commit, commit-msg]
exclude_types: [jupyter, bib]
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
Expand Down
6 changes: 3 additions & 3 deletions examples/wandb_integration.ipynb
Expand Up @@ -262,7 +262,7 @@
"class ConvNet(nn.Module):\n",
" \"\"\"Just your average CNN.\"\"\"\n",
"\n",
" def __init__(self, kernels, classes=10):\n",
" def __init__(self, kernels: list, classes: int = 10) -> None:\n",
" super().__init__()\n",
"\n",
" self.layer1 = nn.Sequential(\n",
Expand All @@ -277,7 +277,7 @@
" )\n",
" self.fc = nn.Linear(7 * 7 * kernels[-1], classes)\n",
"\n",
" def forward(self, x):\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" out = self.layer1(x)\n",
" out = self.layer2(out)\n",
" out = out.reshape(out.size(0), -1)\n",
Expand Down Expand Up @@ -343,7 +343,7 @@
" optimizer: torch.optim.Optimizer,\n",
" epochs: int,\n",
" log_freq: int = 10,\n",
"):\n",
") -> None:\n",
" # Run training and track with wandb\n",
" sample_count = batch_count = 0 # number of examples seen\n",
"\n",
Expand Down
54 changes: 19 additions & 35 deletions pyproject.toml
Expand Up @@ -65,45 +65,29 @@ warn_unused_ignores = true
[tool.ruff]
target-version = "py38"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FA", # flake8-future-annotations
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"PD", # pandas-vet
"PERF", # perflint
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"TCH", # flake8-type-checking
"TID", # tidy imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
select = ["ALL"]
ignore = [
"D100", # Missing docstring in public module
"D205", # 1 blank line required between summary line and description
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
"ANN101",
"ANN401",
"ARG001",
"C901",
"COM812",
"D100", # Missing docstring in public module
"D205", # 1 blank line required between summary line and description
"DTZ005",
"EM101",
"EM102",
"FBT001",
"FBT002",
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
"PTH",
"T201",
"TRY003",
]
pydocstyle.convention = "google"

[tool.ruff.per-file-ignores]
"tests/*" = ["D103", "D104"]
"tests/*" = ["D103", "D104", "INP001", "S101"]
"__init__.py" = ["F401"]
"examples/*" = ["D102", "D103", "D107", "E402", "FA102"]
28 changes: 14 additions & 14 deletions tensorboard_reducer/event_loader.py
Expand Up @@ -58,15 +58,15 @@ def __init__(self, path: str) -> None:
path (str): The path to the event file.
"""
self._first_event_timestamp = None
self.scalars = reservoir.Reservoir(size=10000)
self._scalars = reservoir.Reservoir(size=10000)

self._generator_mutex = threading.Lock()
self.path = path
self._generator = _GeneratorFromPath(path)
self._generator = _generator_from_path(path)

self.file_version: float | None = None

def Reload(self) -> EventAccumulator:
def reload(self) -> EventAccumulator:
"""Synchronously load all events added since last calling Reload. If Reload was
never called, loads all events in the file.
Expand All @@ -75,24 +75,24 @@ def Reload(self) -> EventAccumulator:
"""
with self._generator_mutex:
for event in self._generator.Load():
self._ProcessEvent(event)
self._process_event(event)
return self

def _ProcessEvent(self, event: Event) -> None:
def _process_event(self, event: Event) -> None:
"""Called whenever an event is loaded."""
if self._first_event_timestamp is None:
self._first_event_timestamp = event.wall_time

if event.HasField("file_version"):
new_file_version = _ParseFileVersion(event.file_version)
new_file_version = _parse_file_version(event.file_version)
self.file_version = new_file_version

if event.HasField("summary"):
for value in event.summary.value:
if value.HasField("simple_value"):
datum = value.simple_value
tag = value.tag
self._ProcessScalar(tag, event.wall_time, event.step, datum)
self._process_scalar(tag, event.wall_time, event.step, datum)

@property
def scalar_tags(self) -> list[str]:
Expand All @@ -101,9 +101,9 @@ def scalar_tags(self) -> list[str]:
Returns:
list[str]: All scalar tags
"""
return self.scalars.Keys()
return self._scalars.Keys()

def Scalars(self, tag: str) -> tuple[ScalarEvent, ...]:
def scalars(self, tag: str) -> tuple[ScalarEvent, ...]:
"""Given a summary tag, return all associated ScalarEvents.
Args:
Expand All @@ -115,17 +115,17 @@ def Scalars(self, tag: str) -> tuple[ScalarEvent, ...]:
Returns:
tuple[ScalarEvent, ...]: An array of ScalarEvents.
"""
return self.scalars.Items(tag)
return self._scalars.Items(tag)

def _ProcessScalar(
def _process_scalar(
self, tag: str, wall_time: float, step: int, scalar: float
) -> None:
"""Process a simple value by adding it to accumulated state."""
sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
self.scalars.AddItem(tag, sv)
self._scalars.AddItem(tag, sv)


def _GeneratorFromPath(path: str) -> directory_watcher.DirectoryWatcher:
def _generator_from_path(path: str) -> directory_watcher.DirectoryWatcher:
"""Create an event generator for file or directory at given path string."""
return directory_watcher.DirectoryWatcher(
path,
Expand All @@ -134,7 +134,7 @@ def _GeneratorFromPath(path: str) -> directory_watcher.DirectoryWatcher:
)


def _ParseFileVersion(file_version: str) -> float:
def _parse_file_version(file_version: str) -> float:
"""Convert the string file_version in event.proto into a float.
Args:
Expand Down
13 changes: 7 additions & 6 deletions tensorboard_reducer/load.py
Expand Up @@ -63,7 +63,7 @@ def load_tb_events(
# EventAccumulator that only loads scalars and ignores histograms, images and other
# time-consuming data.
accumulators = [
EventAccumulator(dirname).Reload()
EventAccumulator(dirname).reload()
for dirname in tqdm(input_dirs, disable=not verbose, desc="Loading runs")
]

Expand Down Expand Up @@ -102,7 +102,7 @@ def load_tb_events(

for tag in accumulator.scalar_tags:
# accumulator.Scalars() returns columns 'step', 'wall_time', 'value'
df_scalar = pd.DataFrame(accumulator.Scalars(tag)).set_index("step")
df_scalar = pd.DataFrame(accumulator.scalars(tag)).set_index("step")
df_scalar = df_scalar.drop(columns="wall_time")

if handle_dup_steps is None and not df_scalar.index.is_unique:
Expand Down Expand Up @@ -143,10 +143,11 @@ def load_tb_events(
"shortest run (same behavior as zip())."
)

assert len(load_dict) > 0, (
f"Got {len(input_dirs)} input directories but no TensorBoard event files "
"found inside them."
)
if len(load_dict) == 0:
raise FileNotFoundError(
f"Got {len(input_dirs)} input directories but no TensorBoard event files "
"found inside them."
)

out_dict: dict[str, pd.DataFrame] = {}

Expand Down
2 changes: 1 addition & 1 deletion tensorboard_reducer/write.py
Expand Up @@ -31,7 +31,7 @@ def _rm_rf_or_raise(path: str, overwrite: bool) -> None:
is_data_file = any(ext in path.lower() for ext in _known_extensions)

if overwrite and (is_data_file or is_tb_dir):
os.system(f"rm -rf {path}")
os.system(f"rm -rf {path}") # noqa: S605
elif overwrite:
ValueError(
f"Received the overwrite flag but the content of '{path}' does not "
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -22,6 +22,6 @@ def events_dict() -> dict[str, pd.DataFrame]:

@pytest.fixture(scope="session")
def reduced_events(
events_dict: dict[str, pd.DataFrame]
events_dict: dict[str, pd.DataFrame],
) -> dict[str, dict[str, pd.DataFrame]]:
return tbr.reduce_events(events_dict, REDUCE_OPS)
2 changes: 1 addition & 1 deletion tests/test_write.py
Expand Up @@ -99,7 +99,7 @@ def test_write_data_file(


def test_write_data_file_with_bad_ext(
reduced_events: dict[str, dict[str, pd.DataFrame]]
reduced_events: dict[str, dict[str, pd.DataFrame]],
) -> None:
with pytest.raises(ValueError, match="has unknown extension, should be one of"):
tbr.write_data_file(reduced_events, "foo.bad_ext")
Expand Down

0 comments on commit dec30d2

Please sign in to comment.