In [200]:
from copy import copy
import numpy as np
from numba import prange, guvectorize, float32, float64
LR = np.float(1e-2)
BETA1 = np.float(0.9)
BETA2 = np.float(0.999)
EPS = np.float(1e-8)
DECAY = np.float(0)

def adam_py(lr=LR, beta1=BETA1, beta2=BETA2, eps=EPS, weight_decay=DECAY):
    def aux_init(param):
        return [np.zeros_like(param)]*2
    def update_rule(x, y, aux):
        x = (1-weight_decay) * x
        aux[0] = beta1 * aux[0] + (1-beta1) * y
        aux[1] = beta2 * aux[1] + (1-beta2) * (y**2)
        x += -lr * aux[0] / (np.sqrt(aux[1]) + eps)
        return x, aux
    return aux_init, update_rule

def adam(x, y, aux, lr, beta1, beta2, eps, decay, out, init):
    init[:][:] = aux[:][:]
    for i in prange(aux.shape[1]):
        init[0][i] = beta1 * aux[0][i] + (1-beta1) * y[i]
    for j in prange(aux.shape[1]):
        init[1][j] = beta2 * aux[1][j] + (1-beta2) * y[j] ** 2
    out[:] = (1-decay) * x[:]
    out[:] = out[:] + (-lr * init[0] / (np.sqrt(init[1]) + eps))

def gu_adam(func, *args, **kwargs):
    kwargs_ = {k: v for k, v in kwargs.items() if v is not None}
    return guvectorize([(float32[:], float32[:], float32[:,:], float32, float32, float32, float32, float32, float32[:], float32[:,:]),
                      (float64[:], float64[:], float64[:,:], float64, float64, float64, float64, float64, float64[:], float64[:,:])],
                     '(n),(n),(m,n),(),(),(),(),()->(n),(m,n)', nopython=True, fastmath=True, *args, **kwargs_)(func)

In [201]:
# Defining input values
X = np.arange(0, 10, dtype=np.float32)
Y = np.ones_like(X, dtype=np.float32)
AUX = np.array([np.zeros_like(X, dtype=np.float32)]* 2)
X_ = copy(X)
Y_ = copy(Y)
AUX_ = copy(AUX)

In [202]:
print("Calculating Python method")
aux_init, update = adam_py()
pyaux, pyup = update(X, Y, AUX)
pyaux, pyup

Calculating Python method


(array([-0.03162276,  0.96837723,  1.9683772 ,  2.9683774 ,  3.9683774 ,
         4.968377  ,  5.968377  ,  6.968377  ,  7.968377  ,  8.968377  ],
       dtype=float32),
 array([[0.1  , 0.1  , 0.1  , 0.1  , 0.1  , 0.1  , 0.1  , 0.1  , 0.1  ,
         0.1  ],
        [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
         0.001]], dtype=float32))

In [203]:
print("Calculating Python method with numba")
s_adam = gu_adam(adam)
p_adam = gu_adam(adam, target='parallel')
aux, up = s_adam(X_, Y_, AUX_, LR, BETA1, BETA2, EPS, DECAY)
aux, up

Calculating Python method with numba


(array([-0.03162297,  0.96837705,  1.968377  ,  2.968377  ,  3.968377  ,
         4.968377  ,  5.968377  ,  6.968377  ,  7.968377  ,  8.968377  ],
       dtype=float32),
 array([[0.10000002, 0.10000002, 0.10000002, 0.10000002, 0.10000002,
         0.10000002, 0.10000002, 0.10000002, 0.10000002, 0.10000002],
        [0.00099999, 0.00099999, 0.00099999, 0.00099999, 0.00099999,
         0.00099999, 0.00099999, 0.00099999, 0.00099999, 0.00099999]],
       dtype=float32))

In [204]:
# Test equality of auxilliary vectors
np.testing.assert_allclose(aux, pyaux)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 3 / 10 (30%)
Max absolute difference: 2.3841858e-07
Max relative difference: 6.597028e-06
 x: array([-0.031623,  0.968377,  1.968377,  2.968377,  3.968377,  4.968377,
        5.968377,  6.968377,  7.968377,  8.968377], dtype=float32)
 y: array([-0.031623,  0.968377,  1.968377,  2.968377,  3.968377,  4.968377,
        5.968377,  6.968377,  7.968377,  8.968377], dtype=float32)

In [None]:
# Test equality of update vector
np.testing.assert_allclose(up, pyup)