In [None]:
import jax.numpy as jnp
import scipy

In [None]:
def func(x):
    return jnp.sin(x)

def func_approx(x, a, b, c):
    return a*x**3 + b*x**2 + c*x

In [None]:
import matplotlib.pyplot as plt

x = jnp.arange(0, jnp.pi*2, 0.1)
y = func(x)
plt.plot(x, y)
plt.show()

In [None]:
from scipy import optimize

# solution 1: curve fitting
popt, pcov = optimize.curve_fit(func_approx, x, y)
print(popt)
y_ = func_approx(x, *popt)
plt.plot(x, y, label="exact")
plt.plot(x, y_, label="approx")
plt.legend()
plt.show()

In [None]:
# solution 2: Nelder-Mead optimization
def obj_func(params):
    a, b, c = params

    x = jnp.arange(0, jnp.pi*2, 0.1)
    y0 = func(x)
    y1 = func_approx(x, a, b, c)
    return jnp.mean(jnp.square(y1 - y0))

initial_guess = [0,0,0.5]
result = optimize.minimize(obj_func, initial_guess, method="Nelder-Mead")
y_ = func_approx(x, *result.x)
plt.plot(x, y, label="exact")
plt.plot(x, y_, label="approx")
plt.legend()
plt.show()


In [None]:
import jax

# solution 3: gradient descent
@jax.jit
def loss(params):
    a, b, c = params
    x = jnp.arange(0, jnp.pi*2, 0.1)
    y0 = func(x)
    y1 = func_approx(x, a, b, c)
    return jnp.mean(jnp.square(y1 - y0))

@jax.jit
def update_parameters_step(params, learning_rate=0.0001):
  grad_loss = jax.grad(loss)
  grads = grad_loss(params)
  return [param - learning_rate * grad for param, grad in zip(params, grads)]

def optimize_loop(x0, print_loss = False):
    NUM_STEPS = 50000*4
    for n in range(NUM_STEPS):
        x0 = update_parameters_step(x0)
        if print_loss and n % 1000 == 0:
            print(x0, obj_func(x0))
    return x0

result = optimize_loop([0.0, 0.0, 1.0], print_loss=False)
y_ = func_approx(x, *result)
plt.plot(x, y, label="exact")
plt.plot(x, y_, label="approx")
plt.legend()
plt.show()

In [None]:
import numpy as np

# solution 4: SVD
n = 40
A = np.zeros([n, 3])
x = np.random.random(n) * np.pi * 2
A_ = A.transpose()
A_[0] = x**3
A_[1] = x**2
A_[2] = x
A = A_.transpose()
b = func(x)

U, S, VT = np.linalg.svd(A, full_matrices=False)
print(U.T.shape)
xtilde = VT.T @ np.linalg.inv(np.diag(S)) @ U.T @ b

x = jnp.arange(0, jnp.pi*2, 0.1)
y = func(x)
y_ = func_approx(x, *xtilde)
print(xtilde)
plt.plot(x, y, label="exact")
plt.plot(x, y_, label="approx")
plt.legend()
plt.show()