In [None]:
%matplotlib notebook

In [None]:
import numpy as np

import skimage.io as sio
import skimage.transform as sktransform

import matplotlib.pyplot as plt
import time


In [None]:
# load
img = sio.imread('snapshot04.png')
# resize
img = sktransform.resize_local_mean(img,list((np.array(img.shape[:-1])/4).astype(int)) + [4])
# get mask
bin_img = (img[:,:,3] > 0)

plt.figure()
plt.imshow(bin_img)
plt.axis('off')

In [None]:
# compute rotation error
x_sample = np.linspace(0,360,300,endpoint=False)
y_res = []
for ang in x_sample: 
    rot_img = sktransform.rotate(bin_img,ang)
    y_res.append((bin_img != rot_img).sum())

In [None]:
# plot two instances of the periodic function
plt.figure()
plt.plot(x_sample,y_res)
plt.xlabel('loss')
plt.ylabel('angle')

In [None]:
import scipy.fftpack as scifft
# Number of non-zero DCT. Should be roughly 1 + 2*(# of local minimia)
N_FFT = 19
yf = scifft.dct(y_res)
yf[N_FFT:] = 0
y_recon = scifft.idct(yf)

plt.figure()
plt.plot(y_res,label='data')
plt.plot(y_recon/(2*len(yf)),label='simplified')
plt.legend()

In [None]:
import jax
import jax.numpy as jnp
import optax

In [None]:
y_jax = jnp.array(y_res)
yf_jax = jnp.array(yf)

def compute_func(x):
    s = yf_jax[0]
    xc = jnp.mod(x,360)
    for i in range(1,N_FFT):
        s += 2*yf_jax[i] * jnp.cos(np.pi*i*(2*xc+1)/(2*360))
    return s

plt.figure()
plt.plot(x_sample,[compute_func(_) for _ in x_sample])
plt.plot(x_sample,y_recon,ls='--')

In [None]:
N_EXP = 100
init_g = np.random.rand(N_EXP)*360

In [None]:
import jaxopt
first_order2 = []
t1 = time.time()
for i in range(N_EXP):
    loc = jnp.array(init_g[i])

    solver = jaxopt.GradientDescent(fun=compute_func, maxiter=250, tol=1e-3)
    res = solver.run(loc)
    first_order2.append(np.array(compute_func(res.params)))
print(time.time() - t1)
print((np.array(first_order2) < y_recon.min()*2).sum())

In [None]:
import jaxopt
second_order = []
t1 = time.time()
for i in range(N_EXP):
    loc = jnp.array(init_g[i])

    solver = jaxopt.LBFGS(fun=compute_func, maxiter=3, tol=1e-3)
    res = solver.run(loc)
    second_order.append(np.array(compute_func(res.params)))
print(time.time() - t1)
print((np.array(second_order) < y_recon.min()*2).sum())

In [None]:
plt.figure()
plt.hist(first_order2,alpha=0.5,label='first')
plt.hist(second_order,alpha=0.5,label='second')
plt.legend()