Skip to content

Commit

Permalink
Refactor argument validation utilities (#365)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

Currently, the argument validation utilities do not perform
preprocessing on the arguments passed into them. However, every call of
these utilities is prefaced with preprocessing (e.g., converting with
asarray). This PR moves the preprocessing into the utilities. This
should make it easier to handle argument validation as there is no
longer a need to repeat preprocessing code.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Refactor validate_batch_args
- [x] Refactor validate_single_args

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 10, 2023
1 parent 0a98b7a commit 3387935
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 140 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- Improve developer workflow with pre-commit ({pr}`351`, {pr}`363`)
- Refactor visualize module into multiple files ({pr}`357`)
- Add GitHub link roles in documentation ({pr}`361`)
- Refactor argument validation utilities ({pr}`365`)

## 0.5.2

Expand Down
153 changes: 95 additions & 58 deletions ribs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,78 +87,115 @@ def check_solution_batch_dim(array,
" and tell() only took in a single solution.")


def validate_batch_args(archive,
*,
solution_batch,
objective_batch,
measures_batch,
status_batch=None,
value_batch=None,
jacobian_batch=None,
metadata_batch=None):
"""Performs checks for arguments to add() and tell()."""
def validate_batch_args(archive, solution_batch, **batch_kwargs):
"""Preprocesses and validates batch arguments.
The batch size of each argument in batch_kwargs is validated with respect to
solution_batch.
The arguments are assumed to come directly from users, so they may not be
arrays. Thus, we preprocess each argument by converting it into a numpy
array. We then perform checks on the array, including seeing if its batch
size matches the batch size of solution_batch. The arguments are then
returned in the same order that they were passed into the kwargs, with
solution_batch coming first.
Note that we can guarantee the order is the same as when passed in due to
PEP 468 (https://peps.python.org/pep-0468/), which guarantees that kwargs
will preserve the same order as they are listed.
See the for loop for the list of supported kwargs.
"""
# List of args to return.
returns = []

# Process and validate solution_batch.
solution_batch = np.asarray(solution_batch)
check_batch_shape(solution_batch, "solution_batch", archive.solution_dim,
"solution_dim", _BATCH_WARNING)
batch_size = solution_batch.shape[0]
returns.append(solution_batch)

check_is_1d(objective_batch, "objective_batch", _BATCH_WARNING)
check_solution_batch_dim(objective_batch,
"objective_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(objective_batch, "objective_batch")

check_batch_shape(measures_batch, "measures_batch", archive.measure_dim,
"measure_dim", _BATCH_WARNING)
check_solution_batch_dim(measures_batch,
"measures_batch",
batch_size,
is_1d=False,
extra_msg=_BATCH_WARNING)
check_finite(measures_batch, "measures_batch")

if jacobian_batch is not None:
check_batch_shape_3d(jacobian_batch, "jacobian_batch",
archive.measure_dim + 1, "measure_dim + 1",
archive.solution_dim, "solution_dim")
check_finite(jacobian_batch, "jacobian_batch")

if status_batch is not None:
check_is_1d(status_batch, "status_batch", _BATCH_WARNING)
check_solution_batch_dim(status_batch,
"status_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(status_batch, "status_batch")

if value_batch is not None:
check_is_1d(value_batch, "value_batch", _BATCH_WARNING)
check_solution_batch_dim(value_batch,
"value_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)

if metadata_batch is not None:
check_is_1d(metadata_batch, "metadata_batch", _BATCH_WARNING)
check_solution_batch_dim(metadata_batch,
"metadata_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
# Process and validate the other batch arguments.
batch_size = solution_batch.shape[0]
for name, arg in batch_kwargs.items():
if name == "objective_batch":
objective_batch = np.asarray(arg)
check_is_1d(objective_batch, "objective_batch", _BATCH_WARNING)
check_solution_batch_dim(objective_batch,
"objective_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(objective_batch, "objective_batch")
returns.append(objective_batch)
elif name == "measures_batch":
measures_batch = np.asarray(arg)
check_batch_shape(measures_batch, "measures_batch",
archive.measure_dim, "measure_dim",
_BATCH_WARNING)
check_solution_batch_dim(measures_batch,
"measures_batch",
batch_size,
is_1d=False,
extra_msg=_BATCH_WARNING)
check_finite(measures_batch, "measures_batch")
returns.append(measures_batch)
elif name == "jacobian_batch":
jacobian_batch = np.asarray(arg)
check_batch_shape_3d(jacobian_batch, "jacobian_batch",
archive.measure_dim + 1, "measure_dim + 1",
archive.solution_dim, "solution_dim")
check_finite(jacobian_batch, "jacobian_batch")
returns.append(jacobian_batch)
elif name == "status_batch":
status_batch = np.asarray(arg)
check_is_1d(status_batch, "status_batch", _BATCH_WARNING)
check_solution_batch_dim(status_batch,
"status_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(status_batch, "status_batch")
returns.append(status_batch)
elif name == "value_batch":
value_batch = np.asarray(arg)
check_is_1d(value_batch, "value_batch", _BATCH_WARNING)
check_solution_batch_dim(value_batch,
"value_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
returns.append(value_batch)
elif name == "metadata_batch":
# Special case -- metadata_batch defaults to None in our methods,
# but we make it into an array of None if it is not provided.
metadata_batch = (np.empty(batch_size, dtype=object)
if arg is None else np.asarray(arg, dtype=object))
check_is_1d(metadata_batch, "metadata_batch", _BATCH_WARNING)
check_solution_batch_dim(metadata_batch,
"metadata_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
returns.append(metadata_batch)

return returns


def validate_single_args(archive, solution, objective, measures):
"""Performs preprocessing and checks for arguments to add_single()."""
solution = np.asarray(solution)
check_1d_shape(solution, "solution", archive.solution_dim, "solution_dim")

objective = archive.dtype(objective)
check_finite(objective, "objective")

measures = np.asarray(measures)
check_1d_shape(measures, "measures", archive.measure_dim, "measure_dim")
check_finite(measures, "measures")

return solution, objective, measures


def readonly(arr):
"""Sets an array to be readonly."""
Expand Down
25 changes: 12 additions & 13 deletions ribs/archives/_archive_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,22 +547,20 @@ def add(self,
"""
self._state["add"] += 1

## Step 0: Preprocess input. ##
solution_batch = np.asarray(solution_batch)
objective_batch = np.asarray(objective_batch)
measures_batch = np.asarray(measures_batch)
batch_size = solution_batch.shape[0]
metadata_batch = (np.empty(batch_size, dtype=object) if metadata_batch
is None else np.asarray(metadata_batch, dtype=object))

## Step 1: Validate input. ##
validate_batch_args(
(
solution_batch,
objective_batch,
measures_batch,
metadata_batch,
) = validate_batch_args(
archive=self,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
metadata_batch=metadata_batch,
)
batch_size = solution_batch.shape[0]

## Step 2: Compute status_batch and value_batch ##

Expand Down Expand Up @@ -749,10 +747,11 @@ def add_single(self, solution, objective, measures, metadata=None):
"""
self._state["add"] += 1

solution = np.asarray(solution)
objective = self.dtype(objective)
measures = np.asarray(measures)
validate_single_args(
(
solution,
objective,
measures,
) = validate_single_args(
self,
solution=solution,
objective=objective,
Expand Down
26 changes: 12 additions & 14 deletions ribs/archives/_sliding_boundaries_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,22 +408,19 @@ def add(self,
See :meth:`ArchiveBase.add` for arguments and return values.
"""
# Preprocess input.
solution_batch = np.array(solution_batch)
batch_size = solution_batch.shape[0]
objective_batch = np.array(objective_batch)
measures_batch = np.array(measures_batch)
metadata_batch = (np.empty(batch_size, dtype=object) if metadata_batch
is None else np.asarray(metadata_batch, dtype=object))

# Validate arguments.
validate_batch_args(
(
solution_batch,
objective_batch,
measures_batch,
metadata_batch,
) = validate_batch_args(
archive=self,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
metadata_batch=metadata_batch,
)
batch_size = solution_batch.shape[0]

status_batch = np.empty(batch_size, dtype=np.int32)
value_batch = np.empty(batch_size, dtype=self.dtype)
Expand All @@ -449,10 +446,11 @@ def add_single(self, solution, objective, measures, metadata=None):
See :meth:`ArchiveBase.add_single` for arguments and return values.
"""

solution = np.asarray(solution)
objective = self.dtype(objective)
measures = np.asarray(measures)
validate_single_args(
(
solution,
objective,
measures,
) = validate_single_args(
self,
solution=solution,
objective=objective,
Expand Down
34 changes: 16 additions & 18 deletions ribs/emitters/_evolution_strategy_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,24 +213,22 @@ def tell(self,
metadata_batch (array-like): 1D object array containing a metadata
object for each solution.
"""
# Preprocessing arguments.
solution_batch = np.asarray(solution_batch)
objective_batch = np.asarray(objective_batch)
measures_batch = np.asarray(measures_batch)
status_batch = np.asarray(status_batch)
value_batch = np.asarray(value_batch)
batch_size = solution_batch.shape[0]
metadata_batch = (np.empty(batch_size, dtype=object) if metadata_batch
is None else np.asarray(metadata_batch, dtype=object))

# Validate arguments.
validate_batch_args(archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
metadata_batch=metadata_batch)
(
solution_batch,
objective_batch,
measures_batch,
status_batch,
value_batch,
metadata_batch,
) = validate_batch_args(
archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
metadata_batch=metadata_batch,
)

# Increase iteration counter.
self._itrs += 1
Expand Down
71 changes: 34 additions & 37 deletions ribs/emitters/_gradient_arborescence_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,25 +335,24 @@ def tell_dqd(self,
metadata_batch (array-like): 1d object array containing a metadata
object for each solution.
"""
# Preprocessing arguments.
solution_batch = np.asarray(solution_batch)
objective_batch = np.asarray(objective_batch)
measures_batch = np.asarray(measures_batch)
status_batch = np.asarray(status_batch)
value_batch = np.asarray(value_batch)
batch_size = solution_batch.shape[0]
metadata_batch = (np.empty(batch_size, dtype=object) if metadata_batch
is None else np.asarray(metadata_batch, dtype=object))

# Validate arguments.
validate_batch_args(archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
jacobian_batch=jacobian_batch,
metadata_batch=metadata_batch)
(
solution_batch,
objective_batch,
measures_batch,
status_batch,
value_batch,
jacobian_batch,
metadata_batch,
) = validate_batch_args(
archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
jacobian_batch=jacobian_batch,
metadata_batch=metadata_batch,
)

if self._normalize_grads:
norms = (np.linalg.norm(jacobian_batch, axis=2, keepdims=True) +
Expand Down Expand Up @@ -393,24 +392,22 @@ def tell(self,
RuntimeError: This method was called without first passing gradients
with calls to ask_dqd() and tell_dqd().
"""
# Preprocessing arguments.
solution_batch = np.asarray(solution_batch)
objective_batch = np.asarray(objective_batch)
measures_batch = np.asarray(measures_batch)
status_batch = np.asarray(status_batch)
value_batch = np.asarray(value_batch)
batch_size = solution_batch.shape[0]
metadata_batch = (np.empty(batch_size, dtype=object) if metadata_batch
is None else np.asarray(metadata_batch, dtype=object))

# Validate arguments.
validate_batch_args(archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
metadata_batch=metadata_batch)
(
solution_batch,
objective_batch,
measures_batch,
status_batch,
value_batch,
metadata_batch,
) = validate_batch_args(
archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
metadata_batch=metadata_batch,
)

if self._jacobian_batch is None:
raise RuntimeError("Please call ask_dqd(), tell_dqd(), and ask() "
Expand Down

0 comments on commit 3387935

Please sign in to comment.