Skip to content

Commit

Permalink
dev(hansbug): fix MoductDict register
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Oct 22, 2023
1 parent 4f9a9f1 commit a323641
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
9 changes: 0 additions & 9 deletions treevalue/tree/integration/jax.py
@@ -1,8 +1,6 @@
import warnings
from functools import wraps

from ..tree import register_dict_type

try:
import jax
from jax.tree_util import register_pytree_node
Expand All @@ -22,10 +20,3 @@ def register_for_jax(cls):

register_for_jax(TreeValue)
register_for_jax(FastTreeValue)

try:
from torch.nn import ModuleDict
except (ModuleNotFoundError, ImportError):
pass
else:
register_dict_type(ModuleDict, ModuleDict.items)
9 changes: 9 additions & 0 deletions treevalue/tree/integration/torch.py
@@ -1,6 +1,8 @@
import warnings
from functools import wraps

from ..tree import register_dict_type

try:
import torch
from torch.utils._pytree import _register_pytree_node
Expand All @@ -20,3 +22,10 @@ def register_for_torch(cls):

register_for_torch(TreeValue)
register_for_torch(FastTreeValue)

try:
from torch.nn import ModuleDict
except (ModuleNotFoundError, ImportError):
pass

Check warning on line 29 in treevalue/tree/integration/torch.py

View check run for this annotation

Codecov / codecov/patch

treevalue/tree/integration/torch.py#L28-L29

Added lines #L28 - L29 were not covered by tests
else:
register_dict_type(ModuleDict, ModuleDict.items)

0 comments on commit a323641

Please sign in to comment.