Skip to content

Commit

Permalink
Merge branch 'dev' into sparse-multigraph-lcc
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Aug 24, 2021
2 parents 2eb5002 + 29e1892 commit ab14d8b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 8 deletions.
46 changes: 38 additions & 8 deletions graspologic/cluster/autogmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_estimate_gaussian_parameters,
)
from sklearn.model_selection import ParameterGrid
from sklearn.utils import check_scalar

from .base import BaseCluster

Expand Down Expand Up @@ -116,6 +117,12 @@ class AutoGMMCluster(BaseCluster):
If provided, min_components and ``max_components`` must match the number of
unique labels given here.
kmeans_n_init : int, optional (default = 1)
If ``kmeans_n_init`` is larger than 1 and ``label_init`` is None, additional
``kmeans_n_init``-1 runs of :class:`sklearn.mixture.GaussianMixture`
initialized with k-means will be performed
for all covariance parameters in ``covariance_type``.
max_iter : int, optional (default = 100).
The maximum number of EM iterations to perform.
Expand Down Expand Up @@ -215,6 +222,7 @@ def __init__(
covariance_type="all",
random_state=None,
label_init=None,
kmeans_n_init=1,
max_iter=100,
verbose=0,
selection_criteria="bic",
Expand Down Expand Up @@ -362,20 +370,23 @@ def __init__(
if max_agglom_size is not None and max_agglom_size < 2:
raise ValueError("Must use at least 2 points for `max_agglom_size`")

check_scalar(kmeans_n_init, name="kmeans_n_init", target_type=int, min_val=1)

self.min_components = min_components
self.max_components = max_components
self.affinity = affinity
self.linkage = linkage
self.covariance_type = new_covariance_type
self.random_state = random_state
self.label_init = labels_init
self.kmeans_n_init = kmeans_n_init
self.max_iter = max_iter
self.verbose = verbose
self.selection_criteria = selection_criteria
self.max_agglom_size = max_agglom_size
self.n_jobs = n_jobs

def _fit_cluster(self, X, X_subset, y, params, agg_clustering):
def _fit_cluster(self, X, X_subset, y, params, agg_clustering, seed):
label_init = self.label_init
if label_init is not None:
onehot = _labels_to_onehot(label_init)
Expand All @@ -400,6 +411,7 @@ def _fit_cluster(self, X, X_subset, y, params, agg_clustering):
gm_params["init_params"] = "kmeans"
gm_params["reg_covar"] = 0
gm_params["max_iter"] = self.max_iter
gm_params["random_state"] = seed

criter = np.inf # if none of the iterations converge, bic/aic is set to inf
# below is the regularization scheme
Expand Down Expand Up @@ -514,10 +526,18 @@ def fit(self, X, y=None):
linkage=self.linkage,
covariance_type=self.covariance_type,
n_components=range(lower_ncomponents, upper_ncomponents + 1),
random_state=[self.random_state],
)
param_grid = list(ParameterGrid(param_grid))
param_grid_ag, param_grid = _process_paramgrid(param_grid)

param_grid_ag, param_grid = _process_paramgrid(
param_grid, self.kmeans_n_init, self.label_init
)

if isinstance(self.random_state, int):
np.random.seed(self.random_state)
seeds = np.random.randint(np.iinfo(np.int32).max, size=len(param_grid))
else:
seeds = [self.random_state] * len(param_grid)

n = X.shape[0]
if self.max_agglom_size is None or n <= self.max_agglom_size:
Expand All @@ -539,17 +559,17 @@ def fit(self, X, y=None):
)
ag_labels.append(hierarchical_labels)

def _fit_for_data(p):
def _fit_for_data(p, seed):
n_clusters = p[1]["n_components"]
if (p[0]["affinity"] != "none") and (self.label_init is None):
index = param_grid_ag.index(p[0])
agg_clustering = ag_labels[index][:, n_clusters - self.min_components]
else:
agg_clustering = []
return self._fit_cluster(X, X_subset, y, p, agg_clustering)
return self._fit_cluster(X, X_subset, y, p, agg_clustering, seed)

results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(_fit_for_data)(p) for p in param_grid
delayed(_fit_for_data)(p, seed) for p, seed in zip(param_grid, seeds)
)
results = pd.DataFrame(results)

Expand Down Expand Up @@ -646,7 +666,7 @@ def _labels_to_onehot(labels):
return onehot


def _process_paramgrid(paramgrid):
def _process_paramgrid(paramgrid, kmeans_n_init, label_init):
"""
Removes combinations of affinity and linkage that are not possible.
Expand All @@ -664,7 +684,7 @@ def _process_paramgrid(paramgrid):
ag_paramgrid_processed : list of dicts
options for AgglomerativeClustering
"""
gm_keys = ["covariance_type", "n_components", "random_state"]
gm_keys = ["covariance_type", "n_components"]
ag_keys = ["affinity", "linkage"]
ag_params_processed = []
paramgrid_processed = []
Expand All @@ -686,6 +706,16 @@ def _process_paramgrid(paramgrid):
ag_params = {key: params[key] for key in ag_keys}
if ag_params not in ag_params_processed:
ag_params_processed.append(ag_params)
if (
ag_params["affinity"] == "none"
and kmeans_n_init > 1
and label_init is None
):
more_kmeans_init = gm_params.copy()
more_kmeans_init.update({"n_init": 1})
paramgrid_processed += [
[{"affinity": "none", "linkage": "none"}, more_kmeans_init]
] * (kmeans_n_init - 1)

paramgrid_processed.append([ag_params, gm_params])
return ag_params_processed, paramgrid_processed
Expand Down
25 changes: 25 additions & 0 deletions tests/cluster/test_autogmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,31 @@ def test_two_class(self):
# Asser that we get perfect clustering
assert_allclose(AutoGMM.ari_, 1)

def test_two_class_multiple_kmeans_inits(self):
"""
Easily separable two gaussian problem.
"""
np.random.seed(1)

n = 100
d = 3

X1 = np.random.normal(2, 0.5, size=(n, d))
X2 = np.random.normal(-2, 0.5, size=(n, d))
X = np.vstack((X1, X2))
y = np.repeat([0, 1], n)

AutoGMM = AutoGMMCluster(max_components=5, kmeans_n_init=2)
AutoGMM.fit(X, y)

n_components = AutoGMM.n_components_

# Assert that the two cluster model is the best
assert_equal(n_components, 2)

# Asser that we get perfect clustering
assert_allclose(AutoGMM.ari_, 1)

def test_two_class_parallel(self):
"""
Easily separable two gaussian problem.
Expand Down

0 comments on commit ab14d8b

Please sign in to comment.