New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] lowering of constant jitclass #5111
Conversation
@luk-f-a thanks for submitting this, your efforts to improve Numba are appreciated as always. There are some remaining flake8 issues to be resolved:
FWIW: I have been able to largely eliminate such nuisances using the pre-commit hooks as described here: https://numba.pydata.org/numba-doc/dev/developer/contributing.html#coding-conventions -- maybe it could be helpful for you too? |
hi @esc , thanks! sorry, I should have said at the top that this is very early wip. I don't even know if this is the right way to lower a constant jitclass. It's also missing tests and flake8 polish. |
@luk-f-a no problem, I will label as 'in progress' for now. Please just ping us as usual if you get it into a reviewable state, thanks! |
Below is our attempt at lowering constant jitclasses. It is based on the unboxing code from
from llvmlite import ir
from numba import cgutils, jit, jitclass, types
from numba.targets.imputils import lower_constant
from numba.jitclass import _box
@lower_constant(types.ClassInstanceType)
def _lower_constant_class_instance(context, builder, typ, pyval):
def access_member(obj, member_offset):
# Access member by byte offset
offset = context.get_constant(types.uintp, member_offset)
llvoidptr = ir.IntType(8).as_pointer()
ptr = cgutils.pointer_add(builder, obj, offset)
casted = builder.bitcast(ptr, llvoidptr.as_pointer())
return builder.load(casted)
struct_cls = cgutils.create_struct_proxy(typ)
inst = struct_cls(context, builder)
# get a pointer to pyval
obj = context.add_dynamic_addr(builder, id(pyval), '')
# load from Python object
ptr_meminfo = access_member(obj, _box.box_meminfoptr_offset)
ptr_dataptr = access_member(obj, _box.box_dataptr_offset)
# store to native structure
inst.meminfo = builder.bitcast(ptr_meminfo, inst.meminfo.type)
inst.data = builder.bitcast(ptr_dataptr, inst.data.type)
return inst._getvalue()
if __name__ == '__main__':
spec = [
('value', types.float64)
]
@jitclass(spec)
class TestClass:
def __init__(self, value):
self.value = value
const = TestClass(42.)
@jit
def get_const_jitclass():
return const
const2 = get_const_jitclass()
assert const.value == const2.value |
bump... I've run into this as well, more than a few times. Any chance of this getting merged? |
@nelson2005 thank you for asking about this. I believe this PR or the suggested patch will need some more work (specifically: tests) before it can be scheduled for review. |
since structref is the future of jitclasses, I'm closing this PR. Anyone needing a work around for the current jitclass can use the one provided above. |
What's the timeframe for for jitclasses meeting their structref incarnation? |
closes #3781
While working on numba-scipy, I wanted to test using jitclasses to implement statistical distributions. However, to achieve jit transparency, already instantiated scipy distributions must be available within jitted code (without being passed as arguments). This happens because Scipy.stats api is based not on classes that the user instantiates as needed, but but pre-instantiating an object (for each distribution!) and directing the user to call on the methods on that object.
I re-used the code from the constructor of the jitclass, but I don't know how to copy the attributes from
pyval
into the newly created structure.EDIT: This is early WIP in order to gather feedback. I am not even sure this is the right way to lower a constant jitclass. Once the main approach is solid, I'll create tests and fix the flake8 warnings.