Skip to content

Commit

Permalink
Increase flexibility on kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
gmrukwa committed Feb 7, 2021
1 parent 4863ae4 commit 9c30dc2
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
7 changes: 4 additions & 3 deletions divik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
except ModuleNotFoundError:
import importlib_metadata

if os.environ.get('READTHEDOCS') == 'True':
if os.environ.get("READTHEDOCS") == "True":
import toml

dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, '../pyproject.toml')) as f:
__version__ = toml.load(f)['tool']['poetry']['version']
with open(os.path.join(dirname, "../pyproject.toml")) as f:
__version__ = toml.load(f)["tool"]["poetry"]["version"]
else:
__version__ = importlib_metadata.version(__name__)

Expand Down
7 changes: 3 additions & 4 deletions divik/cluster/_divik/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,9 @@ def _feature_selector(self, n_features):
def _divik(self, X, progress):
full = self.kmeans
fast = self.fast_kmeans
if fast is None:
warn_const = full.kmeans.normalize_rows
else:
warn_const = fast.kmeans.normalize_rows or full.kmeans.normalize_rows
warn_const = getattr(full.kmeans, "normalize_rows", False)
if fast is not None and not warn_const:
warn_const = getattr(fast.kmeans, "normalize_rows", False)
report = DivikReporter(progress, warn_const=warn_const)
select_all = np.ones(shape=(X.shape[0],), dtype=bool)
if self.minimal_size is None:
Expand Down
4 changes: 2 additions & 2 deletions divik/cluster/_two_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class TwoStep(BaseEstimator, ClusterMixin):
>>> kmeans = KMeans(n_clusters=3)
>>> ctr = TwoStep(kmeans).fit(X)
"""

def __init__(self, clusterer, n_subsets: int = 10, random_state: int = 42):
self.clusterer = clusterer
self.n_subsets = n_subsets
Expand Down Expand Up @@ -90,8 +91,7 @@ def fit(self, X, y=None):
final_labels = np.array([to_final[l] for l in initial_labels])
self.labels_ = final_labels
self.n_clusters_ = _get_first_attr(
_get_final_estimator(self.estimator_),
['n_clusters', 'n_clusters_'],
_get_final_estimator(self.estimator_), ["n_clusters", "n_clusters_"],
)
return self

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

# The full version, including alpha/beta/rc tags
dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, '../pyproject.toml')) as f:
release = toml.load(f)['tool']['poetry']['version']
with open(os.path.join(dirname, "../pyproject.toml")) as f:
release = toml.load(f)["tool"]["poetry"]["version"]


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "divik"
version = "3.0.5"
version = "3.0.6"
description = "Divisive iK-means algorithm implementation"
authors = ["Grzegorz Mrukwa <g.mrukwa@gmail.com>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 9c30dc2

Please sign in to comment.