In [13]:
import time

import numpy as np
from numba import njit, guvectorize, float32, float64, prange

from optimizers import s_sgd

class Dense:

    def __init__(self, input_dim, output_dim, batch_size=256, optimizer=s_sgd):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.batch_size = batch_size
        self.optimizer = optimizer

        # layer params
        self.var = np.sqrt(2.0/(self.input_dim + self.output_dim))
        self.W = np.random.normal(0, self.var, (self.input_dim, self.output_dim))
        self.b = np.abs(np.random.normal(0, self.var, (self.output_dim)))

        # init params
        self.x = np.zeros((batch_size, self.input_dim))
        self.auxW = np.zeros_like(self.W)
        self.auxb = np.zeros_like(self.b)

        # parameter derivatives
        self.de_dW = np.zeros_like(self.W)
        self.de_db = np.zeros_like(self.b)

    def forward(self, x):
        self.x = x
        return self.predict(x)

    def backward(self, de_dy):
        self.de_dW, self.de_db, de_dY = numba_backward(self.x.T, W, de_dy, self.batch_size)
        return de_dY

    def predict(self, x):
        return numba_predict(x, self.W, self.b)

    def update(self):
        self.W, self.aW = self.optimizer(self.W, self.de_dW, self.aW)
        self.b, self.ab = self.optimizer(self.b, self.de_db, self.ab)

In [14]:
pydense = Dense(2, 2)

In [15]:
pydense

<__main__.Dense at 0x21060b10820>

In [16]:
def backward(x, W, de_dy, batch_size=24):
    de_dW = (x.T @ de_dy) / batch_size
    de_db = de_dy.mean(axis=0)
    return de_dW, de_db, de_dy @ W.T

x = np.random.normal(0, 1, (2, 2))
W = np.random.normal(0, 1, (2, 2))
b = np.random.normal(0, 1, (2, 2))
de_dy = np.random.normal(0, 1, (2, 2))

# backward(x, W, de_dy)

In [17]:
@njit(fastmath=True)
def matmul(a, b):
    return a @ b

def numba_backward(x, W, de_dy, batch_size=24):
    de_dW = matmul(x.T,de_dy) / batch_size
    de_db = np.mean(de_dy, axis=0)
    return de_dW, de_db, matmul(de_dy, W.T)

@njit(fastmath=True)
def numba_predict(x, W, b):
    return matmul(x, W) + b

def predict(x, W, b):
    return x @ W + b

In [18]:
def ntime():
    start = time.time()
    numba_predict(x, W, b)
    end = time.time()
    print("Elapsed Numba Pre = %s" % (end - start))
    start = time.time()
    f2 = numba_predict(x, W, b)
    end = time.time()
    print("Elapsed Numba Post = %s" % (end - start))
    return f2

def ptime():
    start = time.time()
    f1 = predict(x, W, b)
    end = time.time()
    print("Elapsed PyTime = %s" % (end - start))
    return f1

In [19]:
prednumba = ntime()

Elapsed Numba Pre = 0.23038411140441895
Elapsed Numba Post = 0.0


In [20]:
pred = ptime()

Elapsed PyTime = 0.0


In [21]:
np.testing.assert_allclose(pred, prednumba)

In [22]:
def sgdpy():
    lr=1e-4
    mu=0.9
    weight_decay=0
    def aux_init(param):
        return np.zeros_like(param)
    def update_rule(x, dx, aux):
        x = (1-weight_decay) * x
        aux = mu * aux - lr * dx
        x = np.add(x, aux)
        return x, aux
    return aux_init, update_rule

def sgd(x, y, aux, lr, mu, decay, out, init):
    init[:] = aux[:]
    for i in prange(x.shape[0]):
        out[i] = (1-decay) * x[i]
        init[i] = mu * init[i] - lr * y[i]
        out[i] = np.add(out[i], init[i])

def gu_optimizer(func, *args, **kwargs):
    """
    The guvectorize decorated method, runs in nopython mode
    :param func: a loss error method
    :param target: choose between | 1: None -> serial | 2: 'parallel' -> parallel | execution method
    :return: the decorated loss, error function
    """
    kwargs_ = {k: v for k, v in kwargs.items() if v is not None}
    return guvectorize([(float32[:], float32[:], float32[:], float32, float32, float32, float32[:], float32[:]),
                      (float64[:], float64[:], float64[:], float64, float64, float64, float64[:], float64[:])],
                     '(n),(n),(m),(),(),()->(n),(n)', nopython=True, fastmath=True, *args, **kwargs_)(func)

s_sgd = gu_optimizer(sgd)
p_sgd = gu_optimizer(sgd, target='parallel')

## UPDATE
def update(W, de_dW, aW, b, de_db, ab):
    W, auxW = s_sgd(W, de_dW, aW, 1e-4, 0.9, 0)
    b, auxb = s_sgd(b, de_db, ab, 1e-4, 0.9, 0)
    return W,auxW, b, auxb

In [23]:
bigx = np.arange(0, 10000000, dtype=np.float32)
bigy = np.arange(10000000, 20000000, dtype=np.float32)
init = np.zeros_like(bigx)

update(bigx, bigy, init,bigx, bigy, init)

(array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
         9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32),
 array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
        -1999.9998 , -2000.     ], dtype=float32),
 array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
         9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32),
 array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
        -1999.9998 , -2000.     ], dtype=float32))

In [24]:
_, update_rule = sgdpy()
def update_py(W, de_dW, aW, b, de_db, ab):
    W, auxW = update_rule(W, de_dW, aW)
    b, auxb = update_rule(b, de_db, ab)
    return W,auxW, b, auxb

update_py(bigx, bigy, init, bigx, bigy, init)


(array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
         9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32),
 array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
        -1999.9998 , -2000.     ], dtype=float32),
 array([-1.0000000e+03, -9.9900006e+02, -9.9800018e+02, ...,
         9.9979970e+06,  9.9979980e+06,  9.9979990e+06], dtype=float32),
 array([-1000.     , -1000.00006, -1000.0002 , ..., -1999.9995 ,
        -1999.9998 , -2000.     ], dtype=float32))