In [1]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

n_one = 20
noise = 1e-1*np.random.normal(size=(n_one*n_one))

# generate cylindrical data

x0 = 0.0
y0 = 0.0
z0 = 0.0

r = 2.0

x = np.linspace(-1, 1, 200)
z = np.linspace(-1, 1, 200)
y = np.sqrt(r**2 - (x**2))
y[49:150] *= -1


fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, zdir='z', s=20, c='b',rasterized=True)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()


In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.config import config
config.update('jax_enable_x64', True)
key = random.PRNGKey(0)



In [3]:
# Initialize random guess
B = jnp.array([x.mean(), y.mean(), z.mean(),
               1.0,
               1.0,
               1.0])
A = jnp.array([x, y, z],).T

In [4]:
B

DeviceArray([ 3.55271368e-17, -8.57880695e-02,  3.55271368e-17,
              1.00000000e+00,  1.00000000e+00,  1.00000000e+00],            dtype=float64)

In [5]:
def predict(B, A):
    # pedict r^2 for B and A
    u = B[5]*(A[:, 1] - B[1]) - B[4]*(A[:, 2] - B[2])
    v = B[3]*(A[:, 2] - B[2]) - B[5]*(A[:, 0] - B[0])
    w = B[4]*(A[:, 0] - B[0]) - B[3]*(A[:, 1] - B[1])
    r_squared = u**2 + v**2 + w**2
    return r_squared


In [6]:
t = predict(B, A)

In [7]:
random.normal(key, (1,))

DeviceArray([-0.78476578], dtype=float64)

In [8]:
target = jnp.ones(A.shape[0])*(r**2)


In [9]:
n = A.shape[0]

def loss(B):
    # compute mse
    pred = predict(B, A)
    e = pred - target
    return jnp.dot(e, e) / n


In [10]:
loss(B)

DeviceArray(35.55270606, dtype=float64)

In [11]:
grad(loss)(B)

DeviceArray([ 12.63526522, -25.27053044,  12.63526522,  77.29133001,
              50.76427531,  77.29133001], dtype=float64)

In [12]:
from scipy.optimize import fmin_bfgs


In [13]:
res = fmin_bfgs(loss, B, fprime=grad(loss), norm=2.0, args=(), gtol=1e-17, epsilon=1.4901161193847656e-08, maxiter=None, full_output=1, disp=1, retall=0, callback=None)

         Current function value: 0.033471
         Iterations: 16
         Function evaluations: 84
         Gradient evaluations: 73


In [14]:
res

(array([-6.85959248e-01, -6.56914312e-02, -6.85959248e-01,  7.38814753e-01,
         2.78577024e-04,  7.38814753e-01]),
 0.033470533498233944,
 array([-8.59031350e-13,  7.59982090e-09, -2.00655231e-12, -3.32794302e-08,
         9.70232791e-09, -3.32895617e-08]),
 array([[ 5.75114938e-01,  7.17928039e-03, -4.24886214e-01,
         -1.00882024e-02, -1.73861453e-02, -1.00983773e-02],
        [ 7.17928039e-03,  6.92797646e-02,  7.17792101e-03,
         -4.46403874e-04, -4.42111814e-02, -4.58402954e-04],
        [-4.24886214e-01,  7.17792101e-03,  5.75112633e-01,
         -1.00878070e-02, -1.73851403e-02, -1.00979811e-02],
        [-1.00882024e-02, -4.46403874e-04, -1.00878070e-02,
          5.04226857e-01,  2.60678050e-04, -4.95769652e-01],
        [-1.73861453e-02, -4.42111814e-02, -1.73851403e-02,
          2.60678050e-04,  4.76262143e-02,  2.69547907e-04],
        [-1.00983773e-02, -4.58402954e-04, -1.00979811e-02,
         -4.95769652e-01,  2.69547907e-04,  5.04233845e-01]]),
 84,
 73,

In [15]:
x0_hat = res[0][0]
x0_hat

-0.6859592479673602

In [16]:
y0_hat = res[0][1]
y0_hat

-0.06569143119842225

In [17]:
z0_hat = res[0][2]
z0_hat

-0.6859592479624994

In [18]:
a_hat = res[0][3]
a_hat

0.7388147534199707

In [19]:
b_hat = res[0][4]
b_hat

0.00027857702404679934

In [20]:
c_hat = res[0][5]
c_hat

0.7388147534546972