Skip to content

Commit

Permalink
Add field_list and data methods to archives (#412)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

The overall goal of this PR is to make it easier to access the data
contained in each archive.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Introduce a `data()` method that returns the archive data in many
forms -> this method primarily passes calls to `ArrayStore.data`
- [x] Test data() by modifying old as_pandas tests (we do not place too
much emphasis on testing since ArrayStore.data is already tested fairly
thoroughly)
- [x] Add a `field_list` method that shows the list of all fields in the
archive

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Nov 10, 2023
1 parent 0bb298f commit c26b63a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 28 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### API

- Add field_list and data methods to archives ({pr}`412`)
- Include threshold in `archive.best_elite` ({pr}`409`)
- **Backwards-incompatible:** Replace Elite and EliteBatch with dicts
({pr}`397`)
Expand Down
89 changes: 88 additions & 1 deletion ribs/archives/_archive_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
single_entry_with_threshold)


class ArchiveBase(ABC): # pylint: disable = too-many-instance-attributes
class ArchiveBase(ABC):
# pylint: disable = too-many-instance-attributes, too-many-public-methods
"""Base class for archives.
This class composes archives using an :class:`ArrayStore` that has
Expand Down Expand Up @@ -110,6 +111,11 @@ def __init__(self,
capacity=self._cells,
)

@property
def field_list(self):
"""list: List of data fields in the archive."""
return self._store.field_list

@property
def cells(self):
"""int: Total number of cells in the archive."""
Expand Down Expand Up @@ -640,6 +646,87 @@ def sample_elites(self, n):
_, elites = self._store.retrieve(selected_indices)
return elites

def data(self, fields=None, return_type="dict"):
"""Retrieves data for all elites in the archive.
Args:
fields (array-like of str): List of fields to include. By default,
all fields will be included (see :attr:`field_list`), with an
additional "index" as the last field ("index" can also be placed
anywhere in this list).
return_type (str): Type of data to return. See below.
Returns:
The data at the given indices. This can take the following forms,
depending on the ``return_type`` argument:
- ``return_type="dict"``: Dict mapping from the field name to the
field data at the given indices. An example is::
{
"solution": [[1.0, 1.0, ...], ...],
"objective": [1.5, ...],
"measures": [[1.0, 2.0], ...],
"threshold": [0.8, ...],
"index": [4, ...],
}
Observe that we also return the indices as an ``index`` entry in
the dict. The keys in this dict can be modified with the
``fields`` arg; duplicate fields will be ignored since the dict
stores unique keys.
- ``return_type="tuple"``: Tuple of arrays matching the field order
given in ``fields``. For instance, if ``fields`` was
``["objective", "measures"]``, we would receive a tuple of
``(objective_arr, measures_arr)``. In this case, the results
from ``retrieve`` could be unpacked as::
objective, measures = archive.data(["objective", "measures"])
Unlike with the ``dict`` return type, duplicate fields will show
up as duplicate entries in the tuple, e.g.,
``fields=["objective", "objective"]`` will result in two
objective arrays being returned.
By default, (i.e., when ``fields=None``), the fields in the tuple
will be ordered according to the :attr:`field_list` along with
``index`` as the last field.
- ``return_type="pandas"``: A
:class:`~ribs.archives.ArchiveDataFrame` with the following
columns:
- For fields that are scalars, a single column with the field
name. For example, ``objective`` would have a single column
called ``objective``.
- For fields that are 1D arrays, multiple columns with the name
suffixed by its index. For instance, if we have a ``measures``
field of length 10, we create 10 columns with names
``measures_0``, ``measures_1``, ..., ``measures_9``. We do not
currently support fields with >1D data.
- 1 column of integers (``np.int32``) for the index, named
``index``.
In short, the dataframe might look like this by default:
+------------+------+-----------+------------+------+-----------+-------+
| solution_0 | ... | objective | measures_0 | ... | threshold | index |
+============+======+===========+============+======+===========+=======+
| | ... | | | ... | | |
+------------+------+-----------+------------+------+-----------+-------+
Like the other return types, the columns can be adjusted with
the ``fields`` parameter.
All data returned by this method will be a copy, i.e., the data will
not update as the archive changes.
""" # pylint: disable = line-too-long
data = self._store.data(fields, return_type)
if return_type == "pandas":
data = ArchiveDataFrame(data)
return data

def as_pandas(self, include_solutions=True, include_metadata=False):
"""Converts the archive into an :class:`ArchiveDataFrame` (a child class
of :class:`pandas.DataFrame`).
Expand Down
50 changes: 23 additions & 27 deletions tests/archives/archive_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def test_qd_score_offset_correct(data):
assert data.archive.qd_score_offset == 0.0 # Default value.


def test_field_list_correct(data):
assert data.archive.field_list == [
"solution", "objective", "measures", "metadata", "threshold"
]


def test_basic_stats(data):
assert data.archive.stats.num_elites == 0
assert data.archive.stats.coverage == 0.0
Expand Down Expand Up @@ -395,36 +401,27 @@ def test_sample_elites_fails_when_empty(data):

@pytest.mark.parametrize("name", ARCHIVE_NAMES)
@pytest.mark.parametrize("with_elite", [True, False], ids=["nonempty", "empty"])
@pytest.mark.parametrize("include_solutions", [True, False],
ids=["solutions", "no_solutions"])
@pytest.mark.parametrize("include_metadata", [True, False],
ids=["metadata", "no_metadata"])
@pytest.mark.parametrize("dtype", [np.float64, np.float32],
ids=["float64", "float32"])
def test_as_pandas(name, with_elite, include_solutions, include_metadata,
dtype):
def test_pandas_data(name, with_elite, dtype):
data = get_archive_data(name, dtype)

# Set up expected columns and data types.
measure_cols = [f"measures_{i}" for i in range(len(data.measures))]
expected_cols = ["index"] + measure_cols + ["objective"]
expected_dtypes = [np.int32, *[dtype for _ in measure_cols], dtype]
if include_solutions:
solution_cols = [f"solution_{i}" for i in range(len(data.solution))]
expected_cols += solution_cols
expected_dtypes += [dtype for _ in solution_cols]
if include_metadata:
expected_cols.append("metadata")
expected_dtypes.append(object)
solution_dim = len(data.solution)
measure_dim = len(data.measures)
expected_cols = ([f"solution_{i}" for i in range(solution_dim)] +
["objective"] +
[f"measures_{i}" for i in range(measure_dim)] +
["metadata", "threshold", "index"])
expected_dtypes = ([dtype for _ in range(solution_dim)] + [dtype] +
[dtype for _ in range(measure_dim)] +
[object, dtype, np.int32])

# Retrieve the dataframe.
if with_elite:
df = data.archive_with_elite.as_pandas(
include_solutions=include_solutions,
include_metadata=include_metadata)
df = data.archive_with_elite.data(return_type="pandas")
else:
df = data.archive.as_pandas(include_solutions=include_solutions,
include_metadata=include_metadata)
df = data.archive.data(return_type="pandas")

# Check columns and data types.
assert (df.columns == expected_cols).all()
Expand All @@ -441,9 +438,8 @@ def test_as_pandas(name, with_elite, include_solutions, include_metadata,
assert df.loc[0, "index"] == data.archive.grid_to_int_index(
[data.grid_indices])[0]

expected_data = [*data.measures, data.objective]
if include_solutions:
expected_data += list(data.solution)
if include_metadata:
expected_data.append(data.metadata)
assert (df.loc[0, "measures_0":] == expected_data).all()
expected_data = [
*data.solution, data.objective, *data.measures, data.metadata,
data.objective
]
assert (df.loc[0, :"threshold"] == expected_data).all()

0 comments on commit c26b63a

Please sign in to comment.