In [1]:
from IPython.core.display import HTML
with open ("../style.css", "r") as file:
    css = file.read()
HTML(css)

# Minimization in PyTorch

In [2]:
import torch

In [3]:
def f_and_fs(t):
    # declare x to be a tensor with value t that keeps track of its gradient
    x = torch.tensor(t, requires_grad=True, dtype=torch.double)
    y = torch.exp(x) - 2 * x * x + 1
    # use backpropagation to compute the gradient of y w.r.t. x
    y.backward()
    # the gradient is stored in x.grad
    # the function item returns the value stored in a tensor
    return y.item(), x.grad.item()

In [4]:
def findMinimum(start, eps):
    t     = start
    f, fs = f_and_fs(t)
    print(f'cnt = 0, f({t}) = {f}, fs = {fs}')
    alpha = 0.1   # learning rate
    cnt   = 0     # number of iterations
    while True:
        cnt += 1
        tOld, fOld = t, f
        t    -= alpha * fs
        f, fs = f_and_fs(t)
        print(f'cnt = {cnt}, f({t}) = {f}, fs = {fs}')
        if abs(t - tOld) <= abs(t) * eps:
            return t, f, fs, cnt            
        if f >= fOld:     # f didn't decrease, learning rate is too high
            alpha *= 0.5  # decrease the learning rate
            t, f = tOld, fOld    # reset t
            continue
        else:             # f has decreased
            alpha *= 1.2  # increase the learning rate

In [5]:
findMinimum(1.0, 1e-12)

cnt = 0, f(1.0) = 1.718281828459045, fs = -1.281718171540955
cnt = 1, f(1.1281718171540955) = 1.5444589460545912, fs = -1.4226850245202431
cnt = 2, f(1.2988940200965247) = 1.2909893916591888, fs = -1.5303353378418878
cnt = 3, f(1.5192623087457564) = 0.9525376228223665, fs = -1.508195686609687
cnt = 4, f(1.7798785233919103) = 0.5932010074338105, fs = -1.1903779700702968
cnt = 5, f(2.026715299265687) = 0.3739675854092086, fs = -0.517743803098333
cnt = 6, f(2.1555465252782513) = 0.3398451825377595, fs = 0.010420726703040728
cnt = 7, f(2.152434912957886) = 0.3398351407176765, fs = -0.00395240207382308
cnt = 8, f(2.1538511300803664) = 0.3398341661805162, fs = 0.0025790269559937684
cnt = 9, f(2.1527421956957986) = 0.3398341436997523, fs = -0.0025367168251051453
cnt = 10, f(2.1540510858590753) = 0.3398347742021759, fs = 0.003502591747198025
cnt = 11, f(2.151838565491885) = 0.3398383163972465, fs = -0.006697521693945774
cnt = 12, f(2.1536061392168864) = 0.3398336729082505, fs = 0.0014479217860

(2.1532923585154133, 0.33983344576952135, -2.5810388493141545e-08, 33)

In [6]:
torch.cuda.is_available()

True