From a3236418b4463bc49dd433af581edb47969202a0 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 22 Oct 2023 21:50:58 +0800 Subject: [PATCH] dev(hansbug): fix MoductDict register --- treevalue/tree/integration/jax.py | 9 --------- treevalue/tree/integration/torch.py | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/treevalue/tree/integration/jax.py b/treevalue/tree/integration/jax.py index e956bba60c..bb057622eb 100644 --- a/treevalue/tree/integration/jax.py +++ b/treevalue/tree/integration/jax.py @@ -1,8 +1,6 @@ import warnings from functools import wraps -from ..tree import register_dict_type - try: import jax from jax.tree_util import register_pytree_node @@ -22,10 +20,3 @@ def register_for_jax(cls): register_for_jax(TreeValue) register_for_jax(FastTreeValue) - -try: - from torch.nn import ModuleDict -except (ModuleNotFoundError, ImportError): - pass -else: - register_dict_type(ModuleDict, ModuleDict.items) diff --git a/treevalue/tree/integration/torch.py b/treevalue/tree/integration/torch.py index 98689398b5..ec5316d276 100644 --- a/treevalue/tree/integration/torch.py +++ b/treevalue/tree/integration/torch.py @@ -1,6 +1,8 @@ import warnings from functools import wraps +from ..tree import register_dict_type + try: import torch from torch.utils._pytree import _register_pytree_node @@ -20,3 +22,10 @@ def register_for_torch(cls): register_for_torch(TreeValue) register_for_torch(FastTreeValue) + +try: + from torch.nn import ModuleDict +except (ModuleNotFoundError, ImportError): + pass +else: + register_dict_type(ModuleDict, ModuleDict.items)