Skip to content

Commit

Permalink
Fix os.environ in super and model
Browse files Browse the repository at this point in the history
  • Loading branch information
308188605@qq.com committed Jun 1, 2019
1 parent 4c273b9 commit 1125c65
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
21 changes: 12 additions & 9 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def __init__(self, proto_model: Module,
show_structure=False,
check_point_pos=None, verbose=True):

if not isinstance(proto_model, Module):
raise TypeError(
"The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
# if not isinstance(proto_model, Module):
# raise TypeError(
# "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
self.model: Union[DataParallel, Module] = None
self.model_name = proto_model.__class__.__name__
self.weights_init = None
Expand Down Expand Up @@ -404,23 +404,26 @@ def _fix_weights(weights: Union[dict, OrderedDict], fix_type: str = "remove", is
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
if not gpu_ids_abs:
gpu_ids_abs = []
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
gpu_ids = [i for i in range(len(gpu_ids_abs))]
# old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"]
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
# gpu_ids = [i for i in range(len(gpu_ids_abs))]
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
if len(gpu_ids) == 1:

if len(gpu_ids_abs) == 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = proto_model.cuda(gpu_ids[0])

proto_model = proto_model.cuda(gpu_ids_abs[0])
self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs))
elif len(gpu_ids) > 1:
elif len(gpu_ids_abs) > 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = DataParallel(proto_model.cuda(), gpu_ids)
proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs)
self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs))
else:
self._print("%s model use CPU!" % model_name)
Expand Down
7 changes: 4 additions & 3 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@ def __new__(cls, *args, **kwargs):
return instance

def __init__(self, nepochs: int, logdir: str, gpu_ids_abs: Union[list, tuple] = ()):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
# self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
self.gpu_ids = gpu_ids_abs
self.logdir = logdir
self.performance = Performance(gpu_ids_abs)
self.watcher = Watcher(logdir)
self.loger = Loger(logdir)

self.use_gpu = True if (len(self.gpu_ids) > 0) and torch.cuda.is_available() else False
self.device = torch.device("cuda") if self.use_gpu else torch.device("cpu")
self.device = torch.device("cuda:%d" % self.gpu_ids[0]) if self.use_gpu else torch.device("cpu")
self.input = torch.Tensor()
self.ground_truth = torch.Tensor()
self.nepochs = nepochs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称
version="0.0.15",
version="0.0.17",
author="Guanglei Ding",
author_email="dingguanglei.bupt@qq.com",
maintainer='Guanglei Ding',
Expand Down

0 comments on commit 1125c65

Please sign in to comment.