In [179]:
import time

import numpy as np
from numba import njit

from optimizers import s_sgd

class Dense:

    def __init__(self, input_dim, output_dim, batch_size=256, optimizer=s_sgd):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.batch_size = batch_size
        self.optimizer = optimizer

        # layer params
        self.var = np.sqrt(2.0/(self.input_dim + self.output_dim))
        self.W = np.random.normal(0, self.var, (self.input_dim, self.output_dim))
        self.b = np.abs(np.random.normal(0, self.var, (self.output_dim)))

        # init params
        self.x = np.zeros((batch_size, self.input_dim))
        self.auxW = np.zeros_like(self.W)
        self.auxb = np.zeros_like(self.b)

        # parameter derivatives
        self.de_dW = np.zeros_like(self.W)
        self.de_db = np.zeros_like(self.b)

    def forward(self, x):
        self.x = x
        return self.predict(x)

    def backward(self, de_dy):
        self.de_dW, self.de_db, de_dY = numba_backward(self.x.T, W, de_dy, self.batch_size)
        return de_dY

    def predict(self, x):
        return numba_predict(x, self.W, self.b)

    def update(self):
        self.W, self.aW = self.optimizer(self.W, self.de_dW, self.aW)
        self.b, self.ab = self.optimizer(self.b, self.de_db, self.ab)

In [180]:
pydense = Dense(2, 2)

In [181]:
pydense

<__main__.Dense at 0x22d8026d3a0>

In [182]:
def backward(x, W, de_dy, batch_size=24):
    de_dW = (x.T @ de_dy) / batch_size
    de_db = de_dy.mean(axis=0)
    return de_dW, de_db, de_dy @ W.T

x = np.random.normal(0, 1, (1000, 1000))
W = np.random.normal(0, 1, (1000, 1000))
b = np.random.normal(0, 1, (1000, 1000))
de_dy = np.random.normal(0, 1, (1000, 1000))

# backward(x, W, de_dy)

In [183]:
@njit(fastmath=True)
def matmul(a, b):
    return a @ b

def numba_backward(x, W, de_dy, batch_size=24):
    de_dW = matmul(x.T,de_dy) / batch_size
    de_db = np.mean(de_dy, axis=0)
    return de_dW, de_db, matmul(de_dy, W.T)

@njit(fastmath=True)
def numba_predict(x, W, b):
    return matmul(x, W) + b

def predict(x, W, b):
    return x @ W + b

In [184]:
def ntime():
    start = time.time()
    numba_predict(x, W, b)
    end = time.time()
    print("Elapsed Numba Pre = %s" % (end - start))
    start = time.time()
    f2 = numba_predict(x, W, b)
    end = time.time()
    print("Elapsed Numba Post = %s" % (end - start))
    return f2

def ptime():
    start = time.time()
    f1 = predict(x, W, b)
    end = time.time()
    print("Elapsed PyTime = %s" % (end - start))
    return f1

In [185]:
ntime()

Elapsed Numba Pre = 0.2524232864379883
Elapsed Numba Post = 0.00997304916381836


array([[ 73.02862061, -49.62178094,  -6.8785207 , ...,  41.3734456 ,
        -23.35271519,  -5.86568261],
       [ 20.40711104,  13.3942848 ,   2.3941608 , ..., -23.63026274,
         69.48409687,  -6.62970565],
       [-25.78831455, -10.96572162, -28.02274376, ...,  56.49625951,
         42.70854967,  32.32089384],
       ...,
       [ 53.29977259,  43.51158051,  10.06346471, ..., -21.89942842,
        -82.00427285, -38.66760534],
       [ 56.2271636 ,  38.76335839,  10.15629974, ...,  -9.95671304,
         -1.82992546, -13.65583021],
       [ 20.35099138, -23.20796415,  19.37195767, ...,   6.3885379 ,
        -13.23735478,  -5.33107986]])

In [186]:
ptime()

Elapsed PyTime = 0.011967658996582031


array([[ 73.02862061, -49.62178094,  -6.8785207 , ...,  41.3734456 ,
        -23.35271519,  -5.86568261],
       [ 20.40711104,  13.3942848 ,   2.3941608 , ..., -23.63026274,
         69.48409687,  -6.62970565],
       [-25.78831455, -10.96572162, -28.02274376, ...,  56.49625951,
         42.70854967,  32.32089384],
       ...,
       [ 53.29977259,  43.51158051,  10.06346471, ..., -21.89942842,
        -82.00427285, -38.66760534],
       [ 56.2271636 ,  38.76335839,  10.15629974, ...,  -9.95671304,
         -1.82992546, -13.65583021],
       [ 20.35099138, -23.20796415,  19.37195767, ...,   6.3885379 ,
        -13.23735478,  -5.33107986]])