In this exercise, I'm going to code an Adam optimizer from scratch (i.e. no use of external libraries such as Keras, Tensorflow or PyTorch). The goal is to understand what's happening behind the scenes.

In [1]:
# the only library I can use, welp
import numpy as np

In [2]:
class AdamOptim():
    def __init__(self, eta=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.m_dw, self.v_dw = 0, 0
        self.m_db, self.v_db = 0, 0
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.eta = eta
    def update(self, t, w, b, dw, db):
        ## dw, db are from current minibatch
        ## momentum beta 1
        # *** weights *** #
        self.m_dw = self.beta1*self.m_dw + (1-self.beta1)*dw
        # *** biases *** #
        self.m_db = self.beta1*self.m_db + (1-self.beta1)*db

        ## rms beta 2
        # *** weights *** #
        self.v_dw = self.beta2*self.v_dw + (1-self.beta2)*(dw**2)
        # *** biases *** #
        self.v_db = self.beta2*self.v_db + (1-self.beta2)*(db)

        ## bias correction
        m_dw_corr = self.m_dw/(1-self.beta1**t)
        m_db_corr = self.m_db/(1-self.beta1**t)
        v_dw_corr = self.v_dw/(1-self.beta2**t)
        v_db_corr = self.v_db/(1-self.beta2**t)

        ## update weights and biases
        w = w - self.eta*(m_dw_corr/(np.sqrt(v_dw_corr)+self.epsilon))
        b = b - self.eta*(m_db_corr/(np.sqrt(v_db_corr)+self.epsilon))
        return w, b

Let's check and see if it works lol

In [3]:
## define loss functions and gradient descent. We don't really use the loss function here.
def loss_function(m):
    return m**2-2*m+1
## take derivative
def grad_function(m):
    return 2*m-2
def check_convergence(w0, w1):
    return (w0 == w1)
## initialize weights and biases, and our optimizer
w_0 = 0
b_0 = 0
adam = AdamOptim()
t = 1 
converged = False

while not converged:
    dw = grad_function(w_0)
    db = grad_function(b_0)
    w_0_old = w_0
    w_0, b_0 = adam.update(t,w=w_0, b=b_0, dw=dw, db=db)
    if check_convergence(w_0, w_0_old):
        print('converged after '+str(t)+' iterations')
        break
    else:
        print('iteration '+str(t)+': weight='+str(w_0))
        t+=1

9
iteration 257: weight=0.9983202346112277
iteration 258: weight=0.9983944097882684
iteration 259: weight=0.9984656928634669
iteration 260: weight=0.998534181613794
iteration 261: weight=0.9985999710648938
iteration 262: weight=0.9986631535490955
iteration 263: weight=0.9987238187628614
iteration 264: weight=0.9987820538236523
iteration 265: weight=0.9988379433261909
iteration 266: weight=0.9988915693981099
iteration 267: weight=0.9989430117549652
iteration 268: weight=0.9989923477546039
iteration 269: weight=0.9990396524508698
iteration 270: weight=0.9990849986466379
iteration 271: weight=0.9991284569461627
iteration 272: weight=0.9991700958067327
iteration 273: weight=0.9992099815896199
iteration 274: weight=0.9992481786103167
iteration 275: weight=0.9992847491880505
iteration 276: weight=0.9993197536945726
iteration 277: weight=0.9993532506022106
iteration 278: weight=0.9993852965311832
iteration 279: weight=0.999415946296171
iteration 280: weight=0.9994452529521384
iteration 281: w

Yeeet I guess building an Adam optimizer from scratch ain't that hard ¯\_(ツ)_/¯