# geomstats / geomstats

More modifs on quantization

AliceLeBrigant committed Jan 14, 2020
1 parent 6de8982 commit d44ba7d8d08ce0d8548649e97ac4c3e620f5d153
 @@ -3,9 +3,10 @@ on the circle. """ import matplotlib.pyplot as plt import os import matplotlib.pyplot as plt import geomstats.visualization as visualization from geomstats.geometry.hypersphere import Hypersphere @@ -26,22 +27,17 @@ def main(): clustering = Quantization(metric=circle.metric, n_clusters=n_clusters) clustering = clustering.fit(data) #points = CIRCLE.random_uniform(n_samples=N_POINTS, bound=None) #centers, weights, clusters, n_iterations = METRIC.optimal_quantization( # points=points, n_centers=N_CENTERS, # n_repetitions=N_REPETITIONS, tolerance=TOLERANCE # ) plt.figure(0) visualization.plot(points=clustering.cluster_centers_, space='S1', color='red') visualization.plot(points=clustering.cluster_centers_, space='S1', color='red') plt.show() plt.figure(1) ax = plt.axes() circle_plot = visualization.Circle() circle_plot.draw(ax=ax) for i in range(n_clusters): cluster = data[clustering.labels_==i, :] cluster = data[clustering.labels_ == i, :] circle_plot.draw_points(ax=ax, points=cluster) plt.show()
 @@ -3,9 +3,10 @@ on the sphere """ import matplotlib.pyplot as plt import os import matplotlib.pyplot as plt import geomstats.visualization as visualization from geomstats.geometry.hypersphere import Hypersphere @@ -24,15 +25,15 @@ def main(): plt.figure(0) ax = plt.subplot(111, projection="3d") visualization.plot(points=clustering.cluster_centers_, ax=ax, space='S2', c='r') space='S2', c='r') plt.show() plt.figure(1) ax = plt.subplot(111, projection="3d") sphere_plot = visualization.Sphere() sphere_plot.draw(ax=ax) for i in range(n_clusters): cluster = data[clustering.labels_==i, :] cluster = data[clustering.labels_ == i, :] sphere_plot.draw_points(ax=ax, points=cluster) plt.show()
 @@ -48,7 +48,7 @@ def get_mask_i_float(i, n): return mask_i_float def gather(x, indices): def gather(x, indices, axis=0): return x[indices] @@ -390,10 +390,6 @@ def nonzero(x): return np.nonzero(x) def copy(x): return np.copy(x) def ix_(*args): return np.ix_(*args)