In [34]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
import scipy

In [36]:
import cr.sparse as crs
from cr.sparse import la
from cr.sparse import dict
from cr.sparse import pursuit
from cr.sparse import data

In [37]:
from cr.sparse.pursuit import omp

In [38]:
M = 256
N = 1024
K = 16
S = 32

In [39]:
key = random.PRNGKey(0)
Phi = dict.gaussian_mtx(key, M,N)

In [40]:
dict.coherence(Phi)

DeviceArray(0.30067617, dtype=float32)

In [41]:
X, omega = data.sparse_normal_representations(key, N, K, S)
X.shape

(1024, 32)

In [42]:
omega

DeviceArray([  9,  12, 102, 121, 136, 199, 257, 291, 306, 352, 531, 556,
             596, 749, 760, 902], dtype=int32)

In [43]:
Y = Phi @ X
Y.shape

(256, 32)

In [44]:
solution = omp.solve_multi(Phi, Y, K)

In [45]:
jnp.max(solution.r_norm_sqr)

DeviceArray(3.2312356e-12, dtype=float32)

In [46]:
def time_solve_multi():
    solution = omp.solve_multi(Phi, Y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [47]:
%timeit time_solve_multi()

270 ms ± 2.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [48]:
solve_multi = jax.jit(omp.solve_multi, static_argnums=(2,))

In [49]:
def time_solve_multi_jit():
    solution = solve_multi(Phi, Y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [52]:
%timeit time_solve_multi_jit()

6.11 ms ± 72.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
