-
Notifications
You must be signed in to change notification settings - Fork 248
/
plot_quantization_s1.py
52 lines (39 loc) · 1.29 KB
/
plot_quantization_s1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Plot the result of optimal quantization of the uniform distribution
on the circle.
"""
import os
import matplotlib.pyplot as plt
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.quantization import Quantization
N_POINTS = 1000
N_CENTERS = 5
N_REPETITIONS = 20
TOLERANCE = 1e-6
def main():
circle = Hypersphere(dimension=1)
data = circle.random_uniform(n_samples=1000, bound=None)
n_clusters = 5
clustering = Quantization(metric=circle.metric, n_clusters=n_clusters)
clustering = clustering.fit(data)
plt.figure(0)
visualization.plot(points=clustering.cluster_centers_, space='S1',
color='red')
plt.show()
plt.figure(1)
ax = plt.axes()
circle_plot = visualization.Circle()
circle_plot.draw(ax=ax)
for i in range(n_clusters):
cluster = data[clustering.labels_ == i, :]
circle_plot.draw_points(ax=ax, points=cluster)
plt.show()
if __name__ == "__main__":
if os.environ['GEOMSTATS_BACKEND'] == 'tensorflow':
print('Examples with visualizations are only implemented '
'with numpy backend.\n'
'To change backend, write: '
'export GEOMSTATS_BACKEND = \'numpy\'.')
else:
main()