/
plot_quantization_s2.py
43 lines (32 loc) · 1.04 KB
/
plot_quantization_s2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
Plot the result of optimal quantization of the von Mises Fisher distribution
on the sphere
"""
import matplotlib.pyplot as plt
import geomstats.visualization as visualization
from geomstats.hypersphere import Hypersphere
SPHERE2 = Hypersphere(dimension=2)
METRIC = SPHERE2.metric
N_POINTS = 1000
N_CENTERS = 4
N_REPETITIONS = 20
KAPPA = 10
def main():
points = SPHERE2.random_von_mises_fisher(kappa=KAPPA, n_samples=N_POINTS)
centers, weights, clusters, n_steps = METRIC.optimal_quantization(
points=points, n_centers=N_CENTERS,
n_repetitions=N_REPETITIONS
)
plt.figure(0)
ax = plt.subplot(111, projection="3d", aspect="equal")
visualization.plot(points=centers, ax=ax, space='S2', c='r')
plt.show()
plt.figure(1)
ax = plt.subplot(111, projection="3d", aspect="equal")
sphere = visualization.Sphere()
sphere.draw(ax=ax)
for i in range(N_CENTERS):
sphere.draw_points(ax=ax, points=clusters[i])
plt.show()
if __name__ == "__main__":
main()