In [4]:
# 下面两个函数定义了多个数据的梯度，其中每个数据可以是一元或者多元函数
def _numerical_gradient_1d(f,x):
    """
    求一元或者多元函数梯度，传入的 x 是一维数组，代表坐标，浮点数。比如 x = [1.0,2.0] 就是二元函数在 (1,2) 上的点。求的是在这个点上的二元函数的两个方向的偏导
    """
    h = 1e-4

    grad = np.zeros_like(x) # 假如是二元函数，传入变量 x = [3,4]，则现在 grad = [0,0]，grad[0],grad[1] 分别是二元函数的两个变量的梯度

    for idx in range(x.size): # x: [3,4], idx: [0,1]
        tmp_val = x[idx] # tmp_val=x[0]=3
        x[idx] = tmp_val + h # x: [3+h,4]
        fxh1 = f(x) # [3+h,4] 对应的函数值 f
        x[idx] = tmp_val - h
        fxh2 = f(x)
        grad[idx] = (fxh1-fxh2) / (2*h)
        
        x[idx] = tmp_val # 还原x
        
    return grad
    

def numerical_gradient_2d(f,X):
    """2d数组的梯度"""

    if X.ndim == 1:
        return _numerical_gradient_1d(f,X)
    else:
        grad = np.zeros_like(X) # X=[[2,3,4],[1,2,1]], grad=[[0,0,0],[0,0,0]]
        
        for idx, x in enumerate(X): #  x=[2,3,4],[1,2,1], idx=0,1
            grad[idx] = _numerical_gradient_1d(f,x)
        
        return grad


# 梯度下降函数
def gradient_descent(f, init_x, lr=0.01, step_num=300):
    x = init_x # 假设是二元函数，x=[2,2], f=x[0]**2+x[1]**2
    x_history = []

    for i in range(step_num):
        x_history.append(x.copy())
        
        grad = numerical_gradient_2d(f,x)  # grad=[4,4]
        x -= lr * grad  # x = [2,2] - 0.01*[4,4] = [1.96,1.96]

    return x, np.array(x_history)
        

In [9]:
# 梯度下降法的测试
f = lambda x: x[0]**2 + x[1]**2
x = np.array([2.0,2.0])
x_new,his = gradient_descent(f,x)
his,x

(array([[2.        , 2.        ],
        [1.96      , 1.96      ],
        [1.9208    , 1.9208    ],
        [1.882384  , 1.882384  ],
        [1.84473632, 1.84473632],
        [1.80784159, 1.80784159],
        [1.77168476, 1.77168476],
        [1.73625107, 1.73625107],
        [1.70152605, 1.70152605],
        [1.66749552, 1.66749552],
        [1.63414561, 1.63414561],
        [1.6014627 , 1.6014627 ],
        [1.56943345, 1.56943345],
        [1.53804478, 1.53804478],
        [1.50728388, 1.50728388],
        [1.47713821, 1.47713821],
        [1.44759544, 1.44759544],
        [1.41864353, 1.41864353],
        [1.39027066, 1.39027066],
        [1.36246525, 1.36246525],
        [1.33521594, 1.33521594],
        [1.30851162, 1.30851162],
        [1.28234139, 1.28234139],
        [1.25669456, 1.25669456],
        [1.23156067, 1.23156067],
        [1.20692946, 1.20692946],
        [1.18279087, 1.18279087],
        [1.15913505, 1.15913505],
        [1.13595235, 1.13595235],
        [1.113