From 1795db005470f4610bcd2d544708fdd3c73387c4 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 13 Oct 2023 13:29:17 +0800 Subject: [PATCH 1/7] dev(hansbug): add register custom dicts --- test/tree/general/base.py | 44 +++++++++++++++++++++++++++ test/tree/tree/base.py | 45 +++++++++++++++++++++++++++- treevalue/tree/integration/jax.py | 9 ++++++ treevalue/tree/tree/__init__.py | 2 +- treevalue/tree/tree/tree.pxd | 1 + treevalue/tree/tree/tree.pyx | 50 +++++++++++++++++++++++-------- 6 files changed, 136 insertions(+), 15 deletions(-) diff --git a/test/tree/general/base.py b/test/tree/general/base.py index 169e9257c3..7da685e0b8 100644 --- a/test/tree/general/base.py +++ b/test/tree/general/base.py @@ -1,3 +1,4 @@ +import collections.abc import unittest from functools import reduce from operator import __mul__ @@ -7,10 +8,25 @@ import pytest from hbutils.testing import cmdv +from treevalue import register_dict_type from treevalue.tree import func_treelize, TreeValue, raw, mapping, delayed, FastTreeValue from ..tree.base import get_treevalue_test +class CustomMapping(collections.abc.Mapping): + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): class Container: def __init__(self, value): @@ -813,4 +829,32 @@ def test_unpack(self): assert y == pytest.approx(7.7) assert z is None + def test_init_with_custom_mapping_type(self): + origin_t = CustomMapping(a=1, b=2, c={'x': 15, 'y': CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + + def test_init_with_custom_type(self): + class _CustomMapping: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def iter_items(self): + yield from self._kwargs.items() + + register_dict_type(_CustomMapping, _CustomMapping.iter_items) + + origin_t = _CustomMapping(a=1, b=2, c={'x': 15, 'y': _CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + return _TestClass diff --git a/test/tree/tree/base.py b/test/tree/tree/base.py index c67a9c6f56..6b1913391c 100644 --- a/test/tree/tree/base.py +++ b/test/tree/tree/base.py @@ -1,3 +1,4 @@ +import collections.abc import pickle import re import unittest @@ -7,7 +8,7 @@ from hbutils.testing import OS, cmdv from test.tree.tree.test_constraint import GreaterThanConstraint -from treevalue import raw, TreeValue, delayed, ValidationError +from treevalue import raw, TreeValue, delayed, ValidationError, register_dict_type from treevalue.tree.common import create_storage from treevalue.tree.tree.constraint import cleaf @@ -43,6 +44,20 @@ def __hash__(self): return hash((self.__value,)) +class CustomMapping(collections.abc.Mapping): + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def get_treevalue_test(treevalue_class: Type[TreeValue]): # noinspection DuplicatedCode,PyMethodMayBeStatic @@ -787,4 +802,32 @@ def test_repr_jpeg(self): assert min_size <= len(_repr_jpeg_) <= max_size, \ f'Size within [{min_size!r}, {max_size!r}] required, but {len(_repr_jpeg_)!r} found.' + def test_init_with_custom_mapping_type(self): + origin_t = CustomMapping(a=1, b=2, c={'x': 15, 'y': CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + + def test_init_with_custom_type(self): + class _CustomMapping: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def iter_items(self): + yield from self._kwargs.items() + + register_dict_type(_CustomMapping, _CustomMapping.iter_items) + + origin_t = _CustomMapping(a=1, b=2, c={'x': 15, 'y': _CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + return _TestClass diff --git a/treevalue/tree/integration/jax.py b/treevalue/tree/integration/jax.py index bb057622eb..e956bba60c 100644 --- a/treevalue/tree/integration/jax.py +++ b/treevalue/tree/integration/jax.py @@ -1,6 +1,8 @@ import warnings from functools import wraps +from ..tree import register_dict_type + try: import jax from jax.tree_util import register_pytree_node @@ -20,3 +22,10 @@ def register_for_jax(cls): register_for_jax(TreeValue) register_for_jax(FastTreeValue) + +try: + from torch.nn import ModuleDict +except (ModuleNotFoundError, ImportError): + pass +else: + register_dict_type(ModuleDict, ModuleDict.items) diff --git a/treevalue/tree/tree/__init__.py b/treevalue/tree/tree/__init__.py index 461edb95f1..63c7a7c155 100644 --- a/treevalue/tree/tree/__init__.py +++ b/treevalue/tree/tree/__init__.py @@ -5,4 +5,4 @@ from .io import loads, load, dumps, dump from .service import jsonify, clone, typetrans, walk from .structural import subside, union, rise -from .tree import TreeValue, delayed, ValidationError +from .tree import TreeValue, delayed, ValidationError, register_dict_type diff --git a/treevalue/tree/tree/tree.pxd b/treevalue/tree/tree/tree.pxd index 446e154ed1..47af63565a 100644 --- a/treevalue/tree/tree/tree.pxd +++ b/treevalue/tree/tree/tree.pxd @@ -14,6 +14,7 @@ cdef class _SimplifiedConstraintProxy: cdef readonly Constraint cons cdef Constraint _c_get_constraint(object cons) +cpdef register_dict_type(object type_, object f_items) cdef class ValidationError(Exception): cdef readonly TreeValue _object diff --git a/treevalue/tree/tree/tree.pyx b/treevalue/tree/tree/tree.pyx index 5a744c9a51..427fc6ef5c 100644 --- a/treevalue/tree/tree/tree.pyx +++ b/treevalue/tree/tree/tree.pyx @@ -31,24 +31,47 @@ cdef inline object _item_unwrap(object v): return v _GET_NO_DEFAULT = SingletonMark('get_no_default') +_KNOWN_DICT_TYPES = { + Mapping: Mapping.items, +} -cdef inline TreeStorage _dict_unpack(dict d): +cpdef inline register_dict_type(object type_, object f_items): + if isinstance(object, type): + _KNOWN_DICT_TYPES[type_] = f_items + else: + raise TypeError(f'Not a type - {type_!r}.') + +_DEFAULT_STORAGE = create_storage({}) + +cdef inline TreeStorage _generic_dict_unpack(object d): cdef str k cdef object v cdef dict result = {} - for k, v in d.items(): + cdef object d_items = None + if isinstance(d, dict): + d_items = d.items() + else: + for d_type, df_items in _KNOWN_DICT_TYPES.items(): + if isinstance(d, d_type): + d_items = df_items(d) + break + if d_items is None: + raise TypeError(f'Unknown dict type - {d!r}.') + + for k, v in d_items: if isinstance(v, dict): - result[k] = _dict_unpack(v) - elif isinstance(v, TreeValue): + result[k] = _generic_dict_unpack(v) + if isinstance(v, TreeValue): result[k] = v._detach() else: - result[k] = v + try: + result[k] = _generic_dict_unpack(v) + except TypeError: + result[k] = v return create_storage(result) -_DEFAULT_STORAGE = create_storage({}) - cdef class _SimplifiedConstraintProxy: def __cinit__(self, Constraint cons): self.cons = cons @@ -131,13 +154,14 @@ cdef class TreeValue: self._child_constraints = data._child_constraints else: self.constraint = _c_get_constraint(constraint) - elif isinstance(data, dict): - self._st = _dict_unpack(data) - self.constraint = _c_get_constraint(constraint) else: - raise TypeError( - "Unknown initialization type for tree value - {type}.".format( - type=repr(type(data).__name__))) + try: + self._st = _generic_dict_unpack(data) + self.constraint = _c_get_constraint(constraint) + except TypeError: + raise TypeError( + "Unknown initialization type for tree value - {type}.".format( + type=repr(type(data).__name__))) def __getnewargs_ex__(self): # for __cinit__, when pickle.loads return ({},), {} From c6a49deebf043ec9f64500279039701b03e363c0 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 20 Oct 2023 12:18:54 +0800 Subject: [PATCH 2/7] dev(hansbug): limit the versions of torch and enum_tools --- requirements-doc.txt | 2 +- requirements-test-extra.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-doc.txt b/requirements-doc.txt index cfa7f693be..88151b1d49 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,7 +1,7 @@ Jinja2~=3.0.0 sphinx~=3.2.0 sphinx_rtd_theme~=0.4.3 -enum_tools +enum_tools~=0.9.0 sphinx-toolbox plantumlcli>=0.0.4 packaging diff --git a/requirements-test-extra.txt b/requirements-test-extra.txt index 7744e14635..11a83aedc7 100644 --- a/requirements-test-extra.txt +++ b/requirements-test-extra.txt @@ -1,2 +1,2 @@ jax[cpu]>=0.3.25; platform_system != 'Windows' -torch>=1.1.0; python_version < '3.12' +torch>=1.1.0,<2.1.0; python_version < '3.12' From a6dd66f8414fd69bcb611c48c68fc9c0604f0dd8 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 20 Oct 2023 13:42:57 +0800 Subject: [PATCH 3/7] dev(hansbug): timeout set to 60 on unittest --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 0731c72161..07837139d0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -timeout = 20 +timeout = 60 markers = unittest benchmark From 1a23d12bdef4f369ec3b4651700353628e67d7a9 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 20 Oct 2023 14:33:52 +0800 Subject: [PATCH 4/7] dev(hansbug): add docs for it --- docs/source/api_doc/tree/tree.rst | 8 ++++++++ test/tree/integration/test_torch.py | 17 +++++++++++++++++ treevalue/tree/tree/tree.pyx | 12 ++++++++++++ 3 files changed, 37 insertions(+) diff --git a/docs/source/api_doc/tree/tree.rst b/docs/source/api_doc/tree/tree.rst index 414ca935f8..a50ce50905 100644 --- a/docs/source/api_doc/tree/tree.rst +++ b/docs/source/api_doc/tree/tree.rst @@ -220,3 +220,11 @@ loads .. autofunction:: loads + +register_dict_type +---------------------------- + +.. autofunction:: register_dict_type + + + diff --git a/test/tree/integration/test_torch.py b/test/tree/integration/test_torch.py index a9a0a8e1cb..2ce35a0b78 100644 --- a/test/tree/integration/test_torch.py +++ b/test/tree/integration/test_torch.py @@ -161,3 +161,20 @@ def forward(self, x): 'c': torch.Size([14]), 'd': torch.Size([2, 5, 3]), }) + + @skipUnless(vpip('torch') and OS.linux and vpython < '3.11', 'torch required') + def test_moduledict(self): + with torch.no_grad(): + md = torch.nn.ModuleDict({ + 'a': torch.nn.Linear(3, 5), + 'b': torch.nn.Linear(3, 6), + }) + t = FastTreeValue(md) + + input_ = torch.randn(2, 3) + output_ = t(input_) + + assert output_.shape == FastTreeValue({ + 'a': (2, 5), + 'b': (2, 6), + }) diff --git a/treevalue/tree/tree/tree.pyx b/treevalue/tree/tree/tree.pyx index 427fc6ef5c..9aea3c5c58 100644 --- a/treevalue/tree/tree/tree.pyx +++ b/treevalue/tree/tree/tree.pyx @@ -35,7 +35,19 @@ _KNOWN_DICT_TYPES = { Mapping: Mapping.items, } +@cython.binding(True) cpdef inline register_dict_type(object type_, object f_items): + """ + Overview: + Register dict-like type for TreeValue. + + :param type_: Type to register. + :param f_items: Function to get items, the format should be like ``dict.items``. + + .. note:: + If torch detected, the ``torch.nn.ModuleDict`` is registered by default. + + """ if isinstance(object, type): _KNOWN_DICT_TYPES[type_] = f_items else: From 4f9a9f1a0037b166c681b47373045f459fa0cedb Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 20 Oct 2023 15:28:22 +0800 Subject: [PATCH 5/7] dev(hansbug): fix bug in unittest --- test/tree/integration/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tree/integration/test_torch.py b/test/tree/integration/test_torch.py index 2ce35a0b78..d56da148ed 100644 --- a/test/tree/integration/test_torch.py +++ b/test/tree/integration/test_torch.py @@ -162,7 +162,7 @@ def forward(self, x): 'd': torch.Size([2, 5, 3]), }) - @skipUnless(vpip('torch') and OS.linux and vpython < '3.11', 'torch required') + @skipUnless(torch is not None and vpip('torch') and OS.linux and vpython < '3.11', 'torch required') def test_moduledict(self): with torch.no_grad(): md = torch.nn.ModuleDict({ From a3236418b4463bc49dd433af581edb47969202a0 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 22 Oct 2023 21:50:58 +0800 Subject: [PATCH 6/7] dev(hansbug): fix MoductDict register --- treevalue/tree/integration/jax.py | 9 --------- treevalue/tree/integration/torch.py | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/treevalue/tree/integration/jax.py b/treevalue/tree/integration/jax.py index e956bba60c..bb057622eb 100644 --- a/treevalue/tree/integration/jax.py +++ b/treevalue/tree/integration/jax.py @@ -1,8 +1,6 @@ import warnings from functools import wraps -from ..tree import register_dict_type - try: import jax from jax.tree_util import register_pytree_node @@ -22,10 +20,3 @@ def register_for_jax(cls): register_for_jax(TreeValue) register_for_jax(FastTreeValue) - -try: - from torch.nn import ModuleDict -except (ModuleNotFoundError, ImportError): - pass -else: - register_dict_type(ModuleDict, ModuleDict.items) diff --git a/treevalue/tree/integration/torch.py b/treevalue/tree/integration/torch.py index 98689398b5..ec5316d276 100644 --- a/treevalue/tree/integration/torch.py +++ b/treevalue/tree/integration/torch.py @@ -1,6 +1,8 @@ import warnings from functools import wraps +from ..tree import register_dict_type + try: import torch from torch.utils._pytree import _register_pytree_node @@ -20,3 +22,10 @@ def register_for_torch(cls): register_for_torch(TreeValue) register_for_torch(FastTreeValue) + +try: + from torch.nn import ModuleDict +except (ModuleNotFoundError, ImportError): + pass +else: + register_dict_type(ModuleDict, ModuleDict.items) From ab350d69dda1fb4400637b3a55e35cf2248e93ba Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 22 Oct 2023 21:58:28 +0800 Subject: [PATCH 7/7] dev(hansbug): move CustomMapping for test.testings --- test/testings/__init__.py | 1 + test/testings/mapping.py | 15 +++++++++++++++ test/tree/general/base.py | 14 +------------- test/tree/tree/base.py | 16 +--------------- 4 files changed, 18 insertions(+), 28 deletions(-) create mode 100644 test/testings/__init__.py create mode 100644 test/testings/mapping.py diff --git a/test/testings/__init__.py b/test/testings/__init__.py new file mode 100644 index 0000000000..4937e98b70 --- /dev/null +++ b/test/testings/__init__.py @@ -0,0 +1 @@ +from .mapping import CustomMapping diff --git a/test/testings/mapping.py b/test/testings/mapping.py new file mode 100644 index 0000000000..4826bf9ab4 --- /dev/null +++ b/test/testings/mapping.py @@ -0,0 +1,15 @@ +import collections.abc + + +class CustomMapping(collections.abc.Mapping): + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs diff --git a/test/tree/general/base.py b/test/tree/general/base.py index 7da685e0b8..2d8b93b16c 100644 --- a/test/tree/general/base.py +++ b/test/tree/general/base.py @@ -11,21 +11,9 @@ from treevalue import register_dict_type from treevalue.tree import func_treelize, TreeValue, raw, mapping, delayed, FastTreeValue from ..tree.base import get_treevalue_test +from ...testings import CustomMapping -class CustomMapping(collections.abc.Mapping): - def __init__(self, **kwargs): - self._kwargs = kwargs - - def __getitem__(self, __key): - return self._kwargs[__key] - - def __len__(self): - return len(self._kwargs) - - def __iter__(self): - yield from self._kwargs - def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): class Container: diff --git a/test/tree/tree/base.py b/test/tree/tree/base.py index 6b1913391c..19a593b9e2 100644 --- a/test/tree/tree/base.py +++ b/test/tree/tree/base.py @@ -1,4 +1,3 @@ -import collections.abc import pickle import re import unittest @@ -11,6 +10,7 @@ from treevalue import raw, TreeValue, delayed, ValidationError, register_dict_type from treevalue.tree.common import create_storage from treevalue.tree.tree.constraint import cleaf +from ...testings import CustomMapping try: _ = reversed({}.keys()) @@ -44,20 +44,6 @@ def __hash__(self): return hash((self.__value,)) -class CustomMapping(collections.abc.Mapping): - def __init__(self, **kwargs): - self._kwargs = kwargs - - def __getitem__(self, __key): - return self._kwargs[__key] - - def __len__(self): - return len(self._kwargs) - - def __iter__(self): - yield from self._kwargs - - def get_treevalue_test(treevalue_class: Type[TreeValue]): # noinspection DuplicatedCode,PyMethodMayBeStatic