Skip to content

Commit

Permalink
Make ArrayStore use int32 indices (#400)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

The current ArchiveBase uses int32 indices (the assumption being that we
will never have to deal with more than INT_MAX archive cells). This PR
makes ArrayStore use int32 indices to be consistent with ArchiveBase.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Add tests
- [x] Ensure consistency across all ArrayStore methods

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Nov 3, 2023
1 parent 1038582 commit f8847b1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
2 changes: 1 addition & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
({pr}`397`)
- **Backwards-incompatible:** Rename `measure_*` columns to `measures_*` in
`as_pandas` ({pr}`396`)
- Add ArrayStore data structure ({pr}`395`, {pr}`398`)
- Add ArrayStore data structure ({pr}`395`, {pr}`398`, {pr}`400`)
- Add GradientOperatorEmitter to support OMG-MEGA and OG-MAP-Elites ({pr}`348`)

#### Improvements
Expand Down
8 changes: 4 additions & 4 deletions ribs/archives/_array_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, field_desc, capacity):
"capacity": capacity,
"occupied": np.zeros(capacity, dtype=bool),
"n_occupied": 0,
"occupied_list": np.empty(capacity, dtype=int),
"occupied_list": np.empty(capacity, dtype=np.int32),
"updates": np.array([0, 0]),
}

Expand Down Expand Up @@ -168,7 +168,7 @@ def occupied(self):

@property
def occupied_list(self):
"""numpy.ndarray: Integer array listing all occupied indices in the
"""numpy.ndarray: int32 array listing all occupied indices in the
store."""
return readonly(
self._props["occupied_list"][:self._props["n_occupied"]])
Expand Down Expand Up @@ -207,7 +207,7 @@ def retrieve(self, indices, fields=None):
Raises:
ValueError: Invalid field name provided.
"""
indices = np.asarray(indices)
indices = np.asarray(indices, dtype=np.int32)
occupied = readonly(self._props["occupied"][indices])

data = {}
Expand Down Expand Up @@ -363,7 +363,7 @@ def resize(self, capacity):
self._props["occupied"][:cur_capacity] = cur_occupied

cur_occupied_list = self._props["occupied_list"]
self._props["occupied_list"] = np.empty(capacity, dtype=int)
self._props["occupied_list"] = np.empty(capacity, dtype=np.int32)
self._props["occupied_list"][:cur_capacity] = cur_occupied_list

for name, cur_arr in self._fields.items():
Expand Down
24 changes: 23 additions & 1 deletion tests/archives/array_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,28 @@ def test_add_duplicate_indices(store):
assert np.all(store.occupied_list == [3])


def test_dtypes(store):
store.add(
[3, 5],
{
"objective": [1.0, 2.0],
"measures": [[1.0, 2.0], [3.0, 4.0]],
"solution": [np.zeros(10), np.ones(10)],
},
{}, # Empty extra_args.
[], # Empty transforms.
)

_, data = store.retrieve([5, 3])

# Index is always int32, and other fields were defined as float32 in the
# `store` fixture.
assert data["index"].dtype == np.int32
assert data["objective"].dtype == np.float32
assert data["measures"].dtype == np.float32
assert data["solution"].dtype == np.float32


def test_retrieve_duplicate_indices(store):
store.add(
[3],
Expand Down Expand Up @@ -400,7 +422,7 @@ def test_as_pandas(store):
"solution_8",
"solution_9",
]).all()
assert (df.dtypes == [int] + [np.float32] * 13).all()
assert (df.dtypes == [np.int32] + [np.float32] * 13).all()
assert len(df) == 2

row0 = np.concatenate(([3, 1.0, 1.0, 2.0], np.zeros(10)))
Expand Down

0 comments on commit f8847b1

Please sign in to comment.