Skip to content

Commit

Permalink
Make ParallelClustering picklable, fixing memory caching in Mapper pi…
Browse files Browse the repository at this point in the history
…pelines (#597)

* Remove lambdas from ParallelClustering so it can be pickled

* Add regression test
  • Loading branch information
ulupo committed Jul 9, 2021
1 parent 51b2a8d commit b760d6a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
26 changes: 20 additions & 6 deletions gtda/mapper/cluster.py
Expand Up @@ -16,6 +16,22 @@
from ..utils.validation import validate_params


def _sample_weight_computer(rel_indices, sample_weight):
return {"sample_weight": sample_weight[rel_indices]}


def _empty_dict(*args):
return {}


def _indices_computer_precomputed(rel_indices):
return np.ix_(rel_indices, rel_indices)


def _indices_computer_not_precomputed(rel_indices):
return rel_indices


class ParallelClustering(BaseEstimator):
"""Employ joblib parallelism to cluster different portions of a dataset.
Expand Down Expand Up @@ -129,16 +145,14 @@ def fit(self, X, y=None, sample_weight=None):

fit_params = signature(self.clusterer.fit).parameters
if sample_weight is not None and "sample_weight" in fit_params:
self._sample_weight_computer = lambda rel_indices, sample_weight: \
{"sample_weight": sample_weight[rel_indices]}
self._sample_weight_computer = _sample_weight_computer
else:
self._sample_weight_computer = lambda *args: {}
self._sample_weight_computer = _empty_dict

if self._precomputed:
self._indices_computer = lambda rel_indices: \
np.ix_(rel_indices, rel_indices)
self._indices_computer = _indices_computer_precomputed
else:
self._indices_computer = lambda rel_indices: rel_indices
self._indices_computer = _indices_computer_not_precomputed

# This seems necessary to avoid large overheads when running fit a
# second time. Probably due to refcounts. NOTE: Only works if done
Expand Down
15 changes: 14 additions & 1 deletion gtda/mapper/tests/test_cluster.py
Expand Up @@ -2,6 +2,9 @@
for ParallelClustering."""
# License: GNU AGPLv3

from shutil import rmtree
from tempfile import mkdtemp

import numpy as np
import pytest
import sklearn as sk
Expand All @@ -11,7 +14,8 @@
from numpy.testing import assert_almost_equal
from scipy.spatial import distance_matrix

from gtda.mapper import ParallelClustering, FirstHistogramGap, FirstSimpleGap
from gtda.mapper import ParallelClustering, FirstHistogramGap, \
FirstSimpleGap, make_mapper_pipeline


def test_parallel_clustering_bad_input():
Expand Down Expand Up @@ -233,3 +237,12 @@ def get_partition_from_preds(preds):

assert get_partition_from_preds(preds) == \
get_partition_from_preds(preds_mat)


def test_mapper_pipeline_picklable():
# Regression test for issue #596
X = np.random.random((100, 2))
cachedir = mkdtemp()
pipe = make_mapper_pipeline(memory=cachedir)
pipe.fit_transform(X)
rmtree(cachedir)

0 comments on commit b760d6a

Please sign in to comment.