Skip to content

Commit

Permalink
Merge pull request #7388 from sklam/fix/iss7356
Browse files Browse the repository at this point in the history
Patch cloudpickle to not reset dynamic class each time it is unpickled
  • Loading branch information
sklam committed Sep 22, 2021
1 parent fed8949 commit d8aa6ce
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 3 deletions.
6 changes: 6 additions & 0 deletions numba/cloudpickle/cloudpickle.py
Expand Up @@ -93,6 +93,7 @@ def g():
_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
_DYNAMIC_CLASS_TRACKER_REUSING = weakref.WeakSet()

PYPY = platform.python_implementation() == "PyPy"

Expand All @@ -117,9 +118,14 @@ def _get_or_create_tracker_id(class_def):
def _lookup_class_or_track(class_tracker_id, class_def):
if class_tracker_id is not None:
with _DYNAMIC_CLASS_TRACKER_LOCK:
orig_class_def = class_def
class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
class_tracker_id, class_def)
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
# Check if we are reusing a previous class_def
if orig_class_def is not class_def:
# Remember the class_def is being reused
_DYNAMIC_CLASS_TRACKER_REUSING.add(class_def)
return class_def


Expand Down
5 changes: 5 additions & 0 deletions numba/cloudpickle/cloudpickle_fast.py
Expand Up @@ -36,6 +36,7 @@
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
builtin_code_type,
_make_dict_keys, _make_dict_values, _make_dict_items,
_DYNAMIC_CLASS_TRACKER_REUSING,
)


Expand Down Expand Up @@ -460,6 +461,10 @@ def _function_setstate(obj, state):


def _class_setstate(obj, state):
# Check if class is being reused and needs bypass setstate logic.
if obj in _DYNAMIC_CLASS_TRACKER_REUSING:
return obj

state, slotstate = state
registry = None
for attrname, attr in state.items():
Expand Down
6 changes: 6 additions & 0 deletions numba/tests/cloudpickle_main_class.py
@@ -0,0 +1,6 @@
# Expected to run this module as __main__


# Cloudpickle will think this is a dynamic class when this module is __main__
class Klass:
classvar = None
122 changes: 119 additions & 3 deletions numba/tests/test_serialize.py
@@ -1,14 +1,18 @@
import contextlib
import gc
import pickle
import runpy
import subprocess
import sys
import unittest
from multiprocessing import get_context

import numba
from numba.core.errors import TypingError
from numba.tests.support import TestCase, tag
from .serialize_usecases import *
import unittest
from numba.tests.support import TestCase
from numba.core.target_extension import resolve_dispatcher_from_str
from numba.cloudpickle import dumps, loads
from .serialize_usecases import *


class TestDispatcherPickling(TestCase):
Expand Down Expand Up @@ -205,5 +209,117 @@ def test_numba_unpickle(self):
self.assertIs(got1, got2)



class TestCloudPickleIssues(TestCase):
"""This test case includes issues specific to the cloudpickle implementation.
"""
_numba_parallel_test_ = False

def test_dynamic_class_reset_on_unpickle(self):
# a dynamic class
class Klass:
classvar = None

def mutator():
Klass.classvar = 100

def check():
self.assertEqual(Klass.classvar, 100)

saved = dumps(Klass)
mutator()
check()
loads(saved)
# Without the patch, each `loads(saved)` will reset `Klass.classvar`
check()
loads(saved)
check()

@unittest.skipIf(__name__ == "__main__",
"Test cannot run as when module is __main__")
def test_main_class_reset_on_unpickle(self):
mp = get_context('spawn')
proc = mp.Process(target=check_main_class_reset_on_unpickle)
proc.start()
proc.join(timeout=10)
self.assertEqual(proc.exitcode, 0)

def test_dynamic_class_reset_on_unpickle_new_proc(self):
# a dynamic class
class Klass:
classvar = None

# serialize Klass in this process
saved = dumps(Klass)

# Check the reset problem in a new process
mp = get_context('spawn')
proc = mp.Process(target=check_unpickle_dyn_class_new_proc, args=(saved,))
proc.start()
proc.join(timeout=10)
self.assertEqual(proc.exitcode, 0)

def test_dynamic_class_issue_7356(self):
cfunc = numba.njit(issue_7356)
self.assertEqual(cfunc(), (100, 100))


class DynClass(object):
# For testing issue #7356
a = None


def issue_7356():
with numba.objmode(before="intp"):
DynClass.a = 100
before = DynClass.a
with numba.objmode(after="intp"):
after = DynClass.a
return before, after


def check_main_class_reset_on_unpickle():
# Load module and get its global dictionary
glbs = runpy.run_module(
"numba.tests.cloudpickle_main_class",
run_name="__main__",
)
# Get the Klass and check it is from __main__
Klass = glbs['Klass']
assert Klass.__module__ == "__main__"
assert Klass.classvar != 100
saved = dumps(Klass)
# mutate
Klass.classvar = 100
# check
_check_dyn_class(Klass, saved)


def check_unpickle_dyn_class_new_proc(saved):
Klass = loads(saved)
assert Klass.classvar != 100
# mutate
Klass.classvar = 100
# check
_check_dyn_class(Klass, saved)


def _check_dyn_class(Klass, saved):
def check():
if Klass.classvar != 100:
raise AssertionError("Check failed. Klass reset.")

check()
loaded = loads(saved)
if loaded is not Klass:
raise AssertionError("Expected reuse")
# Without the patch, each `loads(saved)` will reset `Klass.classvar`
check()
loaded = loads(saved)
if loaded is not Klass:
raise AssertionError("Expected reuse")
check()


if __name__ == '__main__':
unittest.main()

0 comments on commit d8aa6ce

Please sign in to comment.