In [2]:
!pip install jax

Collecting jax
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax)
  Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting ml_dtypes>=0.4.0 (from jax)
  Downloading ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Downloading jax-0.5.0-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl (102.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.0/102.0 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hDownloading ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling

In [1]:
import jax
print(jax.devices())  


[CudaDevice(id=0)]


In [2]:
import numpy as np
from geomstats.geometry.hypersphere import Hypersphere
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from geomstats.learning.frechet_mean import FrechetMean
import time

In [3]:
dim = 32 * 32 * 3
space = Hypersphere(dim - 1)
print(dim)

3072


We use Voelker--Gosmann--Stewart for universal sampling on an $n$-ball, see [here](https://compneuro.uwaterloo.ca/files/publications/voelker.2017.pdf)

In [4]:
def uniform_ball(n,r):
    s = np.random.normal(0, 1, n+2)
    s = s / np.linalg.norm(s)
    b = s[:n]
    b = r * b
    return b

In [5]:
def random_walk(start):
    I = np.eye(dim)
    ref_frame = I[:dim - 1,:]
    ref_pt = I[dim - 1,:].flatten()
    frame = space.metric.parallel_transport(ref_frame, ref_pt, end_point=start)
    b = uniform_ball(dim - 1, np.pi / 2)
    direction = b @ frame
    end = space.metric.exp(direction, start)
    return end

In [6]:
def jexp(v, p):
    a = jnp.cos(jnp.linalg.norm(v)) * p + jnp.sin(jnp.linalg.norm(v)) * v / jnp.linalg.norm(v)
    return a

In [7]:
def jdist(A, B):
    norm_a = jnp.linalg.norm(A, axis=1)
    norm_b = jnp.linalg.norm(B, axis=1)
    inner_prod = jnp.sum(A * B, axis=1)
    cos_angle = inner_prod / (norm_a * norm_b)
    cos_angle = jnp.clip(cos_angle, -1, 1)
    dist = jnp.arccos(cos_angle)
    return dist

In [8]:
def jto_tangent(v, p):
    coef = jnp.sum(v * p) / jnp.sum(p * p)
    tangent = v - coef * p
    return tangent

In [9]:
def proj_to_2sphere(X, tangent_1, tangent_2, base_point):
    p1 = base_point
    p2 = jexp(tangent_1, base_point)
    p3 = jexp(tangent_2, base_point)

    A = jnp.hstack((p1.reshape(-1,1),p2.reshape(-1,1),p3.reshape(-1,1)))
    proj = A @ (jnp.linalg.inv(A.T @ A)) @ (A.T)
    projected_vec = (proj @ (X.T)).T
    row_norm = jnp.linalg.norm(projected_vec, axis=1)
    D = jnp.diag(jnp.reciprocal(row_norm))
    sphere_vec = D @ projected_vec

    return sphere_vec

In [10]:
def loss(X, param):
    intercept, coef1, coef2 = jnp.split(param, 3)
    intercept = jnp.reshape(intercept, (dim,))
    coef1 = jnp.reshape(coef1, (dim,))
    coef2 = jnp.reshape(coef2, (dim,))

    base_point = intercept / jnp.linalg.norm(intercept)
    penalty = jnp.sum(jnp.square(base_point - intercept))

    tangent_1 = jto_tangent(coef1, base_point)
    tangent_2 = jto_tangent(coef2, base_point)
    distances = jdist(X, proj_to_2sphere(X, tangent_1, tangent_2, base_point)) ** 2

    return jnp.sum(distances) / 2 + penalty

In [None]:
list_rss = []
list_fs = []
for i in range(1000):
    print(i)
    start_time = time.time()
    I = np.eye(dim)
    start = I[dim - 1,:].flatten()
    points = I[dim - 1,:].reshape(1,-1)
    for j in range(99):
        end = random_walk(start)
        points = np.concatenate((points, end.reshape(1,-1)), axis=0)
        start = end
    
    jpoints = jnp.asarray(points)
    
    intercept_init, coef1_init, coef2_init = np.random.normal(size=(3,) + (dim,))
    intercept_init = jnp.asarray(intercept_init)
    coef1_init = jnp.asarray(coef1_init)
    coef2_init = jnp.asarray(coef2_init)
    intercept_hat = intercept_init / jnp.linalg.norm(intercept_init)
    coef1_hat = jto_tangent(coef1_init, intercept_hat)
    coef2_hat = jto_tangent(coef2_init, intercept_hat)
    initial_guess = jnp.hstack([intercept_hat.flatten(), coef1_hat.flatten(), coef2_hat.flatten()])

    objective_with_grad = lambda param: loss(jpoints, param)

    result = minimize(objective_with_grad, initial_guess, method="BFGS", tol=1e-5)

    ans = np.array(result.x)
    intercept_fin, coef1_fin, coef2_fin = np.split(result.x, 3)
    intercept_fin = np.reshape(intercept_fin, space.shape)
    coef1_fin = np.reshape(coef1_fin, space.shape)
    coef2_fin = np.reshape(coef2_fin, space.shape)

    intercept_ = space.projection(intercept_fin)
    coef1_ = space.to_tangent(coef1_fin, intercept_)
    coef2_ = space.to_tangent(coef2_fin, intercept_)

    sphere_vec = proj_to_2sphere(points, coef1_, coef2_, intercept_)
    rss = np.sum(space.metric.squared_dist(points, sphere_vec))

    point1 = space.projection(coef1_)
    point2 = space.projection(coef2_)
    basis = np.vstack((intercept_, point1, point2)).T
    Q, R = np.linalg.qr(basis)

    sphere_data = sphere_vec @ Q

    sphere = Hypersphere(dim=2)
    sphere_mean = FrechetMean(sphere)
    sphere_mean.set(max_iter=10000)
    sphere_mean.fit(sphere_data)
    sphere_mean_estimate = sphere_mean.estimate_ 

    sphere_variance = np.sum(sphere.metric.squared_dist(sphere_data, sphere_mean_estimate))
    mixed_variance = rss + sphere_variance
    fitting_score = 1 - rss / mixed_variance

    list_rss.append(rss)
    list_fs.append(fitting_score)
    end_time = time.time()
    print('time for iteration:', end_time-start_time)

0
time for iteration: 108.83073711395264
1




time for iteration: 103.75569248199463
2




time for iteration: 83.52461123466492
3
time for iteration: 49.64082098007202
4
time for iteration: 54.90971112251282
5




time for iteration: 104.1040608882904
6




time for iteration: 66.59885430335999
7
time for iteration: 101.81522917747498
8
time for iteration: 51.326016426086426
9




time for iteration: 140.60948085784912
10




time for iteration: 79.66631960868835
11
time for iteration: 60.72897410392761
12
time for iteration: 54.97362399101257
13
time for iteration: 62.78596878051758
14
time for iteration: 108.20573592185974
15
time for iteration: 45.770297050476074
16


In [None]:
rss_array = np.array(list_rss)
fs_array = np.array(list_fs)
np.savetxt("rss.csv", rss_array)
np.savetxt("fitting_score.csv", fs_array)