Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: Incidental changes to SpectralClustering #228

Merged
merged 25 commits into from Jun 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 67 additions & 41 deletions dask_ml/cluster/spectral.py
Expand Up @@ -14,7 +14,7 @@

from .k_means import KMeans
from ..metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from ..utils import check_array
from ..utils import check_array, _log_array, _format_bytes


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -163,7 +163,10 @@ def __init__(self, n_clusters=8, eigen_solver=None, random_state=None,
self.kmeans_params = kmeans_params

def _check_array(self, X):
return check_array(X, accept_dask_dataframe=False).astype(float)
logger.info("Starting check array")
result = check_array(X, accept_dask_dataframe=False).astype(float)
logger.info("Finished check array")
return result

def fit(self, X, y=None):
X = self._check_array(X)
Expand Down Expand Up @@ -198,7 +201,13 @@ def fit(self, X, y=None):
" Got {} components and {} samples".format(n_components, n))
raise ValueError(msg)

inds = rng.permutation(np.arange(n))
params = self.kernel_params or {}
params['gamma'] = self.gamma
params['degree'] = self.degree
params['coef0'] = self.coef0

inds = np.arange(n)
inds = rng.permutation(inds)
keep = inds[:n_components]
rest = inds[n_components:]
# distributed slice perf.
Expand All @@ -208,58 +217,45 @@ def fit(self, X, y=None):
# recover the original order.
inds_idx = np.argsort(inds)

params = self.kernel_params or {}
params['gamma'] = self.gamma
params['degree'] = self.degree
params['coef0'] = self.coef0

# compute the exact blocks
# these are done in parallel for dask arrays
if isinstance(X, da.Array):
X_keep = X[keep].rechunk(self.n_components).persist()
X_keep = X[keep].rechunk(X.shape).persist()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rechunk is to ensure a single block. X_keep is never large.

else:
X_keep = X[keep]

if isinstance(metric, six.string_types):
if metric not in PAIRWISE_KERNEL_FUNCTIONS:
msg = ("Unknown affinity metric name '{}'. Expected one "
"of '{}'".format(metric,
PAIRWISE_KERNEL_FUNCTIONS.keys()))
raise ValueError(msg)
A = pairwise_kernels(X_keep,
metric=metric, filter_params=True, **params)
B = pairwise_kernels(X_keep, X[rest],
metric=metric, filter_params=True, **params)

elif callable(metric):
A = metric(X_keep, **params)
B = metric(X_keep, X[rest], **params)
else:
msg = ("Unexpected type for 'affinity' '{}'. Must be string "
"kernel name, array, or callable")
raise TypeError(msg)
if isinstance(A, da.Array):
A = A.rechunk((n_components, n_components))
B = B.rechunk((B.shape[0], B.chunks[1]))
X_rest = X[rest]

A, B = embed(X_keep, X_rest, n_components, metric, params)
_log_array(logger, A, 'A')
_log_array(logger, B, 'B')

# now the approximation of C
a = A.sum(0) # (l,)
b1 = B.sum(1) # (l,)
b2 = B.sum(0) # (m,)

# TODO: I think we have some unnecessary delayed wrapping of A here.
A_inv = da.from_delayed(delayed(pinv)(A), A.shape, A.dtype)

inner = A_inv.dot(b1)
d1_si = 1 / da.sqrt(a + b1)
d2_si = 1 / da.sqrt(b2 + B.T.dot(A_inv.dot(b1))) # (m,), dask array

d2_si = 1 / da.sqrt(b2 + B.T.dot(inner)) # (m,), dask array

# d1, d2 are diagonal, so we can avoid large matrix multiplies
# Equivalent to diag(d1_si) @ A @ diag(d1_si)
A2 = d1_si.reshape(-1, 1) * A * d1_si.reshape(1, -1) # (n, n)
_log_array(logger, A2, 'A2')
# A2 = A2.rechunk(A2.shape)
# Equivalent to diag(d1_si) @ B @ diag(d2_si)
B2 = d1_si.reshape(-1, 1) * B * d2_si # (m, m), so this is dask.
B2 = da.multiply(da.multiply(d1_si.reshape(-1, 1), B),
d2_si.reshape(1, -1))
_log_array(logger, B2, 'B2')

# U_A, S_A, V_A = svd(A2)
U_A, S_A, V_A = delayed(svd, pure=True, nout=3)(A2)

U_A = da.from_delayed(U_A, (n_components, n_components), A2.dtype)
S_A = da.from_delayed(S_A, (n_components,), A2.dtype)
V_A = da.from_delayed(V_A, (n_components, n_components), A2.dtype)
Expand All @@ -269,25 +265,29 @@ def fit(self, X, y=None):
da.vstack([A2, B2.T]).dot(
U_A[:, :n_clusters]).dot(
da.diag(1.0 / da.sqrt(S_A[:n_clusters])))) # (n, k)
_log_array(logger, V2, 'V2.1')

if isinstance(B2, da.Array):
V2 = V2.rechunk((B2.chunks[1][0], n_clusters))
_log_array(logger, V2, 'V2.2')

# normalize (Eq. 4)
U2 = (V2.T / da.sqrt((V2 ** 2).sum(1))).T # (n, k)

_log_array(logger, U2, 'U2.2')

# Recover original indices
U2 = U2[inds_idx] # (n, k)

_log_array(logger, U2, 'U2.3')

if self.persist_embedding and isinstance(U2, da.Array):
logger.info("Persisting array for k-means")
U2 = U2.persist()
elif isinstance(U2, da.Array):
# We can still persist the small things...
# TODO: we would need to update the task graphs
# for V2 to replace references to, e.g.
# U_A, A2, etc. with references to persisted
# versions of those.
logger.info("Consider persist_embedding. This will require %s",
_format_bytes(U2.nbytes))
pass

# Recover the original order so that labels match
U2 = U2[inds_idx] # (n, k)

logger.info("k-means for assign_labels[starting]")
km.fit(U2)
logger.info("k-means for assign_labels[finished]")
Expand All @@ -298,3 +298,29 @@ def fit(self, X, y=None):
self.labels_ = km.labels_
self.eigenvalues_ = S_A[:n_clusters] # TODO: better name
return self


def embed(X_keep, X_rest, n_components, metric, kernel_params):
if isinstance(metric, six.string_types):
if metric not in PAIRWISE_KERNEL_FUNCTIONS:
msg = ("Unknown affinity metric name '{}'. Expected one "
"of '{}'".format(metric,
PAIRWISE_KERNEL_FUNCTIONS.keys()))
raise ValueError(msg)
A = pairwise_kernels(X_keep,
metric=metric, filter_params=True,
**kernel_params)
B = pairwise_kernels(X_keep, X_rest,
metric=metric, filter_params=True,
**kernel_params)
elif callable(metric):
A = metric(X_keep, **kernel_params)
B = metric(X_keep, X_rest, **kernel_params)
else:
msg = ("Unexpected type for 'affinity' '{}'. Must be string "
"kernel name, array, or callable")
raise TypeError(msg)
if isinstance(A, da.Array):
A = A.rechunk((n_components, n_components))
B = B.rechunk((B.shape[0], B.chunks[1]))
return A, B
10 changes: 5 additions & 5 deletions dask_ml/metrics/pairwise.py
Expand Up @@ -79,12 +79,12 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
YY = Y_norm_squared
if YY.shape != (1, Y.shape[0]):
raise ValueError(
"Incompatiable dimensions for Y and Y_norm_squared")
"Incompatible dimensions for Y and Y_norm_squared")
else:
YY = row_norms(Y, squared=True)[np.newaxis, :]

# TODO: this often emits a warning. Silence it here?
distances = -2 * X.dot(Y.T) + XX + YY
distances = -2 * da.dot(X, Y.T) + XX + YY
distances = da.maximum(distances, 0)
# TODO: scikit-learn sets the diagonal to 0 when X is Y.

Expand Down Expand Up @@ -116,7 +116,7 @@ def check_pairwise_arrays(X, Y, precomputed=False):
@doc_wraps(metrics.pairwise.linear_kernel)
def linear_kernel(X, Y=None):
X, Y = check_pairwise_arrays(X, Y)
return X.dot(Y.T)
return da.dot(X, Y.T)


@doc_wraps(metrics.pairwise.rbf_kernel)
Expand All @@ -136,7 +136,7 @@ def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1):
if gamma is None:
gamma = 1.0 / X.shape[1]

K = (gamma * X.dot(Y.T) + coef0)**degree
K = (gamma * da.dot(X, Y.T) + coef0)**degree
return K


Expand All @@ -146,7 +146,7 @@ def sigmoid_kernel(X, Y=None, gamma=None, coef0=1):
if gamma is None:
gamma = 1.0 / X.shape[1]

K = X.dot(Y.T)
K = da.dot(X, Y.T)
K *= gamma
K += coef0
K = da.tanh(K)
Expand Down
27 changes: 27 additions & 0 deletions dask_ml/utils.py
Expand Up @@ -214,6 +214,33 @@ def check_chunks(n_samples, n_features, chunks=None):
return chunks


def _log_array(logger, arr, name):
logger.info("%s: %s, %s blocks", name, _format_bytes(arr.nbytes),
getattr(arr, 'numblocks', 'No'))


def _format_bytes(n):
# TODO: just import from distributed if / when required
""" Format bytes as text

>>> format_bytes(1)
'1 B'
>>> format_bytes(1234)
'1.23 kB'
>>> format_bytes(12345678)
'12.35 MB'
>>> format_bytes(1234567890)
'1.23 GB'
"""
if n > 1e9:
return '%0.2f GB' % (n / 1e9)
if n > 1e6:
return '%0.2f MB' % (n / 1e6)
if n > 1e3:
return '%0.2f kB' % (n / 1000)
return '%d B' % n


__all__ = ['assert_estimator_equal',
'check_array',
'check_random_state',
Expand Down
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Expand Up @@ -8,6 +8,7 @@ Enhancements
------------

- Added ``sample_weight`` support for :meth:`dask_ml.metrics.accuracy_score`. (:pr:`217`)
- Improved performance of training on :class:`dask_ml.cluster.SpectralClustering` (:pr:`152`)


Version 0.6.0
Expand Down
6 changes: 4 additions & 2 deletions tests/metrics/test_classification.py
Expand Up @@ -61,7 +61,8 @@ def test_sample_weight(metric_pairs, normalize):
sample_weight_np = np.random.random_sample(size[0])
sample_weight_da = da.from_array(sample_weight_np, chunks=25)

result = m1(a, b, sample_weight=sample_weight_da, normalize=normalize, compute=True)
result = m1(a, b, sample_weight=sample_weight_da, normalize=normalize,
compute=True)
expected = m2(a, b, sample_weight=sample_weight_np, normalize=normalize)
assert abs(result - expected) < 1e-5

Expand All @@ -79,4 +80,5 @@ def test_sample_weight_raises(metric_pairs, normalize):
sample_weight_da = da.from_array(sample_weight_np, chunks=25)

with pytest.raises(NotImplementedError):
m1(a, b, sample_weight=sample_weight_da, normalize=normalize, compute=True)
m1(a, b, sample_weight=sample_weight_da, normalize=normalize,
compute=True)
26 changes: 11 additions & 15 deletions tests/test_spectral_clustering.py
Expand Up @@ -3,7 +3,6 @@
import pytest
import sklearn.cluster
import numpy as np
from numpy.testing import assert_array_equal

from dask_ml.datasets import make_blobs
from dask_ml.cluster import SpectralClustering
Expand Down Expand Up @@ -75,19 +74,16 @@ def test_affinity_raises():
assert m.match("Unexpected type for affinity 'ndarray'")


def test_spectral_clustering():
S = np.array([[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
[0.2, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]])

model = SpectralClustering(random_state=0, n_clusters=2,
n_components=4).fit(S)
def test_spectral_clustering(Xl_blobs_easy):
X, y = Xl_blobs_easy
X = (X - X.mean(0)) / X.std(0)
model = SpectralClustering(random_state=0, n_clusters=3,
n_components=5, gamma=None).fit(X)
labels = model.labels_.compute()
if labels[0] == 0:
labels = 1 - labels
y = y.compute()

idx = [(y == i).argmax() for i in range(3)]
grouped_idx = [np.where(y == y[idx[i]])[0] for i in range(3)]

assert_array_equal(labels, [1, 1, 1, 0, 0, 0, 0])
for indices in grouped_idx:
assert len(set(labels[indices])) == 1