In [1]:
%env JAX_ENABLE_X64=1

env: JAX_ENABLE_X64=1


## Imports

In [2]:
from pathlib import Path

import jax
import jax.extend as jex
import jax.numpy as jnp
import jax.scipy as jsp
import klujax_cpp
import numpy as np

## solve_f64

In [3]:
jex.ffi.register_ffi_target(
    "solve_f64",
    klujax_cpp.solve_f64(),
    platform="cpu",
)

In [4]:
def solve_f64(Ai, Aj, Ax, b):
    *_, n_col, n_rhs = b.shape
    _b = b.reshape(-1, n_col, n_rhs)
    
    *ns_lhs, n_nz = Ax.shape
    _Ax = Ax.reshape(-1, n_nz)
    n_lhs, _ = _Ax.shape
    
    call = jex.ffi.ffi_call(
        "solve_f64",
        jax.ShapeDtypeStruct(_b.shape, _b.dtype),
        vmap_method="broadcast_all",
    )
    b = call(
        Ai,
        Aj,
        _Ax,
        _b,
        n_col=np.int32(n_col),
        n_rhs=np.int32(n_rhs),
        n_lhs=np.int32(n_lhs),
        n_nz=np.int32(n_nz),
    )
    return b.reshape(*ns_lhs, n_col, n_rhs)

### Test Single

In [5]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_col, n_rhs))
x_sp = solve_f64(Ai, Aj, Ax, b)

A = jnp.zeros((n_col, n_col), dtype=jnp.float64).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)

x_sp - x

Array([[-1.11022302e-16],
       [ 0.00000000e+00],
       [ 8.88178420e-16],
       [-4.44089210e-16],
       [ 1.77635684e-15]], dtype=float64)

### Test Batched

In [6]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3

Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_lhs, n_col, n_rhs))
x_sp = solve_f64(Ai, Aj, Ax, b)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
x = jax.vmap(jsp.linalg.solve, (0, 0), 0)(A, b)

x_sp - x

Array([[[ 0.00000000e+00+0.j],
        [ 0.00000000e+00+0.j],
        [ 0.00000000e+00+0.j],
        [ 1.11022302e-16+0.j],
        [ 8.88178420e-16+0.j]],

       [[ 4.44089210e-16+0.j],
        [ 0.00000000e+00+0.j],
        [-2.22044605e-16+0.j],
        [ 0.00000000e+00+0.j],
        [ 4.44089210e-16+0.j]],

       [[ 0.00000000e+00+0.j],
        [ 0.00000000e+00+0.j],
        [ 1.38777878e-17+0.j],
        [ 0.00000000e+00+0.j],
        [ 0.00000000e+00+0.j]]], dtype=complex128)

## coo_mul_vec_f64

In [7]:
jex.ffi.register_ffi_target(
    "coo_mul_vec_f64",
    klujax_cpp.coo_mul_vec_f64(),
    platform="cpu",
)

In [8]:
def coo_mul_vec_f64(Ai, Aj, Ax, x):
    *_, n_col, n_rhs = x.shape
    _x = x.reshape(-1, n_col, n_rhs)
    
    *ns_lhs, n_nz = Ax.shape
    _Ax = Ax.reshape(-1, n_nz)
    n_lhs, _ = _Ax.shape

    call = jex.ffi.ffi_call(
        "coo_mul_vec_f64",
        jax.ShapeDtypeStruct(_x.shape, _x.dtype),
        vmap_method="broadcast_all",
    )
    b = call(
        Ai,
        Aj,
        _Ax,
        _x,
        n_col=np.int32(n_col),
        n_rhs=np.int32(n_rhs),
        n_lhs=np.int32(n_lhs),
        n_nz=np.int32(n_nz),
    )
    return b.reshape(*ns_lhs, n_col, n_rhs)

### Test Single

In [9]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_col, n_rhs))
b_sp = coo_mul_vec_f64(Ai, Aj, Ax, x)

A = jnp.zeros((n_col, n_col), dtype=jnp.float64).at[Ai, Aj].add(Ax)
b = A @ x

b_sp - b

Array([[0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float64)

### Test Batched

In [10]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_lhs, n_col, n_rhs))

b_sp = coo_mul_vec_f64(Ai, Aj, Ax, x)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
b = jnp.einsum("bij,bjk->bik", A, x)
b_sp - b

Array([[[0.00000000e+00+0.j],
        [0.00000000e+00+0.j],
        [0.00000000e+00+0.j],
        [0.00000000e+00+0.j],
        [0.00000000e+00+0.j]],

       [[0.00000000e+00+0.j],
        [0.00000000e+00+0.j],
        [2.22044605e-16+0.j],
        [0.00000000e+00+0.j],
        [0.00000000e+00+0.j]],

       [[0.00000000e+00+0.j],
        [0.00000000e+00+0.j],
        [4.44089210e-16+0.j],
        [0.00000000e+00+0.j],
        [0.00000000e+00+0.j]]], dtype=complex128)

## solve_c128

In [11]:
jex.ffi.register_ffi_target(
    "solve_c128",
    klujax_cpp.solve_c128(),
    platform="cpu",
)

In [12]:
def solve_c128(Ai, Aj, Ax, b):
    *_, n_col, n_rhs = b.shape
    _b = b.reshape(-1, n_col, n_rhs)
    
    *ns_lhs, n_nz = Ax.shape
    _Ax = Ax.reshape(-1, n_nz)
    n_lhs, _ = _Ax.shape

    _Ax = _Ax.view(np.float64)
    _b = _b.view(np.float64)
    
    call = jex.ffi.ffi_call(
        "solve_c128",
        jax.ShapeDtypeStruct(_b.shape, _b.dtype),
        vmap_method="broadcast_all",
    )
    x = call(
        Ai,
        Aj,
        _Ax,
        _b,
        n_col=np.int32(n_col),
        n_rhs=np.int32(n_rhs),
        n_lhs=np.int32(n_lhs),
        n_nz=np.int32(n_nz),
    )
    return x.view(b.dtype).reshape(*ns_lhs, n_col, n_rhs)

### Test Single

In [13]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax_r, Ax_i = jax.random.normal(Axkey, (2, n_nz))
Ax = Ax_r + 1j * Ax_i
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b_r, b_i = jax.random.normal(bkey, (2, n_col, n_rhs))
b = b_r + 1j * b_i
x_sp = solve_c128(Ai, Aj, Ax, b)

A = jnp.zeros((n_col, n_col), dtype=jnp.complex128).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)

x_sp - x

Array([[ 0.00000000e+00+0.00000000e+00j],
       [-4.16333634e-17-1.11022302e-16j],
       [ 0.00000000e+00-5.55111512e-17j],
       [-1.66533454e-16+4.44089210e-16j],
       [ 3.33066907e-16-1.22124533e-15j]], dtype=complex128)

### Test Batched

In [14]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3

Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_lhs, n_col, n_rhs), dtype=jnp.complex128)
x_sp = solve_c128(Ai, Aj, Ax, b)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
x = jax.vmap(jsp.linalg.solve, (0, 0), 0)(A, b)

x_sp - x

Array([[[ 5.55111512e-17-5.55111512e-17j],
        [-1.11022302e-16+0.00000000e+00j],
        [ 0.00000000e+00+1.24900090e-16j],
        [ 0.00000000e+00+2.77555756e-17j],
        [-1.24900090e-16+4.44089210e-16j]],

       [[-6.93889390e-18-5.55111512e-17j],
        [-5.55111512e-17-1.11022302e-16j],
        [ 4.85722573e-17+2.22044605e-16j],
        [ 8.32667268e-17-2.22044605e-16j],
        [ 2.22044605e-16-8.32667268e-17j]],

       [[-2.77555756e-17-1.38777878e-17j],
        [-4.44089210e-16+0.00000000e+00j],
        [ 2.22044605e-16+3.33066907e-16j],
        [ 3.46944695e-17+0.00000000e+00j],
        [ 5.55111512e-17+0.00000000e+00j]]], dtype=complex128)

## coo_mul_vec_c128

In [15]:
jex.ffi.register_ffi_target(
    "coo_mul_vec_c128",
    klujax_cpp.coo_mul_vec_c128(),
    platform="cpu",
)

In [16]:
def coo_mul_vec_c128(Ai, Aj, Ax, x):
    *_, n_col, n_rhs = x.shape
    _x = x.reshape(-1, n_col, n_rhs)
    
    *ns_lhs, n_nz = Ax.shape
    _Ax = Ax.reshape(-1, n_nz)
    n_lhs, _ = _Ax.shape

    _Ax = _Ax.view(np.float64)
    _x = _x.view(np.float64)
    call = jex.ffi.ffi_call(
        "coo_mul_vec_c128",
        jax.ShapeDtypeStruct(_x.shape, _x.dtype),
        vmap_method="broadcast_all",
    )
    y = call(
        Ai,
        Aj,
        _Ax,
        _x,
        n_col=np.int32(n_col),
        n_rhs=np.int32(n_rhs),
        n_lhs=np.int32(n_lhs),
        n_nz=np.int32(n_nz),
    )
    return y.view(x.dtype).reshape(*ns_lhs, n_col, n_rhs)

### Test Single

In [17]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_col, n_rhs), dtype=jnp.complex128)

b_sp = coo_mul_vec_c128(Ai, Aj, Ax, x)

A = jnp.zeros((n_col, n_col), dtype=jnp.complex128).at[Ai, Aj].add(Ax)
b = A @ x

b_sp - b

Array([[ 0.00000000e+00+0.j],
       [-2.77555756e-17+0.j],
       [ 1.11022302e-16+0.j],
       [ 0.00000000e+00+0.j],
       [ 0.00000000e+00+0.j]], dtype=complex128)

### Test Batched

In [18]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_lhs, n_col, n_rhs), dtype=jnp.complex128)

b_sp = coo_mul_vec_c128(Ai, Aj, Ax, x)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
b = jnp.einsum("bij,bjk->bik", A, x)
b_sp - b

Array([[[ 0.00000000e+00+1.11022302e-16j],
        [ 0.00000000e+00+0.00000000e+00j],
        [ 0.00000000e+00+1.11022302e-16j],
        [ 0.00000000e+00+0.00000000e+00j],
        [ 0.00000000e+00+0.00000000e+00j]],

       [[ 0.00000000e+00-6.93889390e-18j],
        [ 0.00000000e+00+0.00000000e+00j],
        [-5.55111512e-17+0.00000000e+00j],
        [ 0.00000000e+00+0.00000000e+00j],
        [ 2.77555756e-17+5.55111512e-17j]],

       [[ 0.00000000e+00+0.00000000e+00j],
        [ 3.46944695e-18+0.00000000e+00j],
        [-1.11022302e-16+5.55111512e-17j],
        [ 0.00000000e+00+0.00000000e+00j],
        [-2.77555756e-17+0.00000000e+00j]]], dtype=complex128)

In [22]:
np.atleast_2d([0, 0], ).shape

(1, 2)