Skip to content
Permalink
Browse files

More modifs on quantization

  • Loading branch information
AliceLeBrigant committed Jan 14, 2020
1 parent 6de8982 commit d44ba7d8d08ce0d8548649e97ac4c3e620f5d153
@@ -3,9 +3,10 @@
on the circle.
"""

import matplotlib.pyplot as plt
import os

import matplotlib.pyplot as plt

import geomstats.visualization as visualization

from geomstats.geometry.hypersphere import Hypersphere
@@ -26,22 +27,17 @@ def main():
clustering = Quantization(metric=circle.metric, n_clusters=n_clusters)
clustering = clustering.fit(data)

#points = CIRCLE.random_uniform(n_samples=N_POINTS, bound=None)
#centers, weights, clusters, n_iterations = METRIC.optimal_quantization(
# points=points, n_centers=N_CENTERS,
# n_repetitions=N_REPETITIONS, tolerance=TOLERANCE
# )

plt.figure(0)
visualization.plot(points=clustering.cluster_centers_, space='S1', color='red')
visualization.plot(points=clustering.cluster_centers_, space='S1',
color='red')
plt.show()

plt.figure(1)
ax = plt.axes()
circle_plot = visualization.Circle()
circle_plot.draw(ax=ax)
for i in range(n_clusters):
cluster = data[clustering.labels_==i, :]
cluster = data[clustering.labels_ == i, :]
circle_plot.draw_points(ax=ax, points=cluster)
plt.show()

@@ -3,9 +3,10 @@
on the sphere
"""

import matplotlib.pyplot as plt
import os

import matplotlib.pyplot as plt

import geomstats.visualization as visualization

from geomstats.geometry.hypersphere import Hypersphere
@@ -24,15 +25,15 @@ def main():
plt.figure(0)
ax = plt.subplot(111, projection="3d")
visualization.plot(points=clustering.cluster_centers_, ax=ax,
space='S2', c='r')
space='S2', c='r')
plt.show()

plt.figure(1)
ax = plt.subplot(111, projection="3d")
sphere_plot = visualization.Sphere()
sphere_plot.draw(ax=ax)
for i in range(n_clusters):
cluster = data[clustering.labels_==i, :]
cluster = data[clustering.labels_ == i, :]
sphere_plot.draw_points(ax=ax, points=cluster)
plt.show()

@@ -48,7 +48,7 @@ def get_mask_i_float(i, n):
return mask_i_float


def gather(x, indices):
def gather(x, indices, axis=0):
return x[indices]


@@ -390,10 +390,6 @@ def nonzero(x):
return np.nonzero(x)


def copy(x):
return np.copy(x)


def ix_(*args):
return np.ix_(*args)

@@ -398,10 +398,6 @@ def nonzero(*args, **kwargs):
return torch.nonzero(*args, **kwargs)


def copy(x):
return x.clone()


def seed(x):
torch.manual_seed(x)

@@ -419,3 +415,19 @@ def mean(x, axis=None):

def argmin(*args, **kwargs):
return torch.argmin(*args, **kwargs)


def arange(*args, **kwargs):
return torch.arange(*args, **kwargs)


def gather(x, indices, axis=0):
return x[indices]


def get_mask_i_float(i, n):
range_n = arange(n)
i_float = cast(array([i]), int32)[0]
mask_i = equal(range_n, i_float)
mask_i_float = cast(mask_i, float32)
return mask_i_float
@@ -319,6 +319,10 @@ def stack(*args, **kwargs):
return tf.stack(*args, **kwargs)


def unstack(*args, **kwargs):
return tf.unstack(*args, **kwargs)


def arctan2(*args, **kwargs):
return tf.atan2(*args, **kwargs)

@@ -329,3 +333,7 @@ def diagonal(*args, **kwargs):

def mean(x, axis=None):
return tf.reduce_mean(x, axis)


def argmin(*args, **kwargs):
return tf.argmin(*args, **kwargs)
@@ -4,9 +4,6 @@
import geomstats.backend as gs

from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted



def quantization(X, metric, n_clusters, n_repetitions=20,
@@ -57,11 +54,10 @@ def quantization(X, metric, n_clusters, n_repetitions=20,
"""
n_samples = X.shape[0]
n_features = X.shape[-1]

random_indices = gs.random.randint(low=0, high=n_samples,
size=(n_clusters,))
cluster_centers = X[gs.cast(random_indices, gs.int32), :]
size=(n_clusters,))
cluster_centers = gs.gather(X, gs.cast(random_indices, gs.int32), axis=0)

gap = 1.0
iteration = 0
@@ -71,10 +67,11 @@ def quantization(X, metric, n_clusters, n_repetitions=20,
step_size = gs.floor(gs.array(iteration / n_repetitions)) + 1

random_index = gs.random.randint(low=0, high=n_samples, size=(1,))
point = X[gs.cast(random_index, gs.int32), :]
point = gs.gather(X, gs.cast(random_index, gs.int32), axis=0)

index_to_update = metric.closest_neighbor_index(point, cluster_centers)
center_to_update = gs.copy(cluster_centers[index_to_update, :])
center_to_update = gs.copy(gs.gather(cluster_centers, index_to_update,
axis=0))

tangent_vec_update = metric.log(
point=point, base_point=center_to_update
@@ -93,7 +90,7 @@ def quantization(X, metric, n_clusters, n_repetitions=20,

if iteration == n_max_iterations-1:
print('Maximum number of iterations {} reached. The'
'quantization may be inaccurate'.format(n_max_iterations))
'quantization may be inaccurate'.format(n_max_iterations))

labels = gs.zeros(n_samples)
for i in range(n_samples):
@@ -161,7 +158,7 @@ class Quantization(BaseEstimator, ClusterMixin):
173 (2019), 685 - 703.
"""
def __init__(self, metric, n_clusters, n_repetitions=20,
tolerance=1e-5, n_max_iterations=5e4):
tolerance=1e-5, n_max_iterations=5e4):
self.metric = metric
self.n_clusters = n_clusters
self.n_repetitions = n_repetitions
@@ -178,10 +175,10 @@ def fit(self, X):
"""
self.cluster_centers_, self.labels_ = \
quantization(X=X, metric=self.metric,
n_clusters=self.n_clusters,
n_repetitions=self.n_repetitions,
tolerance=self.tolerance,
n_max_iterations=self.n_max_iterations)
n_clusters=self.n_clusters,
n_repetitions=self.n_repetitions,
tolerance=self.tolerance,
n_max_iterations=self.n_max_iterations)

return self

@@ -190,7 +187,7 @@ def predict(self, point):
Parameters
----------
X : {array-like}, shape=[n_features]
X : array-like, shape=[n_features]
New data to predict.
Returns
@@ -2,12 +2,10 @@

import geomstats.backend as gs

from sklearn.utils.testing import assert_allclose
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.quantization import Quantization



class TestQuantizationMethods(geomstats.tests.TestCase):
_multiprocess_can_split_ = True

@@ -23,7 +21,7 @@ def setUp(self):
def test_fit(self):
X = self.data
clustering = Quantization(metric=self.metric, n_clusters=1,
n_repetitions=1)
n_repetitions=1)
clustering.fit(X)

center = clustering.cluster_centers_
@@ -35,7 +33,7 @@ def test_fit(self):
def test_predict(self):
X = self.data
clustering = Quantization(metric=self.metric, n_clusters=3,
n_repetitions=1)
n_repetitions=1)
clustering.fit(X)

point = self.data[0, :]
@@ -47,4 +45,4 @@ def test_predict(self):


if __name__ == '__main__':
geomstats.tests.main()
geomstats.tests.main()

0 comments on commit d44ba7d

Please sign in to comment.
You can’t perform that action at this time.