In [None]:
# 2021/12/28
# keyword: meshgrid, matplotlib 3D

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import numpy as np
import torch
from torch.optim import SGD, Adam
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
from IPython import display


def f(x):
    return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2


if __name__ == '__main__':
    # begin from (-6., 6., f((-6., 6.)))
    x_ori = -6.
    y_ori = 6.
    z_ori = f((x_ori, y_ori))    
    point_hist = [[x_ori], [y_ori], [z_ori]]    # init ori point

    point = torch.tensor([x_ori, y_ori], requires_grad=True)

    lr = 1e-4
    momentum = 0.9
    optim = SGD([point], lr=lr, momentum=momentum)
    # optim = Adam([point], lr=lr)
    
    n_epoch = 100
    for i_epoch in range(n_epoch):
        logits = f(point)
        optim.zero_grad()
        logits.backward()
        optim.step()

        point_hist[0].append(point.tolist()[0])
        point_hist[1].append(point.tolist()[1])
        point_hist[2].append(f(point.tolist()))

    # visualization
    fig = plt.figure("Convergence (SGD)", figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title("epoch %d, lr %.5f" % (n_epoch, lr), size=10)
    
    # plot surface for f((x, y))
    x, y = np.meshgrid(np.arange(-6.5, 6.5, 0.1), np.arange(-6.5, 6.5, 0.1))
    ax.plot_surface(x, y, f((x, y)), alpha=0.7)
    ax.view_init(60, -30)

    # plt convergence track
    ax.plot(point_hist[0], point_hist[1], point_hist[2], label='(%.1f, %.1f)' % (x_ori, y_ori), ls='-', lw=2.)
    ax.legend()
    fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
    
    plt.show()