Skip to content

Commit

Permalink
Add auto-defined constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
sklam committed Jul 10, 2020
1 parent 3abf6ee commit 72a7d0f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
30 changes: 30 additions & 0 deletions numba/core/structref.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
unbox,
NativeValue,
intrinsic,
overload,
)
from numba.core.typing.templates import AttributeTemplate

Expand Down Expand Up @@ -147,6 +148,35 @@ def unbox_struct_ref(typ, obj, c):
return NativeValue(out)


def define_constructor(struct_typeclass, py_class, fields):

def _names(text):
return [name.strip() for name in text.split(',')]

params = ', '.join(fields)

indent = ' ' * 8
init_fields_buf = []
for k in fields:
init_fields_buf.append(f"st.{k} = {k}")
init_fields = f'\n{indent}'.join(init_fields_buf)

source = f"""
def ctor({params}):
struct_type = struct_typeclass(list(zip({list(fields)}, [{params}])))
def impl({params}):
st = new(struct_type)
{init_fields}
return st
return impl
"""

glbs = dict(struct_typeclass=struct_typeclass, new=new)
exec(source, glbs)
ctor = glbs['ctor']
overload(py_class)(ctor)


def register(struct_type):
default_manager.register(struct_type, models.StructRefModel)
define_attributes(struct_type)
Expand Down
2 changes: 1 addition & 1 deletion numba/core/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ class StructPayloadType(Type):
def __init__(self, typename, fields):
self._typename = typename
self._fields = tuple(fields)
super().__init__(name=f"mutstruct.payload.{typename}")
super().__init__(name=f"mutstruct.payload.{typename}.{self._fields}")

@property
def fields(self):
Expand Down
55 changes: 39 additions & 16 deletions numba/tests/test_struct_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@


class MyStruct(types.Type):
def __init__(self, typename, fields):
self._typename = typename
def __init__(self, fields):
self._fields = tuple(fields)
super().__init__(name=f"numba.structref.{self.__class__.__name__}")
classname = self.__class__.__name__
super().__init__(name=f"numba.structref.{classname}{self._fields}")

@property
def fields(self):
Expand All @@ -30,14 +30,14 @@ def field_dict(self):

def get_data_type(self):
return types.StructPayloadType(
typename=self._typename, fields=self._fields,
typename=self.__class__.__name__, fields=self._fields,
)


structref.register(MyStruct)

my_struct_ty = MyStruct(
"MyStruct", fields=[("values", types.intp[:]), ("counter", types.intp)]
fields=[("values", types.intp[:]), ("counter", types.intp)]
)


Expand All @@ -51,11 +51,17 @@ def _numba_type_(self):
return self._ty


def object_ctor(ty, mi):
def _my_struct_wrap_ctor(ty, mi):
return MyStructWrap(ty, mi)


structref.define_boxing(MyStruct, object_ctor)
structref.define_boxing(MyStruct, _my_struct_wrap_ctor)

structref.define_constructor(
lambda xs: MyStruct(fields=xs),
MyStructWrap,
['values', 'counter'],
)


@njit
Expand All @@ -72,32 +78,49 @@ def my_struct_init(self, values, counter):


@njit
def foo(vs, ctr):
def ctor_by_intrinsic(vs, ctr):
st = my_struct(vs, counter=ctr)
st.values += st.values
st.counter *= ctr
return st


@njit
def ctor_by_class(vs, ctr):
return MyStructWrap(values=vs, counter=ctr)


@njit
def get_values(st):
return st.values


@njit
def bar(st):
def compute_fields(st):
return st.values + st.counter


class TestStructRef(MemoryLeakMixin, TestCase):
def test_basic(self):
def test_ctor_by_intrinsic(self):
vs = np.arange(10, dtype=np.intp)
ctr = 10

foo_expected = vs + vs
foo_got = foo(vs, ctr)
self.assertPreciseEqual(foo_expected, get_values(foo_got))
first_expected = vs + vs
first_got = ctor_by_intrinsic(vs, ctr)
self.assertPreciseEqual(first_expected, get_values(first_got))

second_expected = first_expected + (ctr * ctr)
second_got = compute_fields(first_got)
self.assertPreciseEqual(second_expected, second_got)

def test_ctor_by_class(self):
vs = np.arange(10, dtype=np.float64)
ctr = 10

first_expected = vs.copy()
first_got = ctor_by_class(vs, ctr)
self.assertPreciseEqual(first_expected, get_values(first_got))

bar_expected = foo_expected + (ctr * ctr)
bar_got = bar(foo_got)
self.assertPreciseEqual(bar_expected, bar_got)
second_expected = first_expected + ctr
second_got = compute_fields(first_got)
self.assertPreciseEqual(second_expected, second_got)

0 comments on commit 72a7d0f

Please sign in to comment.