
# JAX implementations

JAX expose numpy and scipy implementations of `eigh` function but at the moment, advanced options such as `eigvals` are not available.

Here are some interesting links:
- [jackd/jju: Jack's Jax Utilities](https://github.com/jackd/jju)

In particular, it contains a JAX implementation of LOBPCG discussed above:
- https://github.com/jackd/jju/blob/master/jju/linalg/lobpcg/basic.py#L11

The issue https://github.com/google/jax/issues/3112 from JAX's Github is also discussing an implementation.

### Benchmarking with numpy

I found this study:
https://towardsdatascience.com/turbocharging-svd-with-jax-749ae12f93af?gi=398628dbfc88
which compare jax and numy algorithms.

There is in the end no big difference between jax and numpy from the speed perspective.


Other benchmarking results can be found in this study:
- [Separate your filters! Separability, SVD and low-rank approximation of 2D image processing filters | Bart Wronski](https://arxiv.org/pdf/2009.07542.pdf)

This recent paper (2019):
- [Differentiable Programming Tensor Networks
Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang](https://arxiv.org/abs/1903.09650).
discuss which implementations of eigh may be recommended when used in differentiable pipelines.
They have an implementation of differentiable SVD in Pytorch.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import jax
import jax.numpy as jnp
import scipy.sparse
import numpy as onp

from jju.linalg.lobpcg.basic import lobpcg
from jju.linalg.lobpcg.utils import identity, rayleigh_ritz
from jju.types import as_array_fun



In [4]:
T = 100
N = 5000

In [5]:
rng = jax.random.PRNGKey(42)
X = jax.random.uniform(rng, (T, N))
W = X.T @ X / T
X0 = jax.random.uniform(rng, (N, 3))



## test LOBPCG

In [6]:
X0 = jax.random.uniform(rng, (N, 3))

In [7]:
%time ev, _ = scipy.sparse.linalg.lobpcg(onp.array(W).astype(onp.float64), onp.array(X0).astype(onp.float64))
ev

CPU times: user 1.25 s, sys: 312 ms, total: 1.56 s
Wall time: 443 ms


[6.63793672e-05 2.44824104e-03 9.55060172e-03]
not reaching the requested tolerance 0.00015811388300841897.


array([1254.49699212,    5.38454768,    5.27474219])

In [8]:
Y = None

In [9]:
# ??lobpcg

In [10]:
rng = jax.random.PRNGKey(42)
X = jax.random.uniform(rng, (T, N))
W = X.T @ X / T
X0 = jax.random.uniform(rng, (N, 3))

%time ev, X1 = lobpcg(W, X0, None, None, None, True, None, None, 1000); ev.block_until_ready(); ev.block_until_ready()
ev

CPU times: user 2.29 s, sys: 214 ms, total: 2.5 s
Wall time: 787 ms


DeviceArray([1254.4971   ,    5.38453  ,    5.2748566], dtype=float32)

Second run is faster:

In [11]:
%time ev, X1 = lobpcg(W, X0, None, None, None, True, None, None, 1000); ev.block_until_ready(); ev.block_until_ready()
ev

CPU times: user 1.87 s, sys: 204 ms, total: 2.07 s
Wall time: 304 ms


DeviceArray([1254.4971   ,    5.38453  ,    5.2748566], dtype=float32)

### Warm-start

In [12]:
%time ev, X1 = lobpcg(W, X1, None, None, None, True, None, None, 1000); ev.block_until_ready(); ev.block_until_ready()

CPU times: user 145 ms, sys: 10.7 ms, total: 156 ms
Wall time: 27.6 ms


DeviceArray([1254.4971   ,    5.3845315,    5.2748566], dtype=float32)

## Transpose approach

In [13]:
rng = jax.random.PRNGKey(42)
X = jax.random.uniform(rng, (T, N))
Wsmall = X @ X.T / T
X0 = jax.random.normal(rng, (T, 3))

%time ev, X1 = lobpcg(Wsmall, X0, None, None, None, True, None, None, 1000); ev.block_until_ready(); ev.block_until_ready()
ev

CPU times: user 378 ms, sys: 3.71 ms, total: 381 ms
Wall time: 380 ms


DeviceArray([1254.4971   ,    5.3842063,    5.2675486], dtype=float32)

In [14]:
%time ev, X1 =  lobpcg(Wsmall, X0, None, None, None, True, None, None, 1000); ev.block_until_ready()

CPU times: user 2.5 ms, sys: 1.08 ms, total: 3.57 ms
Wall time: 2.03 ms


DeviceArray([1254.4971   ,    5.3842063,    5.2675486], dtype=float32)

In [15]:
nx = len(ev)

In [17]:
%time Uhat, s, V = jax.scipy.linalg.svd(X); Uhat.block_until_ready();

CPU times: user 4.68 s, sys: 736 ms, total: 5.42 s
Wall time: 993 ms


DeviceArray([[-0.10014337, -0.08351295,  0.1439427 , ..., -0.05296591,
               0.05243376, -0.14207326],
             [-0.10096472, -0.04740001, -0.16419022, ...,  0.20665328,
               0.04076387,  0.20304883],
             [-0.09967631, -0.17876193, -0.10932072, ...,  0.04042136,
               0.13983932, -0.13899142],
             ...,
             [-0.09967086,  0.12509172, -0.00724508, ...,  0.02378564,
              -0.04695617, -0.18081056],
             [-0.09961274,  0.01550005,  0.12577657, ...,  0.02889361,
               0.1688753 ,  0.0448981 ],
             [-0.09964427, -0.04367123, -0.07992887, ..., -0.08170576,
              -0.0716786 ,  0.11668825]], dtype=float32)

In [18]:
Q = X1
B = Q.T @ X
%time Uhat, s, V = jax.scipy.linalg.svd(B); Uhat.block_until_ready()
U = Q @ Uhat
U.shape
U = V[:nx].T
U.T @ U

CPU times: user 339 ms, sys: 115 ms, total: 454 ms
Wall time: 90.2 ms


DeviceArray([[ 1.0000007e+00, -5.4685456e-09,  3.6529713e-09],
             [-5.4685456e-09,  1.0000014e+00,  4.6366072e-08],
             [ 3.6529713e-09,  4.6366072e-08,  1.0000000e+00]],            dtype=float32)

In [19]:
Cov = X.T @ X / T
jnp.diag(U.T @ Cov @ U)

DeviceArray([1254.4977  ,    5.384249,    5.267863], dtype=float32)

In [20]:
ev

DeviceArray([1254.4971   ,    5.3842063,    5.2675486], dtype=float32)

### test with constraints

In [222]:
N, T = 5000, 100
rng = jax.random.PRNGKey(42)
def generate_wishart(N=1000, T=1100):
    X = jax.random.normal(rng, (T, N))
    W = X.T @ X / T
    return W, X
W, X = generate_wishart(N, T)
X0 = jax.random.normal(rng, (N, 3))
Y = jnp.eye(N, 5)

In [223]:
Y

DeviceArray([[1., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0.],
             [0., 0., 1., 0., 0.],
             ...,
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.]], dtype=float32)

In [224]:
invA = jnp.diag(1.0 / (X**2).mean(0))

In [249]:
%time ev, X1 =  lobpcg(jnp.array(W), jnp.array(X0), Y=jnp.array(Y),largest=True, max_iters=39); ev.block_until_ready();#iK=invA, 
X1

CPU times: user 2.7 s, sys: 265 ms, total: 2.96 s
Wall time: 534 ms


DeviceArray([[-3.6445203e-05, -4.7590547e-05,  7.0459268e-05],
             [ 7.1704976e-06,  5.9620838e-04, -6.5635063e-04],
             [-7.7927485e-05, -7.5356552e-04,  8.6425769e-04],
             ...,
             [ 1.4446663e-02,  3.2763463e-03,  1.8101221e-02],
             [ 3.2827254e-02, -6.8336604e-03, -9.4931778e-03],
             [ 1.6918244e-02,  8.6720763e-03, -8.5902298e-03]],            dtype=float32)

In [230]:
ev

array([282.38647,  65.00391,  63.80316], dtype=float32)

Starting with max_iters=40 the algorithm give nans.
We have to fix it.

In [250]:
# from jax.config import config; config.update("jax_enable_x64", False)
%time ev, X1 =  lobpcg(jnp.array(W, "float32"), jnp.array(X0, "float32"), Y=jnp.array(Y, "float32"),largest=True, max_iters=40); ev.block_until_ready();#iK=invA, 
X1

CPU times: user 2.92 s, sys: 302 ms, total: 3.22 s
Wall time: 512 ms


DeviceArray([[nan, nan, nan],
             [nan, nan, nan],
             [nan, nan, nan],
             ...,
             [nan, nan, nan],
             [nan, nan, nan],
             [nan, nan, nan]], dtype=float32)

In [268]:
from jax.config import config; config.update("jax_enable_x64", True)

In [279]:
%%time 
ev, X1 =  lobpcg(jnp.array(W, "float64"), jnp.array(X0, "float64"), Y=jnp.array(Y, "float64"), largest=True, max_iters=40); ev.block_until_ready();
#iK=invA, 
X1.block_until_ready()

CPU times: user 5.59 s, sys: 664 ms, total: 6.25 s
Wall time: 1.06 s


DeviceArray([[ 5.46214083e-05,  6.46505350e-05,  2.53362722e-05],
             [ 2.68729187e-05,  1.74906617e-04, -1.66615759e-04],
             [-1.01086227e-04, -5.71977816e-04,  4.82890250e-04],
             ...,
             [-1.44547315e-02,  3.21703661e-03,  1.80874004e-02],
             [-3.28315158e-02, -6.82441543e-03, -9.45072788e-03],
             [-1.69018512e-02,  8.80029127e-03, -8.50324976e-03]],            dtype=float64)

In [280]:
ev

DeviceArray([65.00406729, 63.8051286 , 63.10346073], dtype=float64)

## scipy run 

Let's compare to what gives scipy.

In [227]:
from scipy.sparse.linalg import LinearOperator
from scipy.sparse import issparse, spdiags

In [231]:
if False:
    import numpy as np
    print(N, T)
    N, T = 5000, 100

    def generate_wishart(N=1000, T=1100):
        X = np.random.randn(T, N)
        W = X.T @ X / T
        return W, X
    W, X = generate_wishart(N, T)
    rng = np.random.default_rng()
    X0 = rng.random((N, 3))
    Y = np.eye(N, 5)

In [260]:
%time ev, X1 = scipy.sparse.linalg.lobpcg(onp.array(W), X0, Y=Y, largest=True, maxiter=39)
X1

CPU times: user 2.19 s, sys: 486 ms, total: 2.68 s
Wall time: 433 ms


[6.7431797e+03 1.4651953e+00 2.0974681e+00]
not reaching the requested tolerance 0.00015811388300841897.


array([[ 4.4167793e-04,  9.9974408e-05, -3.4029788e-04],
       [-2.5614346e-03, -6.0712595e-05,  3.0222326e-04],
       [-1.5623022e-02,  8.7821954e-06, -1.2032720e-04],
       ...,
       [-8.3930582e-02, -1.4465491e-02,  3.2491998e-03],
       [-4.1133054e-02, -3.2832053e-02, -6.8052039e-03],
       [ 2.8422205e-02, -1.6905002e-02,  8.7821446e-03]], dtype=float32)

In [257]:
ev

array([1451.7709  ,   65.003944,   63.80353 ], dtype=float32)

In [266]:
%time ev, X1 = scipy.sparse.linalg.lobpcg(onp.array(W, float), onp.array(X0, float), Y=onp.array(Y, float), largest=True, maxiter=39)
X1

CPU times: user 3.13 s, sys: 409 ms, total: 3.54 s
Wall time: 559 ms


[1.46537022 2.10026986 1.47411062]
not reaching the requested tolerance 0.00015811388300841897.


array([[ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       ...,
       [-0.01444526,  0.00331638,  0.01821111],
       [-0.03282786, -0.00683802, -0.0093886 ],
       [-0.01691743,  0.0087028 , -0.00847153]])

In [267]:
ev

array([65.00404136, 63.80249923, 63.10427344])

It converges better with computations in float64.