From 5c1ff122313f844b5edb5f19f625eca749b8ffd8 Mon Sep 17 00:00:00 2001 From: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> Date: Tue, 23 Mar 2021 17:55:51 -0500 Subject: [PATCH] Merge pull request #6826 from sklam/fix/iss6821 Fix regression on gufunc serialization (cherry picked from commit 056bf6b71a08bd909b64d27be6cb54cc7b8e58bc) --- numba/np/ufunc/gufunc.py | 30 ++++++- numba/tests/npyufunc/test_gufunc.py | 118 +++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/numba/np/ufunc/gufunc.py b/numba/np/ufunc/gufunc.py index a089906e901..bdf48e95e89 100644 --- a/numba/np/ufunc/gufunc.py +++ b/numba/np/ufunc/gufunc.py @@ -3,9 +3,10 @@ from numba.np.ufunc.ufuncbuilder import GUFuncBuilder from numba.np.ufunc.sigparse import parse_signature from numba.np.numpy_support import ufunc_find_matching_loop +from numba.core import serialize -class GUFunc(object): +class GUFunc(serialize.ReduceMixin): """ Dynamic generalized universal function (GUFunc) intended to act like a normal Numpy gufunc, but capable @@ -18,6 +19,7 @@ def __init__(self, py_func, signature, identity=None, cache=None, self.ufunc = None self._frozen = False self._is_dynamic = is_dynamic + self._identity = identity # GUFunc cannot inherit from GUFuncBuilder because "identity" # is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder @@ -25,6 +27,32 @@ def __init__(self, py_func, signature, identity=None, cache=None, self.gufunc_builder = GUFuncBuilder( py_func, signature, identity, cache, targetoptions) + def _reduce_states(self): + gb = self.gufunc_builder + dct = dict( + py_func=gb.py_func, + signature=gb.signature, + identity=self._identity, + cache=gb.cache, + is_dynamic=self._is_dynamic, + targetoptions=gb.targetoptions, + typesigs=gb._sigs, + frozen=self._frozen, + ) + return dct + + @classmethod + def _rebuild(cls, py_func, signature, identity, cache, is_dynamic, + targetoptions, typesigs, frozen): + self = cls(py_func=py_func, signature=signature, identity=identity, + cache=cache, is_dynamic=is_dynamic, + targetoptions=targetoptions) + for sig in typesigs: + self.add(sig) + self.build_ufunc() + self._frozen = frozen + return self + def __repr__(self): return f"" diff --git a/numba/tests/npyufunc/test_gufunc.py b/numba/tests/npyufunc/test_gufunc.py index 8fde667f595..655ae641947 100644 --- a/numba/tests/npyufunc/test_gufunc.py +++ b/numba/tests/npyufunc/test_gufunc.py @@ -1,10 +1,12 @@ +import unittest +import pickle + import numpy as np import numpy.core.umath_tests as ut from numba import void, float32, jit, guvectorize from numba.np.ufunc import GUVectorize from numba.tests.support import tag, TestCase -import unittest def matmulcore(A, B, C): @@ -281,5 +283,119 @@ class TestGUVectorizeScalarParallel(TestGUVectorizeScalar): target = 'parallel' +class TestGUVectorizePickling(TestCase): + def test_pickle_gufunc_non_dyanmic(self): + """Non-dynamic gufunc. + """ + @guvectorize(["f8,f8[:]"], "()->()") + def double(x, out): + out[:] = x * 2 + + # pickle + ser = pickle.dumps(double) + cloned = pickle.loads(ser) + + # attributes carried over + self.assertEqual(cloned._frozen, double._frozen) + self.assertEqual(cloned.identity, double.identity) + self.assertEqual(cloned.is_dynamic, double.is_dynamic) + self.assertEqual(cloned.gufunc_builder._sigs, + double.gufunc_builder._sigs) + # expected value of attributes + self.assertTrue(cloned._frozen) + + cloned.disable_compile() + self.assertTrue(cloned._frozen) + + # scalar version + self.assertPreciseEqual(double(0.5), cloned(0.5)) + # array version + arr = np.arange(10) + self.assertPreciseEqual(double(arr), cloned(arr)) + + def test_pickle_gufunc_dyanmic_null_init(self): + """Dynamic gufunc w/o prepopulating before pickling. + """ + @guvectorize("()->()", identity=1) + def double(x, out): + out[:] = x * 2 + + # pickle + ser = pickle.dumps(double) + cloned = pickle.loads(ser) + + # attributes carried over + self.assertEqual(cloned._frozen, double._frozen) + self.assertEqual(cloned.identity, double.identity) + self.assertEqual(cloned.is_dynamic, double.is_dynamic) + self.assertEqual(cloned.gufunc_builder._sigs, + double.gufunc_builder._sigs) + # expected value of attributes + self.assertFalse(cloned._frozen) + + # scalar version + expect = np.zeros(1) + got = np.zeros(1) + double(0.5, out=expect) + cloned(0.5, out=got) + self.assertPreciseEqual(expect, got) + # array version + arr = np.arange(10) + expect = np.zeros_like(arr) + got = np.zeros_like(arr) + double(arr, out=expect) + cloned(arr, out=got) + self.assertPreciseEqual(expect, got) + + def test_pickle_gufunc_dyanmic_initialized(self): + """Dynamic gufunc prepopulated before pickling. + + Once unpickled, we disable compilation to verify that the gufunc + compilation state is carried over. + """ + @guvectorize("()->()", identity=1) + def double(x, out): + out[:] = x * 2 + + # prepopulate scalar + expect = np.zeros(1) + got = np.zeros(1) + double(0.5, out=expect) + # prepopulate array + arr = np.arange(10) + expect = np.zeros_like(arr) + got = np.zeros_like(arr) + double(arr, out=expect) + + # pickle + ser = pickle.dumps(double) + cloned = pickle.loads(ser) + + # attributes carried over + self.assertEqual(cloned._frozen, double._frozen) + self.assertEqual(cloned.identity, double.identity) + self.assertEqual(cloned.is_dynamic, double.is_dynamic) + self.assertEqual(cloned.gufunc_builder._sigs, + double.gufunc_builder._sigs) + # expected value of attributes + self.assertFalse(cloned._frozen) + + # disable compilation + cloned.disable_compile() + self.assertTrue(cloned._frozen) + # scalar version + expect = np.zeros(1) + got = np.zeros(1) + double(0.5, out=expect) + cloned(0.5, out=got) + self.assertPreciseEqual(expect, got) + # array version + expect = np.zeros_like(arr) + got = np.zeros_like(arr) + double(arr, out=expect) + cloned(arr, out=got) + self.assertPreciseEqual(expect, got) + + if __name__ == '__main__': unittest.main()