Skip to content

Commit

Permalink
bug fixs
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 3, 2018
1 parent cd19884 commit 04c3102
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 196 deletions.
2 changes: 1 addition & 1 deletion generate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def valid_epoch(self):
norm_type="batch",
active_type="LeakyReLU")
G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming")
G.load_model()
G.load_weights()
print('===> Building optimizer')
opt_D = Optimizer(D.parameters(), lr, lr_decay, weight_decay, momentum, betas, opt_D_name)
opt_G = Optimizer(G.parameters(), lr, lr_decay, weight_decay, momentum, betas, opt_G_name)
Expand Down
194 changes: 92 additions & 102 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Model(object):
r"""A warapper of pytorch ``module`` .
In the simplest case, we use a raw pytorch ``module`` to assemble a ``Model`` of this class.
It can be more convenient to use some feather method, such ``_check_point`` , ``load_model`` and so on.
It can be more convenient to use some feather method, such ``_check_point`` , ``load_weights`` and so on.
* :attr:`proto_model` is the core model in this class.
It is no necessary to passing a ``module`` when you init a ``Model`` .
Expand Down Expand Up @@ -91,30 +91,25 @@ class Model(object):
"""

def __init__(self, proto_model: Module = None, gpu_ids_abs: Union[list, tuple] = (),
def __init__(self, proto_model: Module, gpu_ids_abs: Union[list, tuple] = (),
init_method: Union[str, FunctionType, None] = "kaiming",
show_structure=False, verbose=True):
if not gpu_ids_abs:
gpu_ids_abs = []
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
self.model: Union[DataParallel, Module] = None
# self.model_name :str= None
self.model_name: str = "Model"
self.weights_init = None
self.init_fc = None
self.init_name: str = None
self.num_params: int = 0
self.verbose = verbose
if proto_model is not None:
self.define(proto_model, self.gpu_ids, init_method, show_structure)
self.define(proto_model, gpu_ids_abs, init_method, show_structure)

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)

def __getattr__(self, item):
return getattr(self.model, item)

def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method: Union[str, FunctionType, None],
def define(self, proto_model: Module, gpu_ids_abs: Union[list, tuple], init_method: Union[str, FunctionType, None],
show_structure: bool):
"""Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.
Expand All @@ -125,12 +120,13 @@ def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method:
* Apply weight init method.
:param proto_model: Network, type of ``module``.
:param gpu_ids: Be used GPUs' id, type of ``tuple`` or ``list``. If not use GPU, pass ``()``.
:param gpu_ids_abs: Be used GPUs' id, type of ``tuple`` or ``list``. If not use GPU, pass ``()``.
:param init_method: init weights method("kaiming") or ``False`` don't use any init.
"""
self.num_params = self.print_network(proto_model, show_structure)
self.model = self._set_device(proto_model, gpu_ids)
self.init_name = self._apply_weight_init(init_method, proto_model)
assert isinstance(proto_model, Module)
self.num_params, self.model_name = self.print_network(proto_model, show_structure)
self.model = self._set_device(proto_model, gpu_ids_abs)
self.init_name = self._apply_weight_init(init_method, self.model)
self._print("apply %s weight init!" % self.init_name)

def print_network(self, proto_model: Module, show_structure=False):
Expand All @@ -146,12 +142,9 @@ def print_network(self, proto_model: Module, show_structure=False):
self._print(str(proto_model))
num_params_log = '%s Total number of parameters: %d' % (model_name, num_params)
self._print(num_params_log)
return num_params
return num_params, model_name

def load_model(self, model_or_path: Union[Module, DataParallel, str, None] = None,
weights_or_path: Union[OrderedDict, str, None] = None,
gpu_ids: list = (),
strict=True):
def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
"""Assemble a model and weights from paths or passing parameters.
You can load a model from a file, passing parameters or both.
Expand All @@ -171,50 +164,33 @@ def load_model(self, model_or_path: Union[Module, DataParallel, str, None] = Non
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> resnet.save_model("model.pth", "weights.pth", True)
>>> resnet.save_weights("model.pth", "weights.pth", True)
move to cpu...
>>> resnet_load = Model()
>>> # only load module structure
>>> resnet_load.load_model("model.pth", None)
>>> resnet_load.load_weights("model.pth", None)
ResNet model use CPU!
>>> # only load weights
>>> resnet_load.load_model(None, "weights.pth")
>>> resnet_load.load_weights(None, "weights.pth")
ResNet model use CPU!
>>> # load both
>>> resnet_load.load_model("model.pth", "weights.pth")
>>> resnet_load.load_weights("model.pth", "weights.pth")
ResNet model use CPU!
"""

assert self.model or model_or_path, "You must use `self.define()` or passing a model to load."
# if isinstance(model_or_path, str):

if model_or_path:
if isinstance(model_or_path, Module):
model = model_or_path.cpu()
elif isinstance(model_or_path, DataParallel):
model, _ = self._extract_module(model_or_path.cpu(), extract_weights=False)
else:
model = load(model_or_path, map_location=lambda storage, loc: storage)
if isinstance(weights, str):
weights = load(weights, map_location=lambda storage, loc: storage)
else:
# model_or_path is None
model = self.model.cpu()
if isinstance(model, DataParallel):
model, _ = self._extract_module(model, extract_weights=False)

if weights_or_path:
if isinstance(weights_or_path, dict):
weights = weights_or_path
elif isinstance(weights_or_path, str):
weights = load(weights_or_path, map_location=lambda storage, loc: storage)
else:
raise TypeError("`weights_or_path` must be a `dict` or a path of weights file, such as '.pth'")
weights = self._fix_weights(weights)
model.load_state_dict(weights, strict=strict)

self.model = self._set_device(model, gpu_ids)
raise TypeError("`weights` must be a `dict` or a path of weights file.")
if isinstance(self.model, DataParallel):
self._print("Try to add `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "add", False)
else:
self._print("Try to remove `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "remove", False)
self.model.load_state_dict(weights, strict=strict)

def save_model(self, model_path: str = None, weights_path: str = None, to_cpu=False):
def save_weights(self, weights_path: str, fix_weights=True):
"""Save a model and weights to files.
You can save a model, weights or both to file.
Expand All @@ -237,52 +213,50 @@ def save_model(self, model_path: str = None, weights_path: str = None, to_cpu=Fa
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> model.save_model("model.pth", "weights.pth")
>>> model.save_weights("model.pth", "weights.pth")
>>> #you have had the model. Only get weights from path.
>>> model.load_model(None, "weights.pth")
>>> model.load_weights(None, "weights.pth")
ResNet model use CPU!
>>> model.load_model("model.pth", None)
>>> model.load_weights("model.pth", None)
ResNet model use CPU!
"""
assert self.model is not None, "Model.model is `None`. You need to `define` a model before you save it."
if to_cpu:
if fix_weights:
import copy
model = copy.deepcopy(self.model).cpu()
weights = model.state_dict()
print("move to cpu...")
if hasattr(self.model, "module"):
print("extract `module` from `DataParallel`...")
model, weights = self._extract_module(self.model.cpu(), True)
weights = copy.deepcopy(self.model.state_dict())
self._print("try to remove 'module.' in keys of weights dict...")
weights = self._fix_weights(weights, "remove", False)
else:
model = self.model
weights = self.model.state_dict()

if weights_path:
save(weights, weights_path)

if model_path:
save(model, model_path)
save(weights, weights_path)

def load_point(self, model_name: str, epoch: int, logdir="log"):
"""load model and weights from a certain checkpoint.
this method is cooperate with method `self.chechPoint()`
"""
dir = os.path.join(logdir, "checkpoint")
if logdir.endswith("checkpoint"):
dir = logdir
else:
dir = os.path.join(logdir, "checkpoint")

model_weights_path = os.path.join(dir, "Weights_%s_%d.pth" % (model_name, epoch))
model_path = os.path.join(dir, "Model_%s_%d.pth" % (model_name, epoch))
self.load_model(model_path, model_weights_path)

self.load_weights(model_weights_path, True)

def check_point(self, model_name: str, epoch: int, logdir="log"):
dir = os.path.join(logdir, "checkpoint")
if logdir.endswith("checkpoint"):
dir = logdir
else:
dir = os.path.join(logdir, "checkpoint")

if not os.path.exists(dir):
os.makedirs(dir)

model_weights_path = os.path.join(dir, "Weights_%s_%d.pth" % (model_name, epoch))
model_path = os.path.join(dir, "Model_%s_%d.pth" % (model_name, epoch))
save(self.model.state_dict(), model_weights_path)
save(self.model, model_path)
weights = self._fix_weights(self.model.state_dict(), "remove", False) # try to remove '.module' in keys.
save(weights, model_weights_path)

def count_params(self, proto_model: Module):
"""count the total parameters of model.
Expand All @@ -295,19 +269,6 @@ def count_params(self, proto_model: Module):
num_params += param.numel()
return num_params

def reset_device(self, gpu_ids_abs: list = None):
assert self.model is not None, "You must have a `model` before you reset device!"

if gpu_ids_abs is None:
gpu_ids_abs = []
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
if isinstance(self.model, DataParallel):
proto_model = self.model.module
else:
proto_model = self.model
self.define(proto_model, self.gpu_ids, None, False)

def _apply_weight_init(self, init_method: Union[str, FunctionType], proto_model: Module):
init_name = "No"
if init_method:
Expand Down Expand Up @@ -348,32 +309,50 @@ def _weight_init(self, m):
else:
pass

def _extract_module(self, data_parallel_model: DataParallel, extract_weights=True):
self._print("from `DataParallel` extract `module`...")
model: Module = data_parallel_model.module
weights = self.model.state_dict()
if extract_weights:
weights = self._fix_weights(weights)
return model, weights
# def _extract_module(self, data_parallel_model: DataParallel, extract_weights=True):
# self._print("from `DataParallel` extract `module`...")
# model: Module = data_parallel_model.module
# weights = self.model.state_dict()
# if extract_weights:
# weights = self._fix_weights(weights)
# return model, weights

def _fix_weights(self, weights: Union[dict, OrderedDict]):
def _fix_weights(self, weights: Union[dict, OrderedDict], fix_type: str = "remove", is_strict=True):
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in weights.items():
name = k.replace("module.", "", 1) # remove `module.`
k: str
if fix_type == "remove":
if is_strict:
assert k.startswith(
"module."), "The key of weights dict doesn't start with 'module.'. %s instead" % k
name = k.replace("module.", "", 1) # remove `module.`
elif fix_type == "add":
if is_strict:
assert not k.startswith("module."), "The key of weights dict is %s. Can not add 'module.'" % k
if not k.startswith("module."):
name = "module.".join(k) # add `module.`
else:
raise TypeError("`fix_type` should be 'remove' or 'add'.")
new_state_dict[name] = v
return new_state_dict

def _set_device(self, proto_model: Module, gpu_ids: list):
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
if not gpu_ids_abs:
gpu_ids_abs = []
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
gpu_ids = [i for i in range(len(gpu_ids_abs))]
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
if (len(gpu_ids) == 1):
assert gpu_available, "No gpu available! torch.cuda.is_available() is False"
assert gpu_available, "No gpu available! torch.cuda.is_available() is False. CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"]
proto_model = proto_model.cuda(gpu_ids[0])
self._print("%s model use GPU(%d)!" % (model_name, gpu_ids[0]))
elif (len(gpu_ids) > 1):
assert gpu_available, "No gpu available! torch.cuda.is_available() is False"
assert gpu_available, "No gpu available! torch.cuda.is_available() is False. CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"]
proto_model = DataParallel(proto_model.cuda(), gpu_ids)
self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids))
else:
Expand All @@ -392,9 +371,20 @@ def configure(self):
elif isinstance(self.model, Module):
config_dic["model_name"] = str(self.model.__class__.__name__)
else:
config_dic["model_name"] = 'None'
raise TypeError("Type of `self.model` is wrong!")
config_dic["init_method"] = str(self.init_name)
config_dic["gpus"] = len(self.gpu_ids)
config_dic["total_params"] = self.num_params
config_dic["structure"] = str(self.model)
return config_dic


if __name__ == '__main__':
from torch.nn import Sequential
mode = Sequential(Conv2d(10, 1, 3, 1, 0))
net = Model(mode, [], "kaiming", show_structure=False)
if torch.cuda.is_available():
net = Model(mode, [0], "kaiming", show_structure=False)
if torch.cuda.device_count() > 1:
net = Model(mode, [0, 1], "kaiming", show_structure=False)
if torch.cuda.device_count() > 2:
net = Model(mode, [2, 3], "kaiming", show_structure=False)
26 changes: 13 additions & 13 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ def configure(self):
return config_dic


if __name__ == '__main__':
import torch

param = torch.nn.Linear(10, 1).parameters()
opt = Optimizer(param, lr=0.999, weight_decay=0.03, momentum=0.5, betas=(0.1, 0.4), opt_name="RMSprop")

print(opt.configure['lr'])
opt.do_lr_decay()
print(opt.configure['lr'])
opt.do_lr_decay(reset_lr=0.232, reset_lr_decay=0.3)
print(opt.configure['lr_decay'])
opt.do_lr_decay(reset_lr=0.2)
print(opt.configure)
# if __name__ == '__main__':
# import torch
#
# param = torch.nn.Linear(10, 1).parameters()
# opt = Optimizer(param, lr=0.999, weight_decay=0.03, momentum=0.5, betas=(0.1, 0.4), opt_name="RMSprop")
#
# print(opt.configure['lr'])
# opt.do_lr_decay()
# print(opt.configure['lr'])
# opt.do_lr_decay(reset_lr=0.232, reset_lr_decay=0.3)
# print(opt.configure['lr_decay'])
# opt.do_lr_decay(reset_lr=0.2)
# print(opt.configure)

0 comments on commit 04c3102

Please sign in to comment.