Skip to content

fix(dataclass): Create proper __init__ method for PyClass #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions python/mlc/dataclasses/py_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> type[ClsType]:
)
setattr(type_cls, "_mlc_structure", struct)

# Step 5. Attach methods
# Step 5. Add `__init__` method
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
# Step 6. Attach methods
fn: Callable[..., typing.Any]
type_add_method(type_index, "__init__", type_cls, 1) # static
if init:
fn = method_init(super_type_cls, d_fields)
attach_method(super_type_cls, type_cls, "__init__", fn, check_exists=True)
Expand Down Expand Up @@ -185,3 +186,14 @@ def method(self: ClsType) -> str:
return f"{type_key}({', '.join(fields)})"

return method


def _method_new(
type_cls: type[ClsType],
) -> Callable[..., ClsType]:
def method(*args: typing.Any) -> ClsType:
obj = type_cls.__new__(type_cls)
obj._mlc_init(*args) # type: ignore[attr-defined]
return obj

return method
84 changes: 43 additions & 41 deletions tests/python/test_dataclasses_copy.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,34 @@
import copy
from typing import Any, Optional

import mlc
import pytest
from mlc.testing.dataclasses import PyClassForTest


@mlc.py_class
class PyClassForTest(mlc.PyClass):
bool_: bool
i64: int
f64: float
raw_ptr: mlc.Ptr
dtype: mlc.DataType
device: mlc.Device
any: Any
func: mlc.Func
ulist: list[Any]
udict: dict
str_: str
###
list_any: list[Any]
list_list_int: list[list[int]]
dict_any_any: dict[Any, Any]
dict_str_any: dict[str, Any]
dict_any_str: dict[Any, str]
dict_str_list_int: dict[str, list[int]]
###
opt_bool: Optional[bool]
opt_i64: Optional[int]
opt_f64: Optional[float]
opt_raw_ptr: Optional[mlc.Ptr]
opt_dtype: Optional[mlc.DataType]
opt_device: Optional[mlc.Device]
opt_func: Optional[mlc.Func]
opt_ulist: Optional[list]
opt_udict: Optional[dict[Any, Any]]
opt_str: Optional[str]
###
opt_list_any: Optional[list[Any]]
opt_list_list_int: Optional[list[list[int]]]
opt_dict_any_any: Optional[dict]
opt_dict_str_any: Optional[dict[str, Any]]
opt_dict_any_str: Optional[dict[Any, str]]
opt_dict_str_list_int: Optional[dict[str, list[int]]]
@mlc.py_class(init=False)
class CustomInit(mlc.PyClass):
a: int
b: str

def i64_plus_one(self) -> int:
return self.i64 + 1
def __init__(self, *, b: str, a: int) -> None:
self.a = a
self.b = b


@pytest.fixture
def test_obj() -> CustomInit:
return CustomInit(a=1, b="hello")


@pytest.fixture
def mlc_class_for_test() -> PyClassForTest:
return PyClassForTest(
bool_=True,
i8=8,
i16=16,
i32=32,
i64=64,
f32=2,
f64=2.5,
raw_ptr=mlc.Ptr(0xDEADBEEF),
dtype="float8",
Expand All @@ -62,6 +38,7 @@ def mlc_class_for_test() -> PyClassForTest:
ulist=[1, 2.0, "three", lambda: 4],
udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4},
str_="world",
str_readonly="world",
###
list_any=[1, 2.0, "three", lambda: 4],
list_list_int=[[1, 2, 3], [4, 5, 6]],
Expand Down Expand Up @@ -95,7 +72,11 @@ def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None:
dst = copy.copy(src)
assert src != dst
assert src.bool_ == dst.bool_
assert src.i8 == dst.i8
assert src.i16 == dst.i16
assert src.i32 == dst.i32
assert src.i64 == dst.i64
assert src.f32 == dst.f32
assert src.f64 == dst.f64
assert src.raw_ptr.value == dst.raw_ptr.value
assert src.dtype == dst.dtype
Expand Down Expand Up @@ -133,7 +114,12 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
src = mlc_class_for_test
dst = copy.deepcopy(src)
assert src != dst
assert src.bool_ == dst.bool_
assert src.i8 == dst.i8
assert src.i16 == dst.i16
assert src.i32 == dst.i32
assert src.i64 == dst.i64
assert src.f32 == dst.f32
assert src.f64 == dst.f64
assert src.raw_ptr.value == dst.raw_ptr.value
assert src.dtype == dst.dtype
Expand Down Expand Up @@ -268,3 +254,19 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
and tuple(src.opt_dict_str_list_int["1"]) == tuple(dst.opt_dict_str_list_int["1"]) # type: ignore[index]
and tuple(src.opt_dict_str_list_int["2"]) == tuple(dst.opt_dict_str_list_int["2"]) # type: ignore[index]
)


def test_copy_shallow_dataclass(test_obj: CustomInit) -> None:
src = test_obj
dst = copy.copy(src)
assert src != dst
assert src.a == dst.a
assert src.b == dst.b


def test_copy_deep_dataclass(test_obj: CustomInit) -> None:
src = test_obj
dst = copy.deepcopy(src)
assert src != dst
assert src.a == dst.a
assert src.b == dst.b
Loading