Skip to content

Commit

Permalink
[ADD] Forbid subsampling with repeated indices (#16)
Browse files Browse the repository at this point in the history
* [REF] Rename file containing argument checks

* [ADD] Forbid sub-sampling with repeated indices

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
  • Loading branch information
f-dangel and f-dangel committed Mar 24, 2022
1 parent 50c3551 commit f1b9766
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/linalg/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PROBLEMS = PROBLEMS_REDUCTION_MEAN
IDS = IDS_REDUCTION_MEAN

SUBSAMPLINGS = [None, [0, 0, 1, 0, 1]]
SUBSAMPLINGS = [None, [1, 0]]
SUBSAMPLINGS_IDS = [f"subsampling={sub}" for sub in SUBSAMPLINGS]

PARAM_GROUPS_FN = PARAM_BLOCKS_FN
Expand Down
14 changes: 12 additions & 2 deletions test/utils/test_param_groups.py → test/utils/test_checks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Test ``vivit.utils.param_groups``."""
"""Test ``vivit.utils.checks``."""

from pytest import raises
from torch import rand

from vivit.utils.param_groups import check_key_exists, check_unique_params
from vivit.utils.checks import (
check_key_exists,
check_subsampling_unique,
check_unique_params,
)


def test_missing_key():
Expand All @@ -24,3 +28,9 @@ def test_unique_params():

with raises(ValueError):
check_unique_params(param_groups)


def test_subsampling_unique():
"""Test detection of dduplicate sub-sampling inddices."""
with raises(ValueError):
check_subsampling_unique([0, 0, 1])
8 changes: 7 additions & 1 deletion vivit/linalg/eigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from torch.nn import Module, Parameter

from vivit.linalg.utils import get_hook_store_batch_size, get_vivit_extension, normalize
from vivit.utils.checks import (
check_key_exists,
check_subsampling_unique,
check_unique_params,
)
from vivit.utils.gram import reshape_as_square
from vivit.utils.hooks import ParameterGroupsHook
from vivit.utils.param_groups import check_key_exists, check_unique_params


class EighComputation:
Expand Down Expand Up @@ -44,6 +48,8 @@ def __init__(
value is smaller. Defaults to ``1e-4``. You can disable the warning by
setting it to ``0`` (not recommended).
"""
check_subsampling_unique(subsampling)

self._subsampling = subsampling
self._mc_samples = mc_samples
self._verbose = verbose
Expand Down
8 changes: 7 additions & 1 deletion vivit/linalg/eigvalsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from torch.nn import Module, Parameter

from vivit.linalg.utils import get_hook_store_batch_size, get_vivit_extension
from vivit.utils.checks import (
check_key_exists,
check_subsampling_unique,
check_unique_params,
)
from vivit.utils.gram import reshape_as_square
from vivit.utils.hooks import ParameterGroupsHook
from vivit.utils.param_groups import check_key_exists, check_unique_params


class EigvalshComputation:
Expand All @@ -34,6 +38,8 @@ def __init__(
during backpropagation to command line (consider it a debugging tool).
Defaults to ``False``.
"""
check_subsampling_unique(subsampling)

self._subsampling = subsampling
self._mc_samples = mc_samples
self._verbose = verbose
Expand Down
17 changes: 16 additions & 1 deletion vivit/utils/param_groups.py → vivit/utils/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utility functions to deal with parameter groups."""

from typing import Dict, List
from typing import Dict, List, Union


def check_key_exists(param_groups: List[Dict], key: str):
Expand Down Expand Up @@ -32,3 +32,18 @@ def check_unique_params(param_groups: List[Dict]):

if len(set(params_ids)) != len(params_ids):
raise ValueError("At least one parameter is in more than one group.")


def check_subsampling_unique(subsampling: Union[None, List[int]]):
"""Check that sub-sampling contains unique sample indices.
Args:
subsampling: Indices of active samples used for a computation. ``None``
uses the full mini-batch.
Raises:
ValueError: If the same index occurs more than once.
"""
if subsampling is not None:
if len(set(subsampling)) != len(subsampling):
raise ValueError("Detected repeated index in subsampling.")

0 comments on commit f1b9766

Please sign in to comment.