In [1]:
%matplotlib qt5
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation

In [2]:
fig1 = plt.figure(figsize=plt.figaspect(1))
ax1 = fig1.add_subplot(1, 1, 1, projection='3d')

X = np.arange(-4, 4.25, 0.125)
Y = np.arange(-4, 4.25, 0.125)
X, Y = np.meshgrid(X, Y)
R = X**2 + Y**2 # 目标函数(凸函数),只有一个批次
surf = ax1.plot_surface(X, Y, R, cmap=cm.coolwarm)
plt.tight_layout()
plt.show()

In [3]:
from sympy import *
from sympy.abc import x, y


z = x**2 + y**2
dz_x = diff(z, x) # 对x求偏导
dz_y = diff(z, y) # 对y求偏导
print(dz_x)
print(dz_y)

2*x
2*y


In [4]:
grad = np.empty((2, *X.shape))
for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        grad[0, i, j] = dz_x.subs({x: X[i][j]})
        grad[1, i, j] = dz_y.subs({y: Y[i][j]})

In [5]:
grad # 各点处的梯度

array([[[-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25],
        [-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25],
        [-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25],
        ...,
        [-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25],
        [-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25],
        [-8.  , -7.75, -7.5 , ...,  7.75,  8.  ,  8.25]],

       [[-8.  , -8.  , -8.  , ..., -8.  , -8.  , -8.  ],
        [-7.75, -7.75, -7.75, ..., -7.75, -7.75, -7.75],
        [-7.5 , -7.5 , -7.5 , ..., -7.5 , -7.5 , -7.5 ],
        ...,
        [ 7.75,  7.75,  7.75, ...,  7.75,  7.75,  7.75],
        [ 8.  ,  8.  ,  8.  , ...,  8.  ,  8.  ,  8.  ],
        [ 8.25,  8.25,  8.25, ...,  8.25,  8.25,  8.25]]])

In [6]:
def function():
    """函数func"""
    a, b = symbols('a, b')
    return a**2 + b**2 # 凸函数

def numerical_gradient(func, point):
    """求函数func在点point处的梯度"""
    a, b = symbols('a, b')
    dz_a = diff(func, a).subs({a: point[0]})
    dz_b = diff(func, b).subs({b: point[1]})
    return np.array([float(dz_a), float(dz_b)])


numerical_gradient(function(), (-3, 4))

array([-6.,  8.])

In [7]:
def gradient_descent(f, init_x, lr=0.01, step_num=10000):
    x = init_x
    x_history = []

    for _ in range(step_num):
        x_history.append(x.copy())
        gradient = numerical_gradient(f, x)
        x -= lr * gradient # 迭代值更新过程
    
    return  np.array(x_history)


init_x = np.array([-4.0, 4.0])  # 初始值   

lr = 0.01 # 学习率不能过大或过小(没有对学习率进行一维搜索)
step_num = 100
x_history =  gradient_descent(function(), init_x, lr=lr, step_num=step_num)
print(x_history)

[[-4.          4.        ]
 [-3.92        3.92      ]
 [-3.8416      3.8416    ]
 [-3.764768    3.764768  ]
 [-3.68947264  3.68947264]
 [-3.61568319  3.61568319]
 [-3.54336952  3.54336952]
 [-3.47250213  3.47250213]
 [-3.40305209  3.40305209]
 [-3.33499105  3.33499105]
 [-3.26829123  3.26829123]
 [-3.2029254   3.2029254 ]
 [-3.13886689  3.13886689]
 [-3.07608956  3.07608956]
 [-3.01456777  3.01456777]
 [-2.95427641  2.95427641]
 [-2.89519088  2.89519088]
 [-2.83728706  2.83728706]
 [-2.78054132  2.78054132]
 [-2.7249305   2.7249305 ]
 [-2.67043189  2.67043189]
 [-2.61702325  2.61702325]
 [-2.56468278  2.56468278]
 [-2.51338913  2.51338913]
 [-2.46312135  2.46312135]
 [-2.41385892  2.41385892]
 [-2.36558174  2.36558174]
 [-2.31827011  2.31827011]
 [-2.2719047   2.2719047 ]
 [-2.22646661  2.22646661]
 [-2.18193728  2.18193728]
 [-2.13829853  2.13829853]
 [-2.09553256  2.09553256]
 [-2.05362191  2.05362191]
 [-2.01254947  2.01254947]
 [-1.97229848  1.97229848]
 [-1.93285251  1.93285251]
 

In [8]:
fig2 = plt.figure()
ax2 = fig2.add_subplot(1, 1, 1) 
ax2.quiver(X, Y, -grad[0], -grad[1], color='b')
ax2.plot(x_history[:,0], x_history[:,1], color='red', linestyle='--')
'''
X,Y:箭头位置
color:箭头颜色
-grad[0], -grad[1]:箭头方向
'''
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.show()

In [9]:
fig_dynamic = plt.figure()
ax_dynamic = fig_dynamic.gca(projection='3d')
ax_dynamic.plot_wireframe(X, Y, R, cmap=cm.coolwarm)
route_z = x_history[:, 0]**2 + x_history[:, 1]**2
ims = []
x_data, y_data, z_data = [], [], []

for i in range(len(route_z)):
    x_data.append(x_history[i, 0])
    y_data.append(x_history[i, 1])
    z_data.append(route_z[i])
    im = ax_dynamic.scatter(x_data, y_data, z_data,
                            marker='*', color='red') # 没有set_data_3d方法,故采用ArtistAnimation进行动态图绘制
    ims.append([im]) # 必须是[im]

line_ani = animation.ArtistAnimation(fig_dynamic, artists=ims, interval=100, 
                                   blit=False) # 尝试将blit设置为True or False,选择较合适的效果