In [1]:
%matplotlib notebook

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [3]:
def f(x:float,y:float):
    return x**2 + y**2 / 5.0

In [4]:
def get_numerical_gradients(func:callable, x:float,y:float, e:float=0.000000001):
    return (func(x+e, y) - func(x-e, y)) / (2*e),  (func(x, y+e) - func(x, y-e)) / (2*e)

In [5]:
iterations=100
init_x, init_y = 10,10
LR=0.01

In [6]:
def sgd_step(func:callable, x:float,y:float, lr:float=LR):
    x_grad, y_grad = get_numerical_gradients(func,x,y)
    return x - lr* x_grad, y - lr * y_grad

def momentum_step(func:callable, x:float,y:float, v_x:float, v_y:float, lr:float=LR, gamma:float=0.9, nesterov:bool=False):
    if nesterov:
        x_grad, y_grad = get_numerical_gradients(func,x-gamma*v_x,y-gamma*v_y)
    else:
        x_grad, y_grad = get_numerical_gradients(func,x,y)
    v_x = gamma * v_x  + lr * x_grad
    v_y = gamma * v_y + lr * y_grad
    print("vx:{} vy:{}".format(v_x,v_y))
    x = x - v_x
    y = y - v_y
    return x, y, v_x, v_y

In [7]:
class Simulator(object):
    def __init__(self, iterations:int, init_x:float, init_y:float, func:callable, lr:float):
        self.init_x  = init_x
        self.init_y = init_y
        self.iterations = iterations
        self.func = func
        self.f_values = [self.func(self.init_x, self.init_y)]
        self.x_values = [init_x]
        self.y_values = [init_y]
        self.lr = lr
        self.x = init_x
        self.y = init_y
    def step(self):
        raise NotImplementedError
    
    def simulate(self):
        for it in range(self.iterations):
            self.step()


In [8]:
class SGDSimulator(Simulator):
    def __init__(self, iterations:int, init_x:float, init_y:float, func:callable, lr:float):
        super(SGDSimulator,self).__init__(iterations, init_x, init_y, func, lr)
    
    def step(self):
        self.x,self.y = sgd_step(self.func,self.x,self.y)
        self.x_values.append(self.x)
        self.y_values.append(self.y )
        loss = self.func(self.x,self.y)
        self.f_values.append(loss)
        

In [9]:
class MomentumSimulator(Simulator):
    def __init__(self, iterations:int, init_x:float, init_y:float, func:callable, lr, gamma:float=0.9, nesterov:bool=False):
        super(MomentumSimulator,self).__init__(iterations=iterations,init_x=init_x, init_y=init_y, func=func, lr=lr)
        self.gamma = gamma
        self.v_x = 0.0
        self.v_y = 0.0
        self.v_x_history = []
        self.v_y_history = []
        self.nesterov = nesterov
    
    
    def step(self):
        self.x,self.y, self.v_x, self.v_y = momentum_step(func=self.func,x=self.x,y=self.y, v_x=self.v_x, 
                                                          v_y=self.v_y,gamma=self.gamma, nesterov=self.nesterov)
        self.x_values.append(self.x)
        self.y_values.append(self.y )
        loss = self.func(self.x,self.y)
        self.f_values.append(loss)
        self.v_x_history.append(self.v_x)
        self.v_y_history.append(self.v_y)

In [10]:
class NesterovSimulator(Simulator):
    def __init__(self, iterations, init_x, init_y, func, lr, gamma=0.9):
        super(NesterovSimulator,self).__init__(iterations, init_x, init_y, func, lr)
        self.gamma = gamma
        self.v_x = 0.0
        self.v_y = 0.0
    
    def step(self):
        self.x,self.y, self.v_x, self.v_y = nesterov_step(self.func,self.x,self.y, self.v_x, self.v_y)
        self.x_values.append(self.x)
        self.y_values.append(self.y )
        loss = self.func(self.x,self.y)
        self.f_values.append(loss)

In [11]:
momentum = MomentumSimulator(iterations=iterations,init_x=init_x,init_y=init_y,func=f,lr=LR)
momentum.simulate()

vx:0.2000000165480742 vy:0.04000000330961484
vx:0.3760000311103795 vy:0.07583999206417502
vx:0.5268800862268108 vy:0.10779258019510962
vx:0.652134573897456 vy:0.1361188232351651
vx:0.7518208903434243 vy:0.16106795307280208
vx:0.8265020697706404 vy:0.1828778497738881
vx:0.8771851361046856 vy:0.2017752530584005
vx:0.905256169333768 vy:0.21797588157518533
vx:0.9124149729528597 vy:0.23168450472210367
vx:0.9006095955617042 vy:0.24309552242764668
vx:0.8719725686124192 vy:0.25239307839231606
vx:0.8287597900329773 vy:0.25975130411441383
vx:0.7732930947612516 vy:0.26533468950358396
vx:0.7079072004308918 vy:0.2692983897876481
vx:0.6349017542075752 vy:0.2717885283741093
vx:0.5564988233118793 vy:0.27294250495264294
vx:0.4748062085656699 vy:0.2728893083144753
vx:0.3917867180841737 vy:0.27174986894409714
vx:0.30923345346941183 vy:0.26963737635180685
vx:0.2287508447626599 vy:0.2666575913183511
vx:0.1517414810049236 vy:0.26290914467867044
vx:0.07939821901711919 vy:0.2584839128086803
vx:0.0127013179132

In [12]:
sgd = SGDSimulator(iterations=iterations,init_x=init_x,init_y=init_y,func=f,lr=LR)
sgd.simulate()

In [13]:
nesterov = MomentumSimulator(iterations=iterations,init_x=init_x,init_y=init_y,func=f,lr=LR, nesterov=True)
nesterov.simulate()

vx:0.2000000165480742 vy:0.04000000330961484
vx:0.37239999528537737 vy:0.07569603610591003
vx:0.5170087717942806 vy:0.10739116618196931
vx:0.6342135662862347 vy:0.13537307594901904
vx:0.7249039552448266 vy:0.15991460621478382
vx:0.7903947772831543 vy:0.18127396136691232
vx:0.8323497140423797 vy:0.19969539141797554
vx:0.8527070418213611 vy:0.21540959268314452
vx:0.8536080474069934 vy:0.22863413944437383
vx:0.8373305772747159 vy:0.23957408561098745
vx:0.8062272423493111 vy:0.24842236410565896
vx:0.7626695535996642 vy:0.25536027065677236
vx:0.7089982835625104 vy:0.2605579702658205
vx:0.6474802572081992 vy:0.2641749619098479
vx:0.5802717463384208 vy:0.2663605234190015
vx:0.5093884022422401 vy:0.26725422732952153
vx:0.43668152575560104 vy:0.26698631633359077
vx:0.3638204406642429 vy:0.2656782188198831
vx:0.29228054663315656 vy:0.26344292404221525
vx:0.22333675016345111 vy:0.2603854441163118
vx:0.15806159366534409 vy:0.2566031765261891
vx:0.0973276661263672 vy:0.2521863280795262
vx:0.0418137

In [14]:
sgd.x_values[-10:]

[1.5906433400891729,
 1.5588304677527276,
 1.527653859790803,
 1.4971007708430761,
 1.4671587575298872,
 1.4378155807526127,
 1.409059258984371,
 1.3808780682700217,
 1.35326050669903,
 1.3261952944054656]

In [15]:
nesterov.x_values[-10:]

[0.02072409982451237,
 0.02285673935755267,
 0.02428059263896489,
 0.025050819381195254,
 0.025229142980213125,
 0.024881841535559605,
 0.024077884830782954,
 0.022887237320650004,
 0.021379341470838597,
 0.01962179050256997]

In [16]:
plt.plot(sgd.f_values, label='SGD')
plt.plot(momentum.f_values, label='Momentum')
plt.plot(nesterov.f_values, label='Nesterov')

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x1d0af6ae278>]

In [17]:
X,Y = np.meshgrid(np.arange(-15,15,0.1), np.arange(-15,15,0.1))
# X,Y = np.meshgrid(np.arange(-x_max,x_max,0.1), np.arange(-y_max,y_max,0.1))
Z = f(X,Y)

x_sgd_anim = []
y_sgd_anim = []
x_momentum_anim = []
y_momentum_anim = []

x_nesterov_anim = []
y_nesterov_anim = []
fig, ax = plt.subplots()
fig.set_size_inches(10, 10)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
line_sgd, = ax.plot(0, 0, marker='<', label='SGD', color='r')
line_mom, = ax.plot(0, 0, marker='<', label='Momentum', color='m')
line_nesterov, = ax.plot(0, 0, marker='<', label='Nesterov', color='g')
plt.legend()
ax.contourf(X,Y, Z)

def animation_frame(step):
    x_sgd_anim.append(step[0])
    y_sgd_anim.append(step[1])
    x_momentum_anim.append(step[2])
    y_momentum_anim.append(step[3])
    x_nesterov_anim.append(step[4])
    y_nesterov_anim.append(step[5])
    line_sgd.set_xdata(x_sgd_anim)
    line_sgd.set_ydata(x_sgd_anim)
    line_mom.set_xdata(x_momentum_anim)
    line_mom.set_ydata(y_momentum_anim)
    line_nesterov.set_xdata(x_nesterov_anim)
    line_nesterov.set_ydata(y_nesterov_anim)    
    return line_sgd,line_mom, line_nesterov,

# frames = np.array([[x,y,z,w] for x,y,z,w in zip(x_sgd, y_sgd, x_mom,y_mom)] )
frames = np.array([[x1,y1,x2,y2, x3,y3] for x1,y1,x2,y2,x3,y3 in 
                   zip(sgd.x_values, sgd.y_values, momentum.x_values,momentum.y_values, nesterov.x_values,nesterov.y_values)] )
animation = FuncAnimation(fig, func=animation_frame, frames=frames, interval=100)

plt.imshow(Z)
plt.colorbar()
plt.grid()
plt.show()



<IPython.core.display.Javascript object>