diff --git a/divisi2/algorithms/svd.py b/divisi2/algorithms/svd.py index ff5a57a..5fb8cc3 100644 --- a/divisi2/algorithms/svd.py +++ b/divisi2/algorithms/svd.py @@ -3,6 +3,7 @@ from divisi2.reconstructed import ReconstructedMatrix from divisi2._svdlib import svd_llmat, svd_ndarray from divisi2 import operators +import numpy as np def svd(matrix, k=50): """ @@ -22,6 +23,13 @@ def svd(matrix, k=50): if isinstance(matrix, DenseMatrix): Ut, S, Vt = svd_ndarray(matrix, k) elif isinstance(matrix, SparseMatrix): + if matrix.nnz == 0: + # don't let svdlib touch a matrix of all zeros. It explodes and + # corrupts its state. Just return a zero result instead. + U = DenseMatrix((matrix.shape[0], k)) + S = np.zeros((k,)) + V = DenseMatrix((matrix.shape[1], k)) + return U, S, V if matrix.shape[1] >= matrix.shape[0] * 1.2: # transpose the matrix for speed V, S, U = matrix.T.svd(k)