Skip to content

Commit

Permalink
Use chunk computation in CVT brute force calculation to reduce memory…
Browse files Browse the repository at this point in the history
… usage (#394)

## Description

<!-- Provide a brief description of the PR's purpose here. -->

Added chunking calculation option to CVT Archive

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- CVTArchive brute force calculation fails (OOM) for 1e8 inputs or more
whereas chunking succeeds

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go

---------

Co-authored-by: itsdawei <dhlee@usc.edu>
  • Loading branch information
svott03 and itsdawei committed Oct 21, 2023
1 parent b1cddb7 commit ecd75cc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=suppressed-message,arguments-differ,wildcard-import,locally-disabled,duplicate-code
disable=suppressed-message,arguments-differ,wildcard-import,locally-disabled,duplicate-code,no-else-return


[REPORTS]
Expand Down
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#### Improvements

- Use chunk computation in CVT brute force calculation to reduce memory usage
({pr}`394`)
- Test pyribs installation in tutorials ({pr}`384`)
- Add cron job for testing installation ({pr}`389`)
- Fix broken cross-refs in docs ({pr}`393`)
Expand Down
38 changes: 28 additions & 10 deletions ribs/archives/_cvt_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class CVTArchive(ArchiveBase):
be used instead.
ckdtree_kwargs (dict): kwargs for :class:`~scipy.spatial.cKDTree`. By
default, we do not pass in any kwargs.
chunk_size (int): If passed, brute forcing the closest centroid search
will chunk the distance calculations to compute chunk_size inputs at
a time.
Raises:
ValueError: The ``samples`` array or the ``custom_centroids`` array has
the wrong shape.
Expand All @@ -118,6 +121,7 @@ def __init__(self,
dtype=np.float64,
samples=100_000,
custom_centroids=None,
chunk_size=None,
k_means_kwargs=None,
use_kd_tree=True,
ckdtree_kwargs=None):
Expand Down Expand Up @@ -157,6 +161,7 @@ def __init__(self,
self._centroid_kd_tree = None
self._ckdtree_kwargs = ({} if ckdtree_kwargs is None else
ckdtree_kwargs.copy())
self._chunk_size = chunk_size

if custom_centroids is None:
if not isinstance(samples, int):
Expand Down Expand Up @@ -259,13 +264,26 @@ def index_of(self, measures_batch):
if self._use_kd_tree:
_, indices = self._centroid_kd_tree.query(measures_batch)
return indices.astype(np.int32)

# Brute force distance calculation -- start by taking the difference
# between each measure i and all the centroids.
distances = np.expand_dims(measures_batch, axis=1) - self.centroids

# Compute the total squared distance -- no need to compute actual
# distance with a sqrt.
distances = np.sum(np.square(distances), axis=2)

return np.argmin(distances, axis=1).astype(np.int32)
else:
expanded_measures = np.expand_dims(measures_batch, axis=1)
# Compute indices chunks at a time
if self._chunk_size is not None and \
self._chunk_size < measures_batch.shape[0]:
indices = []
chunks = np.array_split(
expanded_measures,
np.ceil(len(expanded_measures) / self._chunk_size))
for chunk in chunks:
distances = chunk - self.centroids
distances = np.sum(np.square(distances), axis=2)
current_res = np.argmin(distances, axis=1).astype(np.int32)
indices.append(current_res)
return np.concatenate(tuple(indices))
else:
# Brute force distance calculation -- start by taking the
# difference between each measure i and all the centroids.
distances = expanded_measures - self.centroids
# Compute the total squared distance -- no need to compute
# actual distance with a sqrt.
distances = np.sum(np.square(distances), axis=2)
return np.argmin(distances, axis=1).astype(np.int32)
20 changes: 20 additions & 0 deletions tests/archives/cvt_archive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,23 @@ def test_add_single_without_overwrite(data, add_mode):
assert np.isclose(value, low_objective - data.objective)
assert_archive_elite(data.archive_with_elite, data.solution, data.objective,
data.measures, data.centroid, data.metadata)


def test_chunked_calculation_short():
"""Testing accuracy of chunked computation"""
centroids = [[-1, 1], [0, 1], [1, 1], [-1, 0], [0, 0], [1, 0], [-1, -1],
[0, -1], [1, -1]]

archive = CVTArchive(solution_dim=0,
cells=9,
ranges=[(-1, 1), (-1, 1)],
samples=10,
chunk_size=2,
custom_centroids=centroids,
use_kd_tree=False)
measure_batch = [[-1, 1], [-1, .9], [-.1, 1], [.9, .9], [-.9, 0], [.1, 0],
[1, 0], [-1, -.9], [.1, -.9], [.9, -.9]]
closest_centroids = archive.index_of(measure_batch)
correct_centroids = [0, 0, 1, 2, 3, 4, 5, 6, 7, 8]

assert np.all(closest_centroids == correct_centroids)

0 comments on commit ecd75cc

Please sign in to comment.