Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve class __module__ of a jitclass instance #9042

Open
egormkn opened this issue Jun 27, 2023 · 1 comment
Open

Preserve class __module__ of a jitclass instance #9042

egormkn opened this issue Jun 27, 2023 · 1 comment

Comments

@egormkn
Copy link

egormkn commented Jun 27, 2023

Feature request

I want to use a custom Dask serializer/deserializer with jitclass, but at the moment it is not possible because the type() of jitclass instance returns a type object with __module__ set to numba.experimental.jitclass.boxing.

from numba import float64
from numba.experimental import jitclass

@jitclass([("value", float64)])
class Example:
    def __init__(self, value):
        self.value = value

instance = Example(0)

print("Class type:", Example)
# Class type: <class '__main__.Example'>

print("Instance type:", type(instance))
# Instance type: <class 'numba.experimental.jitclass.boxing.Example'>

If I register serializer for the class type, as shown in documentation, Dask can not find it when serializing an instance and using an instance type:

from distributed.protocol import dask_deserialize, dask_serialize, deserialize, serialize

@dask_serialize.register(Example)
def serialize_example(example: Example) -> tuple[dict, list[bytes]]:
    args = (example.value,)
    return serialize(args)


@dask_deserialize.register(Example)
def deserialize_example(header: dict, frames: list[bytes]) -> Example:
    args = deserialize(header, frames)
    return Example(*args)


instance = Example(0)
serialize(instance)
distributed.protocol.pickle - ERROR - Failed to serialize <numba.experimental.jitclass.boxing.Example object at 0x7fe4ed1cdcc0>.
Traceback (most recent call last):
  File "/home/egor/.cache/pypoetry/virtualenvs/forctool-5SrdLVXS-py3.10/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 63, in dumps
    result = pickle.dumps(x, **dump_kwargs)
TypeError: cannot pickle 'Example' object

And if I register serializer for the instance type, Dask fails to pickle this type object because there is no such class in numba.experimental.jitclass.boxing module:

_pickle.PicklingError: Can't pickle <class 'numba.experimental.jitclass.boxing.Example'>: attribute lookup Example on numba.experimental.jitclass.boxing failed

Workaround

I managed to get it working by manually injecting a class into numba.experimental.jitclass.boxing module and using a modification of dask serializer that recovers the original class type and allows Dask to dispatch correctly:

import pickle
from importlib import import_module
from typing import Type

from dask.utils import has_keyword, typename
from distributed.protocol.serialize import dask_deserialize, dask_serialize, register_serialization_family
from numba.core import errors
from numba.extending import as_numba_type


def dask_dumps(x, context=None):
    """Serialize object using the class-based registry"""
    typ = type(x)
    if typ.__module__ == "numba.experimental.jitclass.boxing":
        typ = getattr(import_module(typ.__module__), typ.__name__)
    type_name = typename(typ)
    try:
        dumps = dask_serialize.dispatch(typ)
    except TypeError:
        raise NotImplementedError(type_name)
    if has_keyword(dumps, "context"):
        sub_header, frames = dumps(x, context=context)
    else:
        sub_header, frames = dumps(x)

    header = {
        "sub-header": sub_header,
        "type": type_name,
        "type-serialized": pickle.dumps(typ),
        "serializer": "dask",
    }
    return header, frames


def dask_loads(header, frames):
    typ = pickle.loads(header["type-serialized"])
    loads = dask_deserialize.dispatch(typ)
    return loads(header["sub-header"], frames)


register_serialization_family("dask", dask_dumps, dask_loads)


def check_dask_serializable(x):
    if type(x) in (list, set, tuple) and len(x):
        return check_dask_serializable(next(iter(x)))
    elif type(x) is dict and len(x):
        return check_dask_serializable(next(iter(x.items()))[1])
    else:
        try:
            typ = type(x)
            if typ.__module__ == "numba.experimental.jitclass.boxing":
                typ = getattr(import_module(typ.__module__), typ.__name__)
            dask_serialize.dispatch(typ)
            return True
        except TypeError:
            pass
    return False


setattr(import_module("distributed.protocol.serialize"), "check_dask_serializable", check_dask_serializable)


def fix_dask_serialization(cls: Type) -> Type:
    try:
        as_numba_type(cls)
    except errors.TypingError:
        raise ValueError(f"{cls} is not a Numba type")

    setattr(import_module("numba.experimental.jitclass.boxing"), cls.__name__, cls)

    return cls

Usage:

from distributed.protocol.serialize import deserialize, serialize
from numba import float64
from numba.experimental import jitclass

@fix_dask_serialization
@jitclass([("value", float64)])
class Example:
    def __init__(self, value):
        self.value = value

instance = Example(0)

assert deserialize(*serialize(instance)).value == instance.value

But I think that it needs to be fixed in Numba by passing a real module path from a jitclass decorator to the instances.

@esc
Copy link
Member

esc commented Jun 29, 2023

@egormkn thank you, I have labelled it accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants