In [50]:
import time

import numpy as np
from numba import guvectorize, float32, float64, prange

def sgdpy():
    lr=1e-4
    mu=0.9
    weight_decay=0
    def aux_init(param):
        return np.zeros_like(param)
    def update_rule(x, dx, aux):
        x = (1-weight_decay) * x
        aux = mu * aux - lr * dx
        x = np.add(x, aux)
        return x, aux
    return aux_init, update_rule

def sgd(x, y, aux, lr, mu, decay, out, init):
    init[:] = aux[:]
    for i in prange(x.shape[0]):
        out[i] = (1-decay) * x[i]
        init[i] = mu * init[i] - lr * y[i]
        out[i] = np.add(out[i], init[i])



def gu_optimizer(func, *args, **kwargs):
    """
    The guvectorize decorated method, runs in nopython mode
    :param func: a loss error method
    :param target: choose between | 1: None -> serial | 2: 'parallel' -> parallel | execution method
    :return: the decorated loss, error function
    """
    kwargs_ = {k: v for k, v in kwargs.items() if v is not None}
    return guvectorize([(float32[:], float32[:], float32[:], float32, float32, float32, float32[:], float32[:]),
                      (float64[:], float64[:], float64[:], float64, float64, float64, float64[:], float64[:])],
                     '(n),(n),(m),(),(),()->(n),(n)', nopython=True, fastmath=True, *args, **kwargs_)(func)

s_sgd = gu_optimizer(sgd)
p_sgd = gu_optimizer(sgd, target='parallel')

In [51]:
bigx = np.arange(0, 10000000, dtype=np.float32)
bigy = np.arange(10000000, 20000000, dtype=np.float32)
init = np.zeros_like(bigx)

def ntime(func1):
    start = time.time()
    func1(bigx, bigy, init, 1e-4, 0.9, 0.0)
    end = time.time()
    print("Elapsed Numba Pre = %s" % (end - start))
    start = time.time()
    f2 = func1(bigx, bigy, init, 1e-4, 0.9, 0.0)
    end = time.time()
    print("Elapsed Numba Post = %s" % (end - start))
    return f2

def ptime(func1):
    start = time.time()
    init, update = func1()
    aux = init(bigx)
    f1 = update(bigx, bigy, aux)
    end = time.time()
    print("Elapsed PyTime = %s" % (end - start))
    return f1

In [52]:
s = ntime(s_sgd)
p = ntime(p_sgd)
py = ptime(sgdpy)

print(s)
print(py)

Elapsed Numba Pre = 0.030916690826416016
Elapsed Numba Post = 0.026955366134643555
Elapsed Numba Pre = 0.0298917293548584
Elapsed Numba Post = 0.02892303466796875
Elapsed PyTime = 0.05385589599609375
(array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
        9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32), array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
       -1999.9998 , -2000.     ], dtype=float32))
(array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
        9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32), array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
       -1999.9998 , -2000.     ], dtype=float32))


In [53]:
np.testing.assert_allclose(s, p)

In [54]:
np.testing.assert_allclose(s, p)

In [55]:
np.testing.assert_allclose(s, p)