From 1795db005470f4610bcd2d544708fdd3c73387c4 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 13 Oct 2023 13:29:17 +0800 Subject: [PATCH] dev(hansbug): add register custom dicts --- test/tree/general/base.py | 44 +++++++++++++++++++++++++++ test/tree/tree/base.py | 45 +++++++++++++++++++++++++++- treevalue/tree/integration/jax.py | 9 ++++++ treevalue/tree/tree/__init__.py | 2 +- treevalue/tree/tree/tree.pxd | 1 + treevalue/tree/tree/tree.pyx | 50 +++++++++++++++++++++++-------- 6 files changed, 136 insertions(+), 15 deletions(-) diff --git a/test/tree/general/base.py b/test/tree/general/base.py index 169e9257c3..7da685e0b8 100644 --- a/test/tree/general/base.py +++ b/test/tree/general/base.py @@ -1,3 +1,4 @@ +import collections.abc import unittest from functools import reduce from operator import __mul__ @@ -7,10 +8,25 @@ import pytest from hbutils.testing import cmdv +from treevalue import register_dict_type from treevalue.tree import func_treelize, TreeValue, raw, mapping, delayed, FastTreeValue from ..tree.base import get_treevalue_test +class CustomMapping(collections.abc.Mapping): + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): class Container: def __init__(self, value): @@ -813,4 +829,32 @@ def test_unpack(self): assert y == pytest.approx(7.7) assert z is None + def test_init_with_custom_mapping_type(self): + origin_t = CustomMapping(a=1, b=2, c={'x': 15, 'y': CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + + def test_init_with_custom_type(self): + class _CustomMapping: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def iter_items(self): + yield from self._kwargs.items() + + register_dict_type(_CustomMapping, _CustomMapping.iter_items) + + origin_t = _CustomMapping(a=1, b=2, c={'x': 15, 'y': _CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + return _TestClass diff --git a/test/tree/tree/base.py b/test/tree/tree/base.py index c67a9c6f56..6b1913391c 100644 --- a/test/tree/tree/base.py +++ b/test/tree/tree/base.py @@ -1,3 +1,4 @@ +import collections.abc import pickle import re import unittest @@ -7,7 +8,7 @@ from hbutils.testing import OS, cmdv from test.tree.tree.test_constraint import GreaterThanConstraint -from treevalue import raw, TreeValue, delayed, ValidationError +from treevalue import raw, TreeValue, delayed, ValidationError, register_dict_type from treevalue.tree.common import create_storage from treevalue.tree.tree.constraint import cleaf @@ -43,6 +44,20 @@ def __hash__(self): return hash((self.__value,)) +class CustomMapping(collections.abc.Mapping): + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def get_treevalue_test(treevalue_class: Type[TreeValue]): # noinspection DuplicatedCode,PyMethodMayBeStatic @@ -787,4 +802,32 @@ def test_repr_jpeg(self): assert min_size <= len(_repr_jpeg_) <= max_size, \ f'Size within [{min_size!r}, {max_size!r}] required, but {len(_repr_jpeg_)!r} found.' + def test_init_with_custom_mapping_type(self): + origin_t = CustomMapping(a=1, b=2, c={'x': 15, 'y': CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + + def test_init_with_custom_type(self): + class _CustomMapping: + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, __key): + return self._kwargs[__key] + + def __len__(self): + return len(self._kwargs) + + def __iter__(self): + yield from self._kwargs + + def iter_items(self): + yield from self._kwargs.items() + + register_dict_type(_CustomMapping, _CustomMapping.iter_items) + + origin_t = _CustomMapping(a=1, b=2, c={'x': 15, 'y': _CustomMapping(z=100)}) + t = treevalue_class(origin_t) + assert t == treevalue_class({'a': 1, 'b': 2, 'c': {'x': 15, 'y': {'z': 100}}}) + return _TestClass diff --git a/treevalue/tree/integration/jax.py b/treevalue/tree/integration/jax.py index bb057622eb..e956bba60c 100644 --- a/treevalue/tree/integration/jax.py +++ b/treevalue/tree/integration/jax.py @@ -1,6 +1,8 @@ import warnings from functools import wraps +from ..tree import register_dict_type + try: import jax from jax.tree_util import register_pytree_node @@ -20,3 +22,10 @@ def register_for_jax(cls): register_for_jax(TreeValue) register_for_jax(FastTreeValue) + +try: + from torch.nn import ModuleDict +except (ModuleNotFoundError, ImportError): + pass +else: + register_dict_type(ModuleDict, ModuleDict.items) diff --git a/treevalue/tree/tree/__init__.py b/treevalue/tree/tree/__init__.py index 461edb95f1..63c7a7c155 100644 --- a/treevalue/tree/tree/__init__.py +++ b/treevalue/tree/tree/__init__.py @@ -5,4 +5,4 @@ from .io import loads, load, dumps, dump from .service import jsonify, clone, typetrans, walk from .structural import subside, union, rise -from .tree import TreeValue, delayed, ValidationError +from .tree import TreeValue, delayed, ValidationError, register_dict_type diff --git a/treevalue/tree/tree/tree.pxd b/treevalue/tree/tree/tree.pxd index 446e154ed1..47af63565a 100644 --- a/treevalue/tree/tree/tree.pxd +++ b/treevalue/tree/tree/tree.pxd @@ -14,6 +14,7 @@ cdef class _SimplifiedConstraintProxy: cdef readonly Constraint cons cdef Constraint _c_get_constraint(object cons) +cpdef register_dict_type(object type_, object f_items) cdef class ValidationError(Exception): cdef readonly TreeValue _object diff --git a/treevalue/tree/tree/tree.pyx b/treevalue/tree/tree/tree.pyx index 5a744c9a51..427fc6ef5c 100644 --- a/treevalue/tree/tree/tree.pyx +++ b/treevalue/tree/tree/tree.pyx @@ -31,24 +31,47 @@ cdef inline object _item_unwrap(object v): return v _GET_NO_DEFAULT = SingletonMark('get_no_default') +_KNOWN_DICT_TYPES = { + Mapping: Mapping.items, +} -cdef inline TreeStorage _dict_unpack(dict d): +cpdef inline register_dict_type(object type_, object f_items): + if isinstance(object, type): + _KNOWN_DICT_TYPES[type_] = f_items + else: + raise TypeError(f'Not a type - {type_!r}.') + +_DEFAULT_STORAGE = create_storage({}) + +cdef inline TreeStorage _generic_dict_unpack(object d): cdef str k cdef object v cdef dict result = {} - for k, v in d.items(): + cdef object d_items = None + if isinstance(d, dict): + d_items = d.items() + else: + for d_type, df_items in _KNOWN_DICT_TYPES.items(): + if isinstance(d, d_type): + d_items = df_items(d) + break + if d_items is None: + raise TypeError(f'Unknown dict type - {d!r}.') + + for k, v in d_items: if isinstance(v, dict): - result[k] = _dict_unpack(v) - elif isinstance(v, TreeValue): + result[k] = _generic_dict_unpack(v) + if isinstance(v, TreeValue): result[k] = v._detach() else: - result[k] = v + try: + result[k] = _generic_dict_unpack(v) + except TypeError: + result[k] = v return create_storage(result) -_DEFAULT_STORAGE = create_storage({}) - cdef class _SimplifiedConstraintProxy: def __cinit__(self, Constraint cons): self.cons = cons @@ -131,13 +154,14 @@ cdef class TreeValue: self._child_constraints = data._child_constraints else: self.constraint = _c_get_constraint(constraint) - elif isinstance(data, dict): - self._st = _dict_unpack(data) - self.constraint = _c_get_constraint(constraint) else: - raise TypeError( - "Unknown initialization type for tree value - {type}.".format( - type=repr(type(data).__name__))) + try: + self._st = _generic_dict_unpack(data) + self.constraint = _c_get_constraint(constraint) + except TypeError: + raise TypeError( + "Unknown initialization type for tree value - {type}.".format( + type=repr(type(data).__name__))) def __getnewargs_ex__(self): # for __cinit__, when pickle.loads return ({},), {}