In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
import scipy

In [3]:
import cr.sparse as crs
from cr.sparse import la
from cr.sparse import dict
from cr.sparse import pursuit
from cr.sparse import data

In [4]:
from functools import partial

# Dictionary Setup

In [5]:
M = 128
N = 256
K = 8

In [6]:
key = random.PRNGKey(0)
Phi = dict.gaussian_mtx(key, M,N)



In [7]:
Phi.shape

(128, 256)

In [8]:
dict.coherence(Phi)

DeviceArray(0.33975706, dtype=float32)

In [9]:
x, omega = data.sparse_normal_representations(key, N, K, 1)
x = jnp.squeeze(x)

In [10]:
omega, omega.shape

(DeviceArray([ 41,  60,  68,  89,  99, 198, 232, 244], dtype=int32), (8,))

In [11]:
crs.nonzero_indices(x)

DeviceArray([ 41,  60,  68,  89,  99, 198, 232, 244], dtype=int32)

In [12]:
crs.nonzero_values(x)

DeviceArray([ 0.08086776, -0.3862472 , -0.37565574,  1.6689737 ,
             -1.2758199 ,  2.1192    , -0.8582123 ,  1.1305931 ],            dtype=float32)

In [13]:
y = Phi @ x

# Development of CoSaMP algorithm

In [14]:
r = y
y_norm_sqr = r.T @ r
r_norm_sqr = y_norm_sqr
r_norm_sqr

DeviceArray(11.47407, dtype=float32)

In [15]:
flags = jnp.zeros(N, dtype=bool)

In [16]:
K2 = 2*K
K2

16

In [17]:
K3 = K + K2
K3

24

In [18]:
iterations = 0

In [19]:
res_norm_rtol = 1e-3

In [20]:
max_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2) 

## First iteration

In [21]:
h = Phi.T @ r

In [22]:
I_2k =  pursuit.largest_indices(h, K2 if iterations else K3)
I_2k

DeviceArray([198,  89,  99, 244, 232, 100, 151, 109, 226, 255,  67, 217,
             239, 201,   5,  19,  41,  86,  44, 116, 142,   9, 159, 241],            dtype=int32)

In [23]:
flags = flags.at[I_2k].set(True)

In [24]:
jnp.where(flags)

(DeviceArray([  5,   9,  19,  41,  44,  67,  86,  89,  99, 100, 109, 116,
              142, 151, 159, 198, 201, 217, 226, 232, 239, 241, 244, 255],            dtype=int32),)

In [25]:
Phi_3I = Phi[:, flags]

In [26]:
x_3I, r_3I_norms, rank_3I, s_3I = jnp.linalg.lstsq(Phi_3I, y)

In [27]:
x_3I

DeviceArray([-0.02588557, -0.00669122, -0.03994435,  0.08871117,
             -0.04095322, -0.02504703,  0.04328822,  1.7106531 ,
             -1.1434499 , -0.04441866, -0.05053541,  0.11738732,
              0.03516048, -0.03045028, -0.01221958,  2.1708896 ,
              0.03996569,  0.0356746 ,  0.02685671, -0.8320763 ,
             -0.06286937, -0.05119762,  1.0858326 ,  0.00759119],            dtype=float32)

In [28]:
r_3I_norms

DeviceArray([0.22964859], dtype=float32)

In [29]:
rank_3I

DeviceArray(24, dtype=int32)

In [30]:
s_3I

DeviceArray([1.5750406 , 1.328093  , 1.2966532 , 1.2233667 , 1.1843554 ,
             1.1588342 , 1.1203228 , 1.0971804 , 1.0686163 , 1.0462837 ,
             1.0087969 , 0.9819599 , 0.9384124 , 0.92007697, 0.88111854,
             0.83006454, 0.80708253, 0.77306485, 0.7511992 , 0.72991306,
             0.69474196, 0.6813258 , 0.5978    , 0.5581449 ],            dtype=float32)

In [31]:
Ia = pursuit.largest_indices(x_3I, K)
Ia

DeviceArray([15,  7,  8, 22, 19, 11,  3, 20], dtype=int32)

In [32]:
x_3I[Ia]

DeviceArray([ 2.1708896 ,  1.7106531 , -1.1434499 ,  1.0858326 ,
             -0.8320763 ,  0.11738732,  0.08871117, -0.06286937],            dtype=float32)

In [33]:
I_3k, = jnp.where(flags)
I_3k

DeviceArray([  5,   9,  19,  41,  44,  67,  86,  89,  99, 100, 109, 116,
             142, 151, 159, 198, 201, 217, 226, 232, 239, 241, 244, 255],            dtype=int32)

In [34]:
I = I_3k[Ia]
I

DeviceArray([198,  89,  99, 244, 232, 116,  41, 239], dtype=int32)

In [35]:
x_I = x_3I[Ia]
x_I

DeviceArray([ 2.1708896 ,  1.7106531 , -1.1434499 ,  1.0858326 ,
             -0.8320763 ,  0.11738732,  0.08871117, -0.06286937],            dtype=float32)

In [36]:
Phi_I = Phi[:, I]

In [37]:
jnp.linalg.norm(y - Phi_I @ x_I)

DeviceArray(0.5008783, dtype=float32)

In [38]:
jnp.intersect1d(I, omega)

DeviceArray([ 41,  89,  99, 198, 232, 244], dtype=int32)

In [39]:
I

DeviceArray([198,  89,  99, 244, 232, 116,  41, 239], dtype=int32)

In [40]:
omega

DeviceArray([ 41,  60,  68,  89,  99, 198, 232, 244], dtype=int32)

In [41]:
jnp.setdiff1d(omega, I)

DeviceArray([60, 68], dtype=int32)

In [42]:
jnp.setdiff1d(omega, I_3k)

DeviceArray([60, 68], dtype=int32)

In [43]:
r = y - Phi_I @ x_I

In [44]:
r_norm_sqr = r.T @ r
r_norm_sqr

DeviceArray(0.250879, dtype=float32)

In [45]:
iterations += 1

In [46]:
r_norm_sqr < max_r_norm_sqr

DeviceArray(False, dtype=bool)

In [47]:
flags = flags.at[:].set(False)

In [48]:
flags = flags.at[I].set(True)

In [49]:
jnp.where(flags)

(DeviceArray([ 41,  89,  99, 116, 198, 232, 239, 244], dtype=int32),)

## Second iteration

In [50]:
iterations

1

In [51]:
h = Phi.T @ r

In [52]:
I_2k =  pursuit.largest_indices(h, K2 if iterations else K3)
I_2k

DeviceArray([ 60,  68,  73,   4,  34,  56, 105, 160,  59, 245, 137, 209,
             178, 103, 233, 155], dtype=int32)

In [53]:
jnp.intersect1d(omega, I_2k)

DeviceArray([60, 68], dtype=int32)

In [54]:
flags = flags.at[I_2k].set(True)

In [55]:
I_3k, = jnp.where(flags)

In [56]:
jnp.intersect1d(omega, I_3k)

DeviceArray([ 41,  60,  68,  89,  99, 198, 232, 244], dtype=int32)

In [57]:
jnp.setdiff1d(omega, I_3k)

DeviceArray([], dtype=int32)

In [58]:
Phi_3I = Phi[:, flags]

In [59]:
x_3I, r_3I_norms, rank_3I, s_3I = jnp.linalg.lstsq(Phi_3I, y)

In [60]:
x_3I

DeviceArray([-8.9406967e-08, -2.3096800e-07,  8.0867290e-02,
             -2.3841858e-07,  4.7683716e-07, -3.8624728e-01,
             -3.7565678e-01,  1.1920929e-07,  1.6689737e+00,
             -1.2758204e+00, -4.0978193e-07,  3.5762787e-07,
              4.7683716e-07,  2.8312206e-07, -4.3213367e-07,
              5.3644180e-07,  4.1723251e-07,  2.1192000e+00,
             -5.9604645e-08, -8.5821259e-01, -8.3446503e-07,
              8.6426735e-07,  1.1305925e+00, -3.5762787e-07],            dtype=float32)

In [61]:
r_3I_norms

DeviceArray([5.3271827e-12], dtype=float32)

In [62]:
s_3I

DeviceArray([1.4674419 , 1.3486443 , 1.3159059 , 1.2481905 , 1.1904742 ,
             1.1661103 , 1.1321163 , 1.075251  , 1.0460968 , 1.026143  ,
             1.0208699 , 1.0056185 , 0.9527196 , 0.9132839 , 0.87185484,
             0.84825575, 0.81520104, 0.8015615 , 0.77109873, 0.7518697 ,
             0.7152713 , 0.6672644 , 0.6339768 , 0.5247762 ],            dtype=float32)

In [63]:
rank_3I

DeviceArray(24, dtype=int32)

In [64]:
Ia = pursuit.largest_indices(x_3I, K)
Ia

DeviceArray([17,  8,  9, 22, 19,  5,  6,  2], dtype=int32)

In [65]:
I = I_3k[Ia]
I

DeviceArray([198,  89,  99, 244, 232,  60,  68,  41], dtype=int32)

In [66]:
jnp.setdiff1d(omega, I)

DeviceArray([], dtype=int32)

In [67]:
x_I = x_3I[Ia]
x_I

DeviceArray([ 2.1192    ,  1.6689737 , -1.2758204 ,  1.1305925 ,
             -0.8582126 , -0.38624728, -0.37565678,  0.08086729],            dtype=float32)

In [68]:
Phi_I = Phi[:, I]

In [69]:
r = y - Phi_I @ x_I

In [70]:
jnp.linalg.norm(r)

DeviceArray(1.3279785e-06, dtype=float32)

In [71]:
r_norm_sqr = r.T @ r

In [72]:
r_norm_sqr < max_r_norm_sqr

DeviceArray(True, dtype=bool)

In [73]:
from cr.sparse.pursuit import cosamp

In [74]:
solution =  cosamp.solve(Phi, y, K)

In [75]:
solution.x_I

DeviceArray([ 2.1191995 ,  1.6689733 , -1.2758204 ,  1.1305943 ,
             -0.8582119 , -0.38624743, -0.37565574,  0.08086765],            dtype=float32)

In [76]:
solution.I

DeviceArray([198,  89,  99, 244, 232,  60,  68,  41], dtype=int32)

In [77]:
omega

DeviceArray([ 41,  60,  68,  89,  99, 198, 232, 244], dtype=int32)

In [78]:
jnp.setdiff1d(omega, solution.I)

DeviceArray([], dtype=int32)

In [79]:
solution.r_norm_sqr, solution.iterations

(DeviceArray(2.3235905e-12, dtype=float32), DeviceArray(2, dtype=int32))

In [80]:
def time_solve():
    solution = cosamp.solve(Phi, y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [81]:
%timeit time_solve()

310 ms ± 3.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [82]:
cosamp_solve  = jax.jit(cosamp.solve, static_argnums=(2,3))

In [83]:
sol = cosamp_solve(Phi, y, K, 3)
sol.r_norm_sqr, sol.iterations

(DeviceArray(2.3235905e-12, dtype=float32), DeviceArray(2, dtype=int32))

In [84]:
jnp.setdiff1d(omega, solution.I)

DeviceArray([], dtype=int32)

In [85]:
def time_jit_solve():
    solution = cosamp_solve(Phi, y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [86]:
%timeit time_jit_solve()

775 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [88]:
310 * 1000 / 775

400.0