Skip to content

Commit

Permalink
setup
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Nov 21, 2018
1 parent f2254da commit baa3533
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __call__(self, *args, **kwargs):
def __getattr__(self, item):
return getattr(self.model, item)

def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method: Union[str, FunctionType, None] ,
def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method: Union[str, FunctionType, None],
show_structure: bool):
"""Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.
Expand Down Expand Up @@ -356,7 +356,7 @@ def _extract_module(self, data_parallel_model: DataParallel, extract_weights=Tru
weights = self._fix_weights(weights)
return model, weights

def _fix_weights(self, weights: OrderedDict):
def _fix_weights(self, weights: Union[dict, OrderedDict]):
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
'Intended Audience :: Developers',
"Programming Language :: Python :: 3 :: Only",
'Programming Language :: Python :: Implementation',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
],
# 此项需要,否则卸载时报windows error
zip_safe=False
Expand Down

0 comments on commit baa3533

Please sign in to comment.