# FFI

In [None]:
%env JAX_ENABLE_X64=1

## Imports

In [None]:
from pathlib import Path

import jax
import jax.extend as jex
import jax.numpy as jnp
import jax.scipy as jsp
import klujax_cpp
import numpy as np

## Helpers

In [None]:
def _prepare_arguments(Ai, Aj, Ax, x):
    Ai = jnp.asarray(Ai)
    Aj = jnp.asarray(Aj)
    Ax = jnp.asarray(Ax)
    x = jnp.asarray(x)
    shape = x.shape

    if x.ndim < 2:
        x = jnp.atleast_2d(x).T

    n_col, n_rhs, *_ = x.shape[Ax.ndim - 1 :] + (1,)

    *_, n_nz = Ax.shape
    Ax = Ax.reshape(-1, n_nz)
    n_lhs, _ = Ax.shape

    x = x.reshape(-1, n_col, n_rhs)
    return Ai, Aj, Ax, x, shape, n_lhs, n_nz, n_col, n_rhs

## solve_f64

In [None]:
jex.ffi.register_ffi_target(
    "_solve_f64",
    klujax_cpp.solve_f64(),
    platform="cpu",
)

In [None]:
def solve_f64_impl(Ai, Aj, Ax, b):
    Ai, Aj, Ax, b, shape, n_lhs, n_nz, n_col, n_rhs = _prepare_arguments(Ai, Aj, Ax, b)
    call = jax.extend.ffi.ffi_call(
        "_solve_f64",
        jax.ShapeDtypeStruct(b.shape, b.dtype),
        vmap_method="broadcast_all",
    )
    result = call(  # type: ignore
        Ai,
        Aj,
        Ax,
        b,
    )
    return result.reshape(*shape)  # type: ignore

### single

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_col, n_rhs))
x_sp = solve_f64_impl(Ai, Aj, Ax, b)

A = jnp.zeros((n_col, n_col), dtype=jnp.float64).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)

x_sp - x

### batched

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3

Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_lhs, n_col, n_rhs))
x_sp = solve_f64_impl(Ai, Aj, Ax, b)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
x = jax.vmap(jsp.linalg.solve, (0, 0), 0)(A, b)

x_sp - x

## solve_c128

In [None]:
jex.ffi.register_ffi_target(
    "_solve_c128",
    klujax_cpp.solve_c128(),
    platform="cpu",
)

In [None]:
def solve_c128_impl(Ai, Aj, Ax, b):
    Ai, Aj, Ax, b, shape, n_lhs, n_nz, n_col, n_rhs = _prepare_arguments(Ai, Aj, Ax, b)
    print(b.shape)
    call = jax.extend.ffi.ffi_call(
        "_solve_c128",
        jax.ShapeDtypeStruct(b.shape, b.dtype),
        vmap_method="broadcast_all",
    )
    result = call(  # type: ignore
        Ai,
        Aj,
        Ax,
        b,
    )
    return result.reshape(*shape)  # type: ignore

### single

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax_r, Ax_i = jax.random.normal(Axkey, (2, n_nz))
Ax = Ax_r + 1j * Ax_i
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b_r, b_i = jax.random.normal(bkey, (2, n_col, n_rhs))
b = b_r + 1j * b_i
x_sp = solve_c128_impl(Ai, Aj, Ax, b)

A = jnp.zeros((n_col, n_col), dtype=jnp.complex128).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)

x_sp - x

### batched

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3

Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
b = jax.random.normal(bkey, (n_lhs, n_col, n_rhs), dtype=jnp.complex128)
x_sp = solve_c128_impl(Ai, Aj, Ax, b)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
x = jax.vmap(jsp.linalg.solve, (0, 0), 0)(A, b)

x_sp - x

## coo_mul_vec_f64

In [None]:
jex.ffi.register_ffi_target(
    "_coo_mul_vec_f64",
    klujax_cpp.coo_mul_vec_f64(),
    platform="cpu",
)

In [None]:
def coo_mul_vec_f64_impl(Ai, Aj, Ax, x):
    Ai, Aj, Ax, x, shape, n_lhs, n_nz, n_col, n_rhs = _prepare_arguments(Ai, Aj, Ax, x)
    call = jax.extend.ffi.ffi_call(
        "_coo_mul_vec_f64",
        jax.ShapeDtypeStruct(x.shape, x.dtype),
        vmap_method="broadcast_all",
    )
    result = call(  # type: ignore
        Ai,
        Aj,
        Ax,
        x,
    )
    return result.reshape(*shape)  # type: ignore

### single

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_col, n_rhs))
b_sp = coo_mul_vec_f64_impl(Ai, Aj, Ax, x)

A = jnp.zeros((n_col, n_col), dtype=jnp.float64).at[Ai, Aj].add(Ax)
b = A @ x

b_sp - b

### batched

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz))
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_lhs, n_col, n_rhs))

b_sp = coo_mul_vec_f64_impl(Ai, Aj, Ax, x)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
b = jnp.einsum("bij,bjk->bik", A, x)
b_sp - b

## coo_mul_vec_c128

In [None]:
jex.ffi.register_ffi_target(
    "_coo_mul_vec_c128",
    klujax_cpp.coo_mul_vec_c128(),
    platform="cpu",
)

In [None]:
def coo_mul_vec_c128_impl(Ai, Aj, Ax, x):
    Ai, Aj, Ax, x, shape, n_lhs, n_nz, n_col, n_rhs = _prepare_arguments(Ai, Aj, Ax, x)
    call = jax.extend.ffi.ffi_call(
        "_coo_mul_vec_c128",
        jax.ShapeDtypeStruct(x.shape, x.dtype),
        vmap_method="broadcast_all",
    )
    result = call(  # type: ignore
        Ai,
        Aj,
        Ax,
        x,
    )
    return result.reshape(*shape)  # type: ignore

### single

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_nz,), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_col, n_rhs), dtype=jnp.complex128)

b_sp = coo_mul_vec_c128_impl(Ai, Aj, Ax, x)

A = jnp.zeros((n_col, n_col), dtype=jnp.complex128).at[Ai, Aj].add(Ax)
b = A @ x

b_sp - b

### batched

In [None]:
n_nz = 8
n_col = 5
n_rhs = 1
n_lhs = 3
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(33), 4)
Ax = jax.random.normal(Axkey, (n_lhs, n_nz), dtype=jnp.complex128)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
x = jax.random.normal(bkey, (n_lhs, n_col, n_rhs), dtype=jnp.complex128)

b_sp = coo_mul_vec_c128_impl(Ai, Aj, Ax, x)

A = jnp.zeros((n_lhs, n_col, n_col), dtype=jnp.complex128).at[:, Ai, Aj].add(Ax)
b = jnp.einsum("bij,bjk->bik", A, x)
b_sp - b