Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
feat(*): add check_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
estripling committed Dec 5, 2022
1 parent 046bece commit fefa861
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/fbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
__version__ = version("fbench")

from .core import *
from .validation import *

del core
del core, validation
11 changes: 7 additions & 4 deletions src/fbench/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from fbench.validation import check_vector

__all__ = (
"ackley",
"rastrigin",
Expand Down Expand Up @@ -44,7 +46,7 @@ def ackley(x):
>>> round(ackley([1, 1]), 4)
3.6254
"""
x = np.asarray(x)
x = check_vector(x, 1)
return float(
-20 * np.exp(-0.2 * np.sqrt((x**2).mean()))
- np.exp((np.cos(2 * np.pi * x)).sum() / len(x))
Expand Down Expand Up @@ -87,7 +89,7 @@ def rastrigin(x):
>>> round(rastrigin([4.5, 4.5]), 4)
80.5
"""
x = np.asarray(x)
x = check_vector(x, 1)
return float(10 * len(x) + (x**2 - 10 * np.cos(2 * np.pi * x)).sum())


Expand Down Expand Up @@ -130,7 +132,7 @@ def rosenbrock(x):
>>> round(rosenbrock([3, 3]), 4)
3604.0
"""
x = np.asarray(x)
x = check_vector(x, 2)
return float((100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum())


Expand Down Expand Up @@ -164,4 +166,5 @@ def sphere(x):
>>> sphere([1, 1])
2.0
"""
return float((np.asarray(x) ** 2).sum())
x = check_vector(x, 1)
return float((x**2).sum())
10 changes: 10 additions & 0 deletions src/fbench/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class IncorrectNumberOfElements(ValueError):
"""Raise if number of elements is incorrect."""

pass


class NotAVectorError(TypeError):
"""Raise if object is not vector-like."""

pass
43 changes: 43 additions & 0 deletions src/fbench/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

__all__ = ("check_vector",)

from fbench.exception import IncorrectNumberOfElements, NotAVectorError


def check_vector(x, min_number_of_elements):
"""Validate an n-dimensional vector.
Parameters
----------
x : array_like
Input data with :math:`n` elements that can be converted to an array.
min_number_of_elements : int
Specify the minimum number of elements ``x`` must have.
Returns
-------
np.ndarray
An n-dimensional vector.
Raises
------
NotAVectorError
If ``x`` is not vector-like.
IncorrectNumberOfElements
If ``x`` does not satisfy the ``min_number_of_elements`` condition.
"""
x = np.asarray(x)

if len(x.shape) != 1:
raise NotAVectorError(
f"input must be vector-like object - it has shape={x.shape}"
)

if not len(x) >= min_number_of_elements:
raise IncorrectNumberOfElements(
f"number of elements must be at least {min_number_of_elements} "
f"- it has {x.shape[0]}"
)

return x
1 change: 0 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_rastrigin(arg, expected):
@pytest.mark.parametrize(
"arg, expected",
[
([0], 0.0),
([0, 0], 1.0),
([1, 1], 0.0),
([1, 1, 1], 0.0),
Expand Down
17 changes: 17 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
import numpy.testing as npt
import pytest

from fbench import exception, validation


def test_check_vector():
x = [1, 2, 3]
actual = validation.check_vector(x, 1)
npt.assert_array_equal(actual, np.array(x))

with pytest.raises(exception.NotAVectorError):
validation.check_vector([[1, 2]], 1)

with pytest.raises(exception.IncorrectNumberOfElements):
validation.check_vector([1, 2], 3)

0 comments on commit fefa861

Please sign in to comment.