This notebook demonstrates the basic behind gradient descent.

In [25]:
import numpy

def gradient_descent(func, grad_func, w_init, n_epochs=100, lr=0.001, verbose=0):
    
    i = 0
    w = w_init

    # conduct a fixed number of steps; other stopping criteria
    # could also be used (e.g., stopping once the difference
    # between the current function value and the one of the 
    # previous iteration becomes very small)
    while i < n_epochs:
        
        # conduct gradient update step!
        delta_w = -lr * grad_func(w)
        w = w + delta_w
                
        if verbose > 0:
            print("f={}; w: {}".format(func(w), w))
            
        # increment counter
        i += 1            
    
    return w

In [26]:
# simple function (e.g., f(w_1,w_2) = w_1*w_1 + w_2*w_2) for d=2)
def f(w):
    return numpy.sum(w*w)

# corresponding gradient
def grad(w):
    return 2*w

In [27]:
# starting point
w_init = numpy.array([10,10])

# learning rate (usually has a big impact!)
lr = 0.01

# apply gradient descent
w_opt = gradient_descent(f, grad, w_init, n_epochs=250, lr=lr, verbose=1)

f=192.08000000000004; w: [9.8 9.8]
f=184.47363200000004; w: [9.604 9.604]
f=177.1684761728; w: [9.41192 9.41192]
f=170.15260451635714; w: [9.2236816 9.2236816]
f=163.41456137750941; w: [9.03920797 9.03920797]
f=156.94334474696007; w: [8.85842381 8.85842381]
f=150.72838829498042; w: [8.68125533 8.68125533]
f=144.75954411849915; w: [8.50763023 8.50763023]
f=139.0270661714066; w: [8.33747762 8.33747762]
f=133.5215943510189; w: [8.17072807 8.17072807]
f=128.23413921471857; w: [8.00731351 8.00731351]
f=123.15606730181571; w: [7.84716724 7.84716724]
f=118.27908703666381; w: [7.69022389 7.69022389]
f=113.5952351900119; w: [7.53641941 7.53641941]
f=109.09686387648745; w: [7.38569103 7.38569103]
f=104.77662806697855; w: [7.23797721 7.23797721]
f=100.6274735955262; w: [7.09321766 7.09321766]
f=96.64262564114335; w: [6.95135331 6.95135331]
f=92.81557766575408; w: [6.81232624 6.81232624]
f=89.14008079019023; w: [6.67607972 6.67607972]
f=85.61013359089868; w: [6.54255812 6.54255812]
f=82.2199723006