Skip to content

Commit

Permalink
Merge pull request #6826 from sklam/fix/iss6821
Browse files Browse the repository at this point in the history
Fix regression on gufunc serialization

(cherry picked from commit 056bf6b)
  • Loading branch information
sklam authored and esc committed Mar 25, 2021
1 parent f52e1f6 commit 5c1ff12
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
30 changes: 29 additions & 1 deletion numba/np/ufunc/gufunc.py
Expand Up @@ -3,9 +3,10 @@
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
from numba.np.ufunc.sigparse import parse_signature
from numba.np.numpy_support import ufunc_find_matching_loop
from numba.core import serialize


class GUFunc(object):
class GUFunc(serialize.ReduceMixin):
"""
Dynamic generalized universal function (GUFunc)
intended to act like a normal Numpy gufunc, but capable
Expand All @@ -18,13 +19,40 @@ def __init__(self, py_func, signature, identity=None, cache=None,
self.ufunc = None
self._frozen = False
self._is_dynamic = is_dynamic
self._identity = identity

# GUFunc cannot inherit from GUFuncBuilder because "identity"
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
# object here
self.gufunc_builder = GUFuncBuilder(
py_func, signature, identity, cache, targetoptions)

def _reduce_states(self):
gb = self.gufunc_builder
dct = dict(
py_func=gb.py_func,
signature=gb.signature,
identity=self._identity,
cache=gb.cache,
is_dynamic=self._is_dynamic,
targetoptions=gb.targetoptions,
typesigs=gb._sigs,
frozen=self._frozen,
)
return dct

@classmethod
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
targetoptions, typesigs, frozen):
self = cls(py_func=py_func, signature=signature, identity=identity,
cache=cache, is_dynamic=is_dynamic,
targetoptions=targetoptions)
for sig in typesigs:
self.add(sig)
self.build_ufunc()
self._frozen = frozen
return self

def __repr__(self):
return f"<numba._GUFunc '{self.__name__}'>"

Expand Down
118 changes: 117 additions & 1 deletion numba/tests/npyufunc/test_gufunc.py
@@ -1,10 +1,12 @@
import unittest
import pickle

import numpy as np
import numpy.core.umath_tests as ut

from numba import void, float32, jit, guvectorize
from numba.np.ufunc import GUVectorize
from numba.tests.support import tag, TestCase
import unittest


def matmulcore(A, B, C):
Expand Down Expand Up @@ -281,5 +283,119 @@ class TestGUVectorizeScalarParallel(TestGUVectorizeScalar):
target = 'parallel'


class TestGUVectorizePickling(TestCase):
def test_pickle_gufunc_non_dyanmic(self):
"""Non-dynamic gufunc.
"""
@guvectorize(["f8,f8[:]"], "()->()")
def double(x, out):
out[:] = x * 2

# pickle
ser = pickle.dumps(double)
cloned = pickle.loads(ser)

# attributes carried over
self.assertEqual(cloned._frozen, double._frozen)
self.assertEqual(cloned.identity, double.identity)
self.assertEqual(cloned.is_dynamic, double.is_dynamic)
self.assertEqual(cloned.gufunc_builder._sigs,
double.gufunc_builder._sigs)
# expected value of attributes
self.assertTrue(cloned._frozen)

cloned.disable_compile()
self.assertTrue(cloned._frozen)

# scalar version
self.assertPreciseEqual(double(0.5), cloned(0.5))
# array version
arr = np.arange(10)
self.assertPreciseEqual(double(arr), cloned(arr))

def test_pickle_gufunc_dyanmic_null_init(self):
"""Dynamic gufunc w/o prepopulating before pickling.
"""
@guvectorize("()->()", identity=1)
def double(x, out):
out[:] = x * 2

# pickle
ser = pickle.dumps(double)
cloned = pickle.loads(ser)

# attributes carried over
self.assertEqual(cloned._frozen, double._frozen)
self.assertEqual(cloned.identity, double.identity)
self.assertEqual(cloned.is_dynamic, double.is_dynamic)
self.assertEqual(cloned.gufunc_builder._sigs,
double.gufunc_builder._sigs)
# expected value of attributes
self.assertFalse(cloned._frozen)

# scalar version
expect = np.zeros(1)
got = np.zeros(1)
double(0.5, out=expect)
cloned(0.5, out=got)
self.assertPreciseEqual(expect, got)
# array version
arr = np.arange(10)
expect = np.zeros_like(arr)
got = np.zeros_like(arr)
double(arr, out=expect)
cloned(arr, out=got)
self.assertPreciseEqual(expect, got)

def test_pickle_gufunc_dyanmic_initialized(self):
"""Dynamic gufunc prepopulated before pickling.
Once unpickled, we disable compilation to verify that the gufunc
compilation state is carried over.
"""
@guvectorize("()->()", identity=1)
def double(x, out):
out[:] = x * 2

# prepopulate scalar
expect = np.zeros(1)
got = np.zeros(1)
double(0.5, out=expect)
# prepopulate array
arr = np.arange(10)
expect = np.zeros_like(arr)
got = np.zeros_like(arr)
double(arr, out=expect)

# pickle
ser = pickle.dumps(double)
cloned = pickle.loads(ser)

# attributes carried over
self.assertEqual(cloned._frozen, double._frozen)
self.assertEqual(cloned.identity, double.identity)
self.assertEqual(cloned.is_dynamic, double.is_dynamic)
self.assertEqual(cloned.gufunc_builder._sigs,
double.gufunc_builder._sigs)
# expected value of attributes
self.assertFalse(cloned._frozen)

# disable compilation
cloned.disable_compile()
self.assertTrue(cloned._frozen)
# scalar version
expect = np.zeros(1)
got = np.zeros(1)
double(0.5, out=expect)
cloned(0.5, out=got)
self.assertPreciseEqual(expect, got)
# array version
expect = np.zeros_like(arr)
got = np.zeros_like(arr)
double(arr, out=expect)
cloned(arr, out=got)
self.assertPreciseEqual(expect, got)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5c1ff12

Please sign in to comment.