Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always convert b to fortran array if b is not fortran contiguous #40

Merged
merged 3 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pypardiso/pardiso_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def _check_b(self, A, b):
SparseEfficiencyWarning)
b = b.todense()

# pardiso expects fortran (column-major) order if b is a matrix
if b.ndim == 2:
# pardiso expects fortran (column-major) order for b
if not b.flags.f_contiguous:
b = np.asfortranarray(b)

if b.shape[0] != A.shape[0]:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_input_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ def test_input_b_wrong_shape():
b = np.append(b, 1)
with pytest.raises(ValueError):
basic_solve(A, b)


def test_input_b_slice():
A, b = create_test_A_b_rand(matrix=True)
b1 = b[:, 0]
b2 = b[:, 0].copy()
x1 = ps.solve(A, b1)
x2 = ps.solve(A, b2)
np.testing.assert_array_equal(x1, x2)
3 changes: 3 additions & 0 deletions tests/test_scipy_aliases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import pytest
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import spsolve as scipyspsolve
Expand Down Expand Up @@ -28,6 +29,7 @@ def test_basic_spsolve_matrix():
np.testing.assert_array_almost_equal(xpp, xscipy)


@pytest.mark.filterwarnings("ignore:splu requires CSC matrix format")
def test_basic_factorized():
ps.remove_stored_factorization()
ps.free_memory()
Expand All @@ -39,6 +41,7 @@ def test_basic_factorized():
np.testing.assert_array_almost_equal(xpp, xscipy)


@pytest.mark.filterwarnings("ignore:Changing the sparsity structure")
def test_factorized_modified_A():
ps.remove_stored_factorization()
ps.free_memory()
Expand Down