the idea is to bechmark the performance of a loop similar to a kalman filter one and see if there are possibilities to improve the performance 

torch jit is just useless ...

In [None]:
@torch.jit.script
def loop_mult_jit(A, B):
    n = 0
    while n<1000:
        A = A @ B @ A
        n += 1

In [None]:
def loop_mult(A, B):
    n = 0
    while n<1000:
        A = A @ B @ A
        n += 1

In [None]:
%timeit loop_mult(A,A)

67.2 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%timeit loop_mult_jit(A,A)

64.8 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## pytorch loop

In [None]:
import torch

In [None]:
out = torch.zeros(10_000, 20)
obs = torch.ones(10_000, 20)


In [None]:
def test_loop_torch(n=10_000):
    out = torch.zeros(10_000, 20)
    obs = torch.ones(10_000, 20)

    m = torch.ones(1, 20)
    for i in range(1, len(obs)):
        out[i] = (out[i-1] + obs[i]) * .3
    return out

In [None]:
%timeit test_loop

14.9 ns ± 0.177 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


## torch jit

In [None]:
test_loop_jit = torch.jit.script(test_loop_torch)

In [None]:
%timeit test_loop_jit

14.9 ns ± 0.193 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [None]:
test_loop_torch()

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3000, 0.3000, 0.3000,  ..., 0.3000, 0.3000, 0.3000],
        [0.3900, 0.3900, 0.3900,  ..., 0.3900, 0.3900, 0.3900],
        ...,
        [0.4286, 0.4286, 0.4286,  ..., 0.4286, 0.4286, 0.4286],
        [0.4286, 0.4286, 0.4286,  ..., 0.4286, 0.4286, 0.4286],
        [0.4286, 0.4286, 0.4286,  ..., 0.4286, 0.4286, 0.4286]])

## Numpy loop

In [None]:
import numpy as np

In [None]:
def test_loop_np(n=10_000):
    out = np.zeros((n, 20))
    obs = torch.ones((n, 20))

    for i in range(1, len(obs)):
        out[i] = (out[i-1] + obs[i]) * .3
    return out

In [None]:
%timeit test_loop

14.8 ns ± 0.0836 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


## Jax

In [None]:
import jax.numpy as jnp
from jax import grad, jit

In [None]:
def test_loop_jax_1(n=10_000):
    out = jnp.zeros((10_000, 20))
    obs = jnp.ones((10_000, 20))

    for i in range(1, len(obs)):
        out.at[i, :].set((out[i-1] + obs[i]) * .3)
    return out

In [None]:
%timeit test_loop_jax_1()

16.8 s ± 2.3 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
def test_loop_jax_2(n=10_000):
    out = np.zeros((10_000, 20))
    obs = jnp.ones((10_000, 20))

    m = jnp.ones((1, 20))
    for i in range(1, len(obs)):
        out[i, :] = (out[i-1] + obs[i]) * 1.01
    return out

In [None]:
%timeit test_loop_jax_2()

4.36 s ± 235 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Julia comparison

In [None]:
import torch

In [None]:
def loop_add(A, B, max=100):
    for _ in range(int(max)):
        A = A + B
    return A

In [None]:
%timeit loop_add(1,2, max=1e7)

428 ms ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit loop_add(torch.tensor(1),torch.tensor(2), max=1e7)

22.6 s ± 2.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
A = torch.rand(100,100, dtype=torch.float64)
B = torch.rand(100,100, dtype=torch.float64)

In [None]:
%timeit loop_add(A, B, max=1e5)

686 ms ± 33.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
