pyamgx: Python interface to NVIDIA's AMGX library
See demo.py
for the full example.
# Create matrices and vectors:
A = pyamgx.Matrix().create(rsc)
x = pyamgx.Vector().create(rsc)
b = pyamgx.Vector().create(rsc)
# Create solver:
solver = pyamgx.Solver().create(rsc, cfg)
# Upload system:
M = sparse.csr_matrix(np.random.rand(5, 5))
rhs = np.random.rand(5)
sol = np.zeros(5, dtype=np.float64)
A.upload_CSR(M)
b.upload(rhs)
x.upload(sol)
# Setup and solve:
solver.setup(A)
solver.solve(b, x)
# Download solution
x.download(sol)
print("pyamgx solution: ", sol)
print("scipy solution: ", splinalg.spsolve(M, rhs))
pyamgx solution: [-0.52114365 0.72874012 0.17712795 1.37890116 -1.03672993]
scipy solution: [-0.52114365 0.72874012 0.17712795 1.37890116 -1.03672993]
-
Set the environment variable
AMGX_DIR
to the AMGX root directory. -
Clone this repository:
$ git clone https://github.com/shwina/pyamgx
- Build and install
pyamgx
:
$ cd pyamgx
$ python setup.py build_ext
$ pip install .
Note: If you do not have administrative priveleges and if you are not installing inside a virtualenv or conda environment, replace the last command above with:
$ pip install . --user
- Run the demo:
$ python demo.py