In [74]:
import os
import comptools
from setuptools import setup, Extension
import bindthem
import subprocess
from contextlib import contextmanager,redirect_stderr,redirect_stdout
from os import devnull
from scipy.sparse import csr_matrix

template1 = \
"""
void spmv(const int Ap[], const int Ap_size,
          const int Aj[], const int Aj_size,
          const double Ax[], const int Ax_size,
          const double Xx[], const int Xx_size,
                double Yx[], const int Yx_size)
{{
    for(int i = 0; i < {n}; i++){{
        double sum = Yx[i];
        for(int jj = Ap[i]; jj < Ap[i+1]; jj++){{
            sum += Ax[jj] * Xx[Aj[jj]];
        }}
        Yx[i] = sum;
    }}
}}
"""

def writeit(name, n):
    """
    write a template to a file
    """
    with open(name, 'w') as f:
        f.write(template1.format(n=1))

def bindit(name):
    """
    bind name with bindthem
    """
    subprocess.check_output(['./bindthem.py', name])

def buildit(name):
    """
    use setuptools to build an extension inplace
    """
    get_pybind_include = comptools.get_pybind_include
    BuildExt = comptools.BuildExt
    ext_modules = [Extension(name,
                             sources=[name.replace('.cpp','_bind.cpp')],
                             include_dirs=[get_pybind_include(), get_pybind_include(user=True)],
                             language='c++')]

    setup(
        include_package_data=False,
        zip_safe=False,
        #
        script_args=['build_ext', '--inplace'],
        ext_modules=ext_modules,
        cmdclass={'build_ext': BuildExt},
        install_requires=['numpy>=1.7.0', 'pybind11>=2.2']
        );

# https://stackoverflow.com/questions/11130156/suppress-stdout-stderr-print-from-python-functions
@contextmanager
def suppress_stdout_stderr():
    """A context manager that redirects stdout and stderr to devnull"""
    with open(devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)

In [75]:
def f(a):
    print("here")
    return

with suppress_stdout_stderr():
    f(1)

In [76]:
import importlib

class myjit(object):
    """
    Provide a callable object
    """
    def __init__(self, name, n):
        self.name = name
        self.n = n
        
    def __call__(self, *args):
        """
        If the function is already compile, then proceed.
        If the function needs to be built, then rebind and build.
        """
        built = importlib.util.find_spec(self.name)
        
        if built:
            print("Using jitted {}!".format(self.name))
        else:
            print('jitting {}!'.format(self.name))
            print(self.n)
            with suppress_stdout_stderr():
                spmvname = self.name+".cpp"
                writeit(spmvname, self.n)
                bindit(spmvname)
                buildit(spmvname)
        
        mod = importlib.import_module(self.name)
        func = getattr(mod, self.name)
        return func(*args)

In [77]:
class csr(csr_matrix):
    """
    New CSR class

    The purpose of this class is to redefine the matvec in scipy.sparse
    """
    def _mul_vector(self, other):
        """
        Identical to scipy.sparse with an in internal call to
        pyamg.amg_core.sparse.csr_matvec
        """
        M, N = self.shape

        # output array
        result = np.zeros(M)

        # call the spmv-N matvec
        print('calling spmv-n')
        g = myjit("spmv-{n}".format(n=N), N)
        g(N, self.indptr, self.indices, self.data, other, result)

        return result

In [78]:
import pyamg
import numpy as np
g = myjit('spmv', 3)
A = pyamg.gallery.poisson((3,), format='csr')

g(A.shape[0], A.indptr, A.indices, A.data, x, y)
print(y)
np.testing.assert_array_almost_equal(y, A*x)

Using jitted spmv!
[2. 0. 1. 0. 0. 0. 0. 0. 0. 1.]


ValueError: dimension mismatch

In [79]:
n = 10
A = pyamg.gallery.poisson((n,), format='csr')
x = np.ones((n,))
y = A * x

A = csr(A)
yn = A * x

calling spmv-n
jitting spmv-10!
10


CalledProcessError: Command '['./bindthem.py', 'spmv-10.cpp']' returned non-zero exit status 1.

In [80]:
"{n}".format(n=10)

'10'