In [66]:
import numpy as np
from numpy import asarray
from numpy.linalg import norm
from numpy import sqrt

class Optimizer:
    """
    A class to implement adam optimimization.

    ...

    Attributes
    ----------
    a : float
        stepsize (alpha)
    b1 : float
        exponential decay rate [0,1)
    b2 : int
        exponential decay rate [0,1)
    eps : float
        regularization parameter

    Methods
    -------
    objective(x,y):
        returns height of paraboloid at (x,y)

    gradient(x,y):
        returns slope of paraboloid at (x,y)

    optimize(th_0,d):
        returns optimal point to precision d 
    """
    def __init__(self, a:float=1e-3, b1:float=0.9, b2:float=0.999, eps:float=1e-9):
        self.a = a
        self.b1 = b1
        self.b2 = b2 
        self.eps = eps

    def objective(self, x:float, y:float) -> float:
        '''given points x and y, objective gives height on a paraboloid with center at (2,4)'''
        return (x-2)**2 + (y-4)**2

    def gradient(self, x:float, y:float) -> np.ndarray[np.float64]:
        '''given points x and y, returns value of the gradient at (x,y)'''
        return asarray([2*(x-2),2*(y-4)])

    def _adam(self, mt, vt, gradient, t):
            ## Compute biased first moment
            mtt = self.b1*mt + (1-self.b1)*gradient
            
            ## Compute biased second moment
            vtt = self.b2*vt + (1-self.b2)*gradient**2
            
            ## Compute bias corrected terms
            mtt_hat = mtt/(1-self.b1**t)
            vtt_hat = vtt/(1-self.b2**t)
            
            ## Calculate update
            update = self.a*(mtt_hat)/(sqrt(vtt_hat)+self.eps)
            return update, mtt, vtt

    def _sgd(self,mt,vt,gradient,t):
        return self.a*gradient,mt,vt

    def _rmsprop(self,mt,vt,gradient,t):
        vtt = self.b2*vt + (1-self.b2)*gradient**2
        return (self.a/norm(vtt))*gradient,mt,vtt


    def optimize(self, th_t:np.array, d: float, optimizer:str="adam") -> np.ndarray[np.float64]:
        '''Give starting point (2d-array) and convergence precision (d) returns optimal point'''
        mt,vt = asarray([0 for _ in range(th_t.shape[0])]),asarray([0 for _ in range(th_t.shape[0])])
        th_tt = th_t
        gradient = self.gradient(th_t[0],th_t[1])
        t = 1
        # while th_tt NOT coverged DO 
        while norm(gradient) >= d:
            ## Compute gradient
            gradient = self.gradient(th_t[0],th_t[1])

            if optimizer == "sgd":
                update, mtt, vtt = self._sgd(mt,vt,gradient,t)
            elif optimizer =="rms":
                update, mtt, vtt = self._rmsprop(mt,vt,gradient,t)
            else:
                update, mtt, vtt = self._adam(mt,vt,gradient,t)
            
            
            ## Apply update
            th_tt = th_t - update
            
            ## Update params recursively
            th_t = th_tt
            mt = mtt
            vt = vtt
            t += 1
            print(f'time: {t}, grad:{gradient}, theta:{th_tt}')
            if t > 100000:
                raise Exception(f"{optimizer} failed to converge")
        return th_tt

In [71]:
th_0 = asarray([45,-90])
f = Optimizer(a=0.1).optimize(th_0,0.1)

time: 2, grad:[  86 -188], theta:[ 44.9 -89.9]
time: 3, grad:[  85.8 -187.8], theta:[ 44.80000614 -89.80000279]
time: 4, grad:[  85.60001227 -187.60000558], theta:[ 44.7000225  -89.70001022]
time: 5, grad:[  85.40004501 -187.40002044], theta:[ 44.60005319 -89.60002415]
time: 6, grad:[  85.20010638 -187.2000483 ], theta:[ 44.50010226 -89.5000464 ]
time: 7, grad:[  85.00020452 -187.00009281], theta:[ 44.40017376 -89.40007881]
time: 8, grad:[  84.80034751 -186.80015763], theta:[ 44.30027168 -89.30012317]
time: 9, grad:[  84.60054337 -186.60024635], theta:[ 44.20040001 -89.20018127]
time: 10, grad:[  84.40080001 -186.40036254], theta:[ 44.10056263 -89.10025485]
time: 11, grad:[  84.20112526 -186.2005097 ], theta:[ 44.0007634  -89.00034564]
time: 12, grad:[  84.00152681 -186.00069127], theta:[ 43.90100611 -88.90045532]
time: 13, grad:[  83.80201222 -185.80091064], theta:[ 43.80129447 -88.80058556]
time: 14, grad:[  83.60258893 -185.60117111], theta:[ 43.7016321  -88.70073796]
time: 15, grad

In [18]:
f

array([1.99989959, 3.95323947])