-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
266 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Math utilities | ||
============== | ||
|
||
Solvers | ||
------- | ||
|
||
.. module:: lenskit.math.solve | ||
|
||
.. autofunction:: solve_tri |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
Matrix Utilities | ||
---------------- | ||
|
||
.. module:: lenskit.matrix | ||
|
||
We have some matrix-related utilities, since matrices are used so heavily in recommendation | ||
algorithms. | ||
|
||
Building Ratings Matrices | ||
~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autofunction:: sparse_ratings | ||
.. autoclass:: RatingMatrix | ||
|
||
Compressed Sparse Row Matrices | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
We use CSR-format sparse matrices in quite a few places. Since SciPy's sparse matrices are not | ||
directly usable from Numba, we have implemented a Numba-compiled CSR representation that can | ||
be used from accelerated algorithm implementations. | ||
|
||
.. autofunction:: csr_from_coo | ||
.. autofunction:: csr_from_scipy | ||
.. autofunction:: csr_to_scipy | ||
.. autofunction:: csr_rowinds | ||
.. autofunction:: csr_save | ||
.. autofunction:: csr_load | ||
|
||
.. autoclass:: CSR | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,12 @@ | ||
Utility Functions | ||
================= | ||
|
||
.. automodule:: lenskit.util | ||
:members: | ||
|
||
Matrix Utilities | ||
---------------- | ||
|
||
.. module:: lenskit.matrix | ||
|
||
We have some matrix-related utilities, since matrices are used so heavily in recommendation | ||
algorithms. | ||
|
||
Building Ratings Matrices | ||
~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
.. toctree:: | ||
matrix | ||
math | ||
|
||
.. autofunction:: sparse_ratings | ||
.. autoclass:: RatingMatrix | ||
|
||
Compressed Sparse Row Matrices | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
We use CSR-format sparse matrices in quite a few places. Since SciPy's sparse matrices are not | ||
directly usable from Numba, we have implemented a Numba-compiled CSR representation that can | ||
be used from accelerated algorithm implementations. | ||
|
||
.. autofunction:: csr_from_coo | ||
.. autofunction:: csr_from_scipy | ||
.. autofunction:: csr_to_scipy | ||
.. autofunction:: csr_rowinds | ||
.. autofunction:: csr_save | ||
.. autofunction:: csr_load | ||
|
||
.. autoclass:: CSR | ||
:members: | ||
Miscellaneous | ||
------------- | ||
|
||
.. automodule:: lenskit.util | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Mathematical helper routines. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Efficient solver routines. | ||
""" | ||
|
||
|
||
import numpy as np | ||
|
||
import cffi | ||
import numba as n | ||
from numba.extending import get_cython_function_address | ||
|
||
__ffi = cffi.FFI() | ||
|
||
__uplo_U = np.array(ord('U'), dtype=np.int8) | ||
__uplo_L = np.array(ord('L'), dtype=np.int8) | ||
__trans_N = np.array(ord('N'), dtype=np.int8) | ||
__trans_T = np.array(ord('T'), dtype=np.int8) | ||
__trans_C = np.array(ord('C'), dtype=np.int8) | ||
__diag_U = np.array(ord('U'), dtype=np.int8) | ||
__diag_N = np.array(ord('N'), dtype=np.int8) | ||
__inc_1 = np.ones(1, dtype=np.int32) | ||
|
||
__dtrsv = __ffi.cast("void (*) (char*, char*, char*, int*, double*, int*, double*, int*)", | ||
get_cython_function_address("scipy.linalg.cython_blas", "dtrsv")) | ||
__dposv = __ffi.cast("void (*) (char*, int*, int*, double*, int*, double*, int*, int*)", | ||
get_cython_function_address("scipy.linalg.cython_lapack", "dposv")) | ||
|
||
|
||
@n.njit(n.void(n.boolean, n.boolean, n.double[:, ::1], n.double[::1]), nogil=True) | ||
def _dtrsv(lower, trans, a, x): | ||
inc1 = __ffi.from_buffer(__inc_1) | ||
|
||
# dtrsv uses Fortran-layout arrays. Because we use C-layout arrays, we will | ||
# invert the meaning of 'lower' and 'trans', and the function will work fine. | ||
# We also need to swap index orders | ||
uplo = __uplo_U if lower else __uplo_L | ||
tspec = __trans_N if trans else __trans_T | ||
|
||
n_p = np.array([a.shape[0]], dtype=np.intc) | ||
n_p = __ffi.from_buffer(n_p) | ||
lda_p = np.array([a.shape[1]], dtype=np.intc) | ||
lda_p = __ffi.from_buffer(lda_p) | ||
|
||
__dtrsv(__ffi.from_buffer(uplo), __ffi.from_buffer(tspec), __ffi.from_buffer(__diag_N), | ||
n_p, __ffi.from_buffer(a), lda_p, | ||
__ffi.from_buffer(x), inc1) | ||
|
||
|
||
def solve_tri(A, b, transpose=False, lower=True): | ||
""" | ||
Solve the system :math:`Ax = b`, where :math:`A` is triangular. | ||
This is equivalent to :py:fun:`scipy.linalg.solve_triangular`, but does *not* | ||
check for non-singularity. It is a thin wrapper around the BLAS ``dtrsv`` | ||
function. | ||
Args: | ||
A(ndarray): the matrix. | ||
b(ndarray): the taget vector. | ||
transpose(bool): whether to solve :math:`Ax = b` or :math:`A^T x = b`. | ||
lower(bool): whether :math:`A` is lower- or upper-triangular. | ||
""" | ||
x = b.copy() | ||
_dtrsv(lower, transpose, A, x) | ||
return x | ||
|
||
|
||
@n.njit(n.intc(n.float64[:, ::1], n.float64[::1], n.boolean), nogil=True) | ||
def _dposv(A, b, lower): | ||
if A.shape[0] != A.shape[1]: | ||
return -11 | ||
if A.shape[0] != b.shape[0]: | ||
return -12 | ||
|
||
# dposv uses Fortran-layout arrays. Because we use C-layout arrays, we will | ||
# invert the meaning of 'lower' and 'trans', and the function will work fine. | ||
# We also need to swap index orders | ||
uplo = __uplo_U if lower else __uplo_L | ||
n_p = __ffi.from_buffer(np.array([A.shape[0]], dtype=np.intc)) | ||
nrhs_p = __ffi.from_buffer(np.ones(1, dtype=np.intc)) | ||
info = np.zeros(1, dtype=np.intc) | ||
info_p = __ffi.from_buffer(info) | ||
|
||
__dposv(__ffi.from_buffer(uplo), n_p, nrhs_p, | ||
__ffi.from_buffer(A), n_p, | ||
__ffi.from_buffer(b), n_p, | ||
info_p) | ||
|
||
return info[0] | ||
|
||
|
||
def dposv(A, b, lower=False): | ||
info = _dposv(A, b, lower) | ||
if info < 0: | ||
raise ValueError('invalid args to dposv, code ' + str(info)) | ||
elif info > 0: | ||
raise RuntimeError('error in dposv, code ' + str(info)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import numpy as np | ||
import scipy.linalg as sla | ||
|
||
from pytest import approx | ||
|
||
from lenskit.math.solve import solve_tri, dposv | ||
|
||
|
||
def test_solve_ltri(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
Af = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
A = np.tril(Af) | ||
|
||
x = solve_tri(A, b) | ||
assert len(x) == size | ||
|
||
xexp = sla.solve_triangular(A, b, lower=True) | ||
assert x == approx(xexp, rel=1.0e-6) | ||
|
||
|
||
def test_solve_ltri_transpose(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
Af = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
A = np.tril(Af) | ||
|
||
x = solve_tri(A, b, True) | ||
assert len(x) == size | ||
|
||
xexp = sla.solve_triangular(A.T, b, lower=False) | ||
assert x == approx(xexp, rel=1.0e-6) | ||
|
||
|
||
def test_solve_utri(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
Af = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
A = np.triu(Af) | ||
|
||
x = solve_tri(A, b, lower=False) | ||
assert len(x) == size | ||
xexp = sla.solve_triangular(A, b, lower=False) | ||
assert x == approx(xexp, rel=1.0e-6) | ||
|
||
|
||
def test_solve_utri_transpose(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
Af = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
A = np.triu(Af) | ||
|
||
x = solve_tri(A, b, True, lower=False) | ||
assert len(x) == size | ||
xexp = sla.solve_triangular(A.T, b, lower=True) | ||
assert x == approx(xexp, rel=1.0e-6) | ||
|
||
|
||
def test_solve_cholesky(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
A = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
|
||
# square values of A | ||
A = A * A | ||
|
||
# and solve | ||
xexp, resid, rank, s = np.linalg.lstsq(A, b) | ||
|
||
# chol solve | ||
L = np.linalg.cholesky(A.T @ A) | ||
|
||
w = solve_tri(L, A.T @ b) | ||
x = solve_tri(L, w, transpose=True) | ||
|
||
assert x == approx(xexp) | ||
|
||
|
||
def test_solve_dposv(): | ||
for i in range(10): | ||
size = np.random.randint(5, 50) | ||
A = np.random.randn(size, size) | ||
b = np.random.randn(size) | ||
|
||
# square values of A | ||
A = A * A | ||
|
||
# and solve | ||
xexp, resid, rank, s = np.linalg.lstsq(A, b) | ||
|
||
F = A.T @ A | ||
x = A.T @ b | ||
dposv(F, x, True) | ||
|
||
assert x == approx(xexp, rel=1.0e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters