In [1]:
import os
import comptools
from setuptools import setup, Extension
import bindthem
import subprocess

def bindit(name):
    subprocess.check_output(['./bindthem.py', name])

def buildit(name):

    get_pybind_include = comptools.get_pybind_include
    BuildExt = comptools.BuildExt
    ext_modules = [Extension(name,
                             sources=[name+'_bind.cpp'],
                             include_dirs=[get_pybind_include(), get_pybind_include(user=True)],
                             language='c++')]

    setup(
        include_package_data=False,
        zip_safe=False,
        #
        script_args=['build_ext', '--inplace'],
        ext_modules=ext_modules,
        cmdclass={'build_ext': BuildExt},
        install_requires=['numpy>=1.7.0', 'pybind11>=2.2']
        );

# https://stackoverflow.com/questions/11130156/suppress-stdout-stderr-print-from-python-functions
from contextlib import contextmanager,redirect_stderr,redirect_stdout
from os import devnull

@contextmanager
def suppress_stdout_stderr():
    """A context manager that redirects stdout and stderr to devnull"""
    with open(devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)

In [2]:
def f(a):
    print("here")
    return

with suppress_stdout_stderr():
    f(1)

In [None]:
import importlib

class myjit(object):
    def __init__(self, name):
        self.name = name
    def __call__(self, *args):
        built = importlib.util.find_spec(self.name)
        
        if built:
            print("Using jitted {}!".format(self.name))
        else:
            print('jitting {}!'.format(self.name))
            with suppress_stdout_stderr():
                bindit(self.name+'.cpp')
                buildit(self.name)
        
        mod = importlib.import_module(self.name)
        func = getattr(mod, self.name)
        return func(*args)
    

In [4]:
g = myjit('spmv')

In [5]:
import pyamg
import numpy as np
A = pyamg.gallery.poisson((3,), format='csr')
x = np.ones((3,))
y = np.zeros((3,))
g(A.shape[0], A.indptr, A.indices, A.data, x, y)
print(y)
np.testing.assert_array_almost_equal(y, A*x)

Using jitted spmv!
[1. 0. 1.]
