In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize

In [2]:
w = [5, 6]
n_sample = 100
x0 = -5 + 12 * np.random.random_sample(n_sample)
x1 = -6 + 15 * np.random.random_sample(n_sample)
b = -1 + 2 * np.random.random_sample(n_sample)
y = np.array([w[0], w[1] ** 2]).dot(np.stack([x0, x1])) + b

X0, X1 = np.meshgrid(x0, x1)
X0 = np.ravel(X0)
X1 = np.ravel(X1)
B = -1 + 2 * np.random.random_sample(n_sample ** 2)
Y = np.array([w[0], w[1] ** 2]).dot(np.stack([X0, X1])) + B

In [3]:
def loss(args):
    w0, w1, b = args
    y_hat = np.array([w0, w1 ** 2]).dot(np.stack([X0, X1])) + b
    return np.sum((Y - y_hat) ** 2) / (n_sample ** 2)

In [4]:
def sparse_loss(args):
    w0, w1, b = args
    y_hat = np.array([w0, w1 ** 2]).dot(np.stack([x0, x1])) + b
    return np.sum((y - y_hat) ** 2) / n_sample

In [5]:
def jac(args):
    w0, w1, b = args
    y_hat = np.array([w0, w1 ** 2]).dot(np.stack([x0, x1])) + b
    dy = -2 * np.sum(y - y_hat) / n_sample
    return dy * np.array([np.sum(x0), 2 * w1 * np.sum(x1), 1])

In [6]:
def hess(args):
    w0, w1, b = args
    y_hat = np.array([w0, w1 ** 2]).dot(np.stack([x0, x1])) + b
    dy = -2 * np.sum(y - y_hat) / n_sample
    
    return np.array([[2 * np.sum(x0) / n_sample, 4 * w1 * np.sum(x1) / n_sample, 2 / n_sample]]).T.dot(
        np.array([[np.sum(x0), 2 * w1 * np.sum(x1), 1]])
    ) + dy * np.array([
        [0, 0, 0],
        [0, 2 * np.sum(x1), 0],
        [0, 0, 0]
    ])

In [7]:
minimize(loss, [1, 1, 0])

      fun: 0.3327056693478594
 hess_inv: array([[ 4.17662281e-02, -4.78359440e-06, -4.33940299e-02],
       [-4.78359440e-06,  2.03464016e-04, -4.21146315e-03],
       [-4.33940299e-02, -4.21146315e-03,  6.31998953e-01]])
      jac: array([ 2.08616257e-07, -7.46920705e-06,  6.70552254e-08])
  message: 'Optimization terminated successfully.'
     nfev: 55
      nit: 7
     njev: 11
   status: 0
  success: True
        x: array([ 4.99941132e+00,  5.99987634e+00, -2.66286520e-03])

In [9]:
minimize(sparse_loss, [1, 2, 0], jac=jac, hess=hess, method='Newton-CG')

     fun: 155.5809260120686
     jac: array([-3.41754980e+02, -6.85069832e+03, -3.26931063e+00])
    nfev: 31
    nhev: 2
     nit: 1
    njev: 20
  status: 2
 success: False
       x: array([1.62042121e+00, 6.08659849e+00, 5.93509904e-03])