In [1]:
import numpy as np
class Optimizer():
    def __init__(self, eta=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8, opt="adam"):
        self.m_dc, self.v_dc = 0.0, 0.0
        self.m_dw, self.v_dw = 0.0, 0.0
        self.beta1 = beta1 # momentum
        self.beta2 = beta2 # rms
        self.epsilon = epsilon
        self.eta = eta
        if opt.lower()=="adam": self.opt = self._adam
        elif opt.lower()=="momentum": self.opt = self._momentum
        elif opt.lower()=="rmsprop": self.opt = self._rmsprop
        else: self.opt = self._constant
        
    def update(self, t, c, w, dc, dw):
        c, w = self.opt(t, c, w, dc, dw)
        return c, w
        
    def _constant(self, t, c, w, dc, dw):
        c = c - self.eta * dc
        w = w - self.eta * dw
        return c, w
        
    def _momentum(self, t, c, w, dc, dw):
        self.m_dc = self.beta1*self.m_dc - eta*dc
        self.m_dw = self.beta1*self.m_dw - eta*dw
        c = c + self.m_dc
        w = w + self.m_dw
        return c, w
    
    def _rmsprop(self, t, c, w, dc, dw):
        self.v_dc = self.beta2*self.v_dc + (1-self.beta2)*(dc**2)
        self.v_dw = self.beta2*self.v_dw + (1-self.beta2)*(dw**2)
        c = c - self.eta*(dc/(np.sqrt(self.v_dc)+self.epsilon))
        w = w - self.eta*(dw/(np.sqrt(self.v_dw)+self.epsilon))
        return c, w
        
        
    def _adam(self, t, c, w, dc, dw):
        ## momentum beta 1
        self.m_dc = self.beta1*self.m_dc + (1-self.beta1)*dc
        self.m_dw = self.beta1*self.m_dw + (1-self.beta1)*dw

        ## rms beta 2
        self.v_dc = self.beta2*self.v_dc + (1-self.beta2)*(dc**2)
        self.v_dw = self.beta2*self.v_dw + (1-self.beta2)*(dw**2)

        ## bias correction
        m_dc_corr = self.m_dc/(1-self.beta1**t)
        m_dw_corr = self.m_dw/(1-self.beta1**t)
        v_dc_corr = self.v_dc/(1-self.beta2**t)
        v_dw_corr = self.v_dw/(1-self.beta2**t)

        ## update weights
        c = c - self.eta*(m_dc_corr/(np.sqrt(v_dc_corr)+self.epsilon))
        w = w - self.eta*(m_dw_corr/(np.sqrt(v_dw_corr)+self.epsilon))
        return c, w