In [1]:
from jax import lax, jit
from jax.config import config; config.update('jax_platform_name', 'gpu')
import jax.numpy as np

In [2]:
N, D_X, D_H = 50, 3, 5
X, Y = np.ones((N, D_X)), np.zeros(N)
w = np.ones((D_X, D_H))

In [3]:
def potential_fn(w1):
    z1 = np.matmul(X, w1)
    w2 = np.ones((D_H, D_H))
    z2 = np.matmul(z1, w2)
    w3 = np.ones((D_H, 1))
    z3 = np.matmul(z2, w3)
    return np.sum((z3 - Y) ** 2)

In [4]:
@jit
def loop(w):
    def body_fn(i, state):
        w, f = state
        f = potential_fn(w)
        return w, f

    w, f = lax.fori_loop(0, 1000, body_fn, (w, 0.))
    return f

### CPU

In [5]:
%%time
loop(w)

CPU times: user 103 ms, sys: 4.24 ms, total: 108 ms
Wall time: 125 ms


DeviceArray(14062500., dtype=float32)

In [6]:
%%time
loop(w)

CPU times: user 825 µs, sys: 47 µs, total: 872 µs
Wall time: 532 µs


DeviceArray(14062500., dtype=float32)

### GPU

In [5]:
%%time
loop(w)

CPU times: user 442 ms, sys: 202 ms, total: 645 ms
Wall time: 760 ms


DeviceArray(14062500., dtype=float32)

In [6]:
%%time
loop(w)

CPU times: user 36.9 ms, sys: 0 ns, total: 36.9 ms
Wall time: 35.4 ms


DeviceArray(14062500., dtype=float32)

## PyTorch

In [7]:
import torch

In [8]:
def tpotential_fn(w1):
    z1 = torch.matmul(tX, w1)
    w2 = torch.ones((D_H, D_H))
    z2 = torch.matmul(z1, w2)
    w3 = torch.ones((D_H, 1))
    z3 = torch.matmul(z2, w3)
    return torch.sum((z3 - tY) ** 2)

### CPU

In [9]:
tX, tY = torch.ones((N, D_X)), torch.zeros(N)
w = torch.ones((D_X, D_H))
jtpotential_fn = torch.jit.trace(tpotential_fn, (w,))

In [10]:
%%time
for i in range(1000):
    f = tpotential_fn(w)

CPU times: user 587 ms, sys: 0 ns, total: 587 ms
Wall time: 51.6 ms


In [11]:
%%time
for i in range(1000):
    f = jtpotential_fn(w)

CPU times: user 698 ms, sys: 0 ns, total: 698 ms
Wall time: 61.7 ms


### GPU

In [12]:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
tX, tY = torch.ones((N, D_X)), torch.zeros(N)
w = torch.ones((D_X, D_H))
jtpotential_fn = torch.jit.trace(tpotential_fn, (w,))

In [13]:
%%time
for i in range(1000):
    f = tpotential_fn(w)

CPU times: user 98 ms, sys: 0 ns, total: 98 ms
Wall time: 97.1 ms


In [14]:
%%time
for i in range(1000):
    f = jtpotential_fn(w)

CPU times: user 65.3 ms, sys: 0 ns, total: 65.3 ms
Wall time: 64.5 ms
