In [1]:
import numpy as np

In [2]:
from numba import njit, jit
from numba import double
from typing import Tuple

In [3]:
@njit
def opt_foward(inputs: np.ndarray, w: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    output = inputs @ w.T + b.T
    return output, inputs

In [4]:
def foward(inputs: np.ndarray, w: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    output = inputs @ w.T + b.T
    return output, inputs

In [5]:
def backward(ograds: np.ndarray, s: np.ndarray, w: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    bgrads = ograds.sum(axis=0)
    wgrads = ograds.T @ s
    igrads = ograds @ w
    return igrads, wgrads, bgrads

In [6]:
@njit
def opt_backward(ograds: np.ndarray, s: np.ndarray, w: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    bgrads = ograds.sum(axis=0)
    wgrads = ograds.T @ s
    igrads = ograds @ w
    return igrads, wgrads, bgrads

In [7]:
i = np.random.rand(1000, 128)
w = np.random.rand(128, 128)
b = np.random.rand(128)

In [8]:
o, s = np.zeros((i.shape[0], w.shape[0])), np.zeros_like(i)
# compile code once
opt_foward(i, w, b);

In [9]:
%%timeit
o[:], s[:] = foward(i, w, b)

1.08 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%%timeit
o[:], s[:] = opt_foward(i, w, b)

786 µs ± 123 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
o, s = opt_foward(i, w, b)
og = np.random.rand(*o.shape)

In [12]:
ig, wg, bg = np.zeros_like(i), np.zeros_like(w), np.zeros_like(b)
# compile code once
opt_backward(og, s, w);

In [13]:
%%timeit
ig[:], wg[:], bg[:] = backward(og, s, w)

2.23 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
%%timeit
ig[:], wg[:], bg[:] = opt_backward(og, s, w)

1.73 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [15]:
def tanh_func(x):
    return np.tanh(x)

def tanh_grad(x):
    y = np.tanh(x)
    return 1 - (y * y)

In [16]:
a = np.random.rand(128, 128)

In [17]:
%%timeit
tanh_func(a)

129 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [18]:
%%timeit
tanh_grad(a)

142 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
@njit
def opt_tanh_func(x : np.ndarray) -> np.ndarray:
    return np.tanh(x)

@njit
def opt_tanh_grad(x : np.ndarray) -> np.ndarray:
    y = np.tanh(x)
    return 1 - (y * y)
# compile once
opt_tanh_func(np.ones(1));
opt_tanh_grad(np.ones(1));

In [20]:
%%timeit
tanh_func(a)

131 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
%%timeit
tanh_grad(a)

145 µs ± 3.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
from ddnn.utils import Parameter

_t = 1
_l2 = 0.1
_eta = 0.1
_eps = 1e-8
_beta1 = 0.9
_beta2 = 0.999

def call_t(
    params: Parameter, grads: Parameter, state: Parameter
) -> Tuple[Parameter, Tuple[Parameter, Parameter]]:
    if _t == 0:
        raise ValueError()
    if state == None:
        old_m = Parameter(
            np.zeros_like(grads.weights), np.zeros_like(grads.bias)
        )
        old_v = Parameter(
            np.zeros_like(grads.weights), np.zeros_like(grads.bias)
        )
    else:
        old_m = state[0]
        old_v = state[1]

    temp = grads.weights
    if _l2 != 0:
        # += here would modify grads.weights
        temp = temp + _l2 * params.weights

    m_w = _beta1 * old_m.weights + (1 - _beta1) * temp
    m_b = _beta1 * old_m.bias + (1 - _beta1) * grads.bias

    v_w = (
        _beta2 * old_v.weights
        + (1 - _beta2) * temp * temp
    )
    v_b = (
        _beta2 * old_v.bias
        + (1 - _beta2) * grads.bias * grads.bias
    )

    old_m = Parameter(m_w, m_b)
    old_v = Parameter(v_w, v_b)

    adj = (1 - _beta2**_t) ** 0.5 / (
        1 - _beta1**_t
    )
    delta_w = (-_eta * adj) * m_w / (np.sqrt(v_w) + _eps)
    delta_b = (-_eta * adj) * m_b / (np.sqrt(v_b) + _eps)

    delta = Parameter(delta_w, delta_b)

    return (delta, (old_m, old_v))

In [23]:
from numba.experimental import jitclass
from numba import double
@jitclass
class Parameter:
    weights: double[:,:]
    bias: double[:]

    def __init__(self, w, b):
        weights = w
        bias = b

    def __iadd__(self, other):
        self.weights += other.weights
        self.bias += other.bias
        return self

    @property
    def shape(self):
        return self.weights.shape


@jit
def opt_call_t(
    params: Parameter, grads: Parameter, state: Parameter
) -> Tuple[Parameter, Tuple[Parameter, Parameter]]:
    if _t == 0:
        raise ValueError()
    if state == None:
        old_m = Parameter(
            np.zeros_like(grads.weights), np.zeros_like(grads.bias)
        )
        old_v = Parameter(
            np.zeros_like(grads.weights), np.zeros_like(grads.bias)
        )
    else:
        old_m = state[0]
        old_v = state[1]

    temp = grads.weights
    if _l2 != 0:
        # += here would modify grads.weights
        temp = temp + _l2 * params.weights

    m_w = _beta1 * old_m.weights + (1 - _beta1) * temp
    m_b = _beta1 * old_m.bias + (1 - _beta1) * grads.bias

    v_w = (
        _beta2 * old_v.weights
        + (1 - _beta2) * temp * temp
    )
    v_b = (
        _beta2 * old_v.bias
        + (1 - _beta2) * grads.bias * grads.bias
    )

    old_m = Parameter(m_w, m_b)
    old_v = Parameter(v_w, v_b)

    adj = (1 - _beta2**_t) ** 0.5 / (
        1 - _beta1**_t
    )
    delta_w = (-_eta * adj) * m_w / (np.sqrt(v_w) + _eps)
    delta_b = (-_eta * adj) * m_b / (np.sqrt(v_b) + _eps)

    delta = Parameter(delta_w, delta_b)

    return (delta, (old_m, old_v))

In [28]:
params = Parameter(np.random.rand(128, 128), np.random.random(128))
grads = Parameter(np.random.rand(128, 128), np.random.random(128))
state = (
    Parameter(np.random.rand(128, 128), np.random.random(128)),
    Parameter(np.random.rand(128, 128), np.random.random(128))
)
# compile once
opt_call_t(params, grads, state);

In [29]:
%%timeit
delta, nstate = call_t(params, grads, state)

50.1 µs ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [31]:
%%timeit
delta, nstate = opt_call_t(params, grads, state)

15.1 µs ± 678 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
