Skip to content

Commit

Permalink
docs of model
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Oct 29, 2018
1 parent b35f6d5 commit b273be6
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 65 deletions.
6 changes: 3 additions & 3 deletions jdit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .model import Model
from .optimizer import Optimizer
from jdit.model import Model
from jdit.optimizer import Optimizer
# import dataset
from .trainer import *
from jdit.trainer import *
# import .metric
# from .metric import *
204 changes: 142 additions & 62 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@
from torch.nn import init, Conv2d, Linear, ConvTranspose2d, InstanceNorm2d, BatchNorm2d, DataParallel
from torch import save, load


class Model(object):
r"""A warapper of pytorch ``module`` .
In the simplest case, we use a raw pytorch ``module`` to assemble a ``Model`` of this class.
It can be more convenient to use some feather method, such ``checkPoint`` , ``loadModel`` and so on.
* :attr:`proto_model` is the core model in this class. It is no necessary to passing a ``module``
when you init a ``Model`` . you can build a model later by using ``Model.define(module)`` or load a model from a file.
when you init a ``Model`` . you can build a model later by using ``Model.define(module)`` or load a model from
a file.
* :attr:`gpu_ids_abs` controls the gpus which you want to use. you should use a absolute id of gpus.
* :attr:`init_method` controls the weights init method.
* At init_method="xavier", it will use ``init.xavier_normal_``, in ``pytorch.nn.init``, to init the Conv layers of model.
* At init_method="kaiming", it will use ``init.kaiming_normal_``, in ``pytorch.nn.init``, to init the Conv layers of model.
* At init_method="xavier", it will use ``init.xavier_normal_``, in ``pytorch.nn.init``, to init the Conv
layers of model.
* At init_method="kaiming", it will use ``init.kaiming_normal_``, in ``pytorch.nn.init``, to init the Conv
layers of model.
* At init_method=your_own_method, it will be used on weights, just like what ``pytorch.nn.init`` method does.
* :attr:`show_structure` controls whether to show your network structure.
.. note::
Don't try to pass a :attr:``DataParallel`` model. Only :attr:``module`` is accessable.
Don't try to pass a ``DataParallel`` model. Only ``module`` is accessible.
It will change to ``DataParallel`` class automatically by passing a muti-gpus ids, like ``[0, 1]``.
.. note::
Expand All @@ -40,19 +45,24 @@ class Model(object):
show_structure (bool): Is the structure shown. Default: ``False``
Attributes:
num_params (int): the totals amount of weights in this model.
num_params (int): The totals amount of weights in this model.
gpu_ids (list or tuple): which device is this model on.
gpu_ids (list or tuple): Which device is this model on.
Examples::
>>> from torch.nn import Sequential, Conv3d
>>> # using a square kernels and equal stride
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> # using cpu to init a Model by module.
>>> net = Model(module, [], show_structure=False)
Sequential Total number of parameters: 15873
Sequential model use CPU!
apply kaiming weight init!
>>> input = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input)
"""

def __init__(self, proto_model=None, gpu_ids_abs=(), init_method="kaiming", show_structure=False):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
Expand All @@ -71,24 +81,27 @@ def __getattr__(self, item):
return getattr(self.model, item)

def define(self, proto_model, gpu_ids, init_method, show_structure):
"""define network, according to CPU, GPU and multi-GPUs.
"""Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.
* Print the module's info.
* Move this module to specify device.
* Apply weight init method.
:param proto_model: Network, type of module.
:param gpu_ids: Using GPUs' id, type of tuple. If not use GPU, pass '()'.
:param init_method: init weights method("kaiming") or `False` don't use any init.
:return: Network
:param proto_model: Network, type of ``module``.
:param gpu_ids: Be used GPUs' id, type of ``tuple`` or ``list``. If not use GPU, pass ``()``.
:param init_method: init weights method("kaiming") or ``False`` don't use any init.
"""
self.num_params = self.print_network(proto_model, show_structure)
self.model = self._set_device(proto_model, gpu_ids)
init_name = self._apply_weight_init(init_method, proto_model)
print("apply %s weight init!" % init_name)

def print_network(self, net, show_structure=False):
"""print total number of parameters and structure of network
"""Print total number of parameters and structure of network
:param net: network
:param show_structure: if show network's structure. default: false
:return:
:param net: Pytorch module
:param show_structure: If show network's structure. default: ``False``
:return: Total number of parameters
"""
model_name = net.__class__.__name__
num_params = self.countParams(net)
Expand All @@ -98,61 +111,112 @@ def print_network(self, net, show_structure=False):
print(num_params_log)
return num_params

def loadModel(self, model_or_path, weights_or_path=None, gpu_ids=(), is_eval=True):
"""to assemble a model and weights from paths or passing parameters.
This method deal well with different devices model loading.
You don' need to care about which devices your model have saved.
loadModel(m_path, w_path) #both using a file from paths.
loadModel(model, w_path) #you have had the model. Only get weight from path.
loadModel(model, weight) #you get model and weight. So, you don't need to do any file reading.
loadModel(m_path, None)/loadModel(model, None) #you only load the model without weights.
:param model_or_path: pytorch model or model file path.
:param weights_or_path: pytorch weights or weights file path.
:param gpu_ids:using gpus. default:() using cpu
:param is_eval: if using only for evaluating. model.eval()
:return: model
def loadModel(self, model_or_path, weights_or_path=None, gpu_ids=()):
r"""Assemble a model and weights from paths or passing parameters.
You can load a model from a file, passing parameters or both.
:param model_or_path: Pytorch model or model file path.
:param weights_or_path: Pytorch weights or weights file path.
:param gpu_ids: If using gpus. default:``()``
:return: ``module``
Example::
>>> from torchvision.models.resnet import resnet18
>>> resnet = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> resnet.saveModel("model.pth", "weights.pth", True)
move to cpu...
>>> resnet_load = Model()
>>> # only load module structure
>>> resnet_load.loadModel("model.pth", None)
ResNet model use CPU!
>>> # only load weights
>>> resnet_load.loadModel(None, "weights.pth")
ResNet model use CPU!
>>> # load both
>>> resnet_load.loadModel("model.pth", "weights.pth")
ResNet model use CPU!
"""
is_path = isinstance(model_or_path, str) and os.path.exists(model_or_path)
model = model_or_path
if is_path:

assert self.model or model_or_path, "You must use `self.define()` or passing a model to load."

model_is_path = isinstance(model_or_path, str)
weights_is_path = isinstance(weights_or_path, str)

if model_is_path:
model = load(model_or_path, map_location=lambda storage, loc: storage)
else:
if model_or_path:
model = model_or_path
else:
model = self.model

is_path = isinstance(weights_or_path, str) and os.path.exists(weights_or_path)
weights = weights_or_path
if is_path:
if weights_is_path:
weights = load(weights_or_path, map_location=lambda storage, loc: storage)
else:
weights = weights_or_path

if hasattr(model, "module"):
print("deal with `dataparallel` and extract `module`...")
model = model.module
if weights is not None:
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in weights.items():
# name = k[7:] # remove `module.`
name = k.replace("module.", "", 1) # remove `module.`
new_state_dict[name] = v
weights = new_state_dict

model, _ = self._extract_module(model, extract_weights=False)
weights = self._fix_weights(weights)
if weights is not None:
model.load_state_dict(weights)

model = self._set_device(model, gpu_ids)
self.model = model
# if torch.cuda.is_available() and (len(gpu_ids) == 1):
# print("convert to GPU %s" % str(gpu_ids))
# model = model.cuda()
# elif torch.cuda.is_available() and (len(gpu_ids) > 1):
# print("convert to GPUs %s" % str(gpu_ids))
# model = DataParallel(model, gpu_ids).cuda()
#
#
if is_eval:
return model.eval()
self.model = self._set_device(model, gpu_ids)

def saveModel(self, model_path=None, weights_path=None, to_cpu=False):
r"""Save a model and weights to files.
You can save a model, weights or both to file.
.. note::
This method deal well with different devices on model saving.
You don' need to care about which devices your model have saved.
:param model_or_path: Pytorch model or model file path.
:param weights_or_path: Pytorch weights or weights file path.
:param to_cpu: If this is true, it will keep the location of module
without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
default:``False``
Example::
>>> from torchvision.models.resnet import resnet18
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> model.saveModel("model.pth", "weights.pth")
>>> #you have had the model. Only get weights from path.
>>> model.loadModel(None, "weights.pth")
ResNet model use CPU!
>>> model.loadModel("model.pth", None)
ResNet model use CPU!
"""
assert self.model is not None, "Model.model is `None`. You need to `define` a model before you save it."
if to_cpu:
import copy
model = copy.deepcopy(self.model).cpu()
weights = model.state_dict()
print("move to cpu...")
if hasattr(self.model, "module"):
print("extract `module` from `DataParallel`...")
model, weights = self._extract_module(self.model.cpu(), True)
else:
return model
model = self.model
weights = self.model.state_dict()

if weights_path:
save(weights, weights_path)

if model_path:
save(model, model_path)

def loadPoint(self, model_name, epoch, logdir="log"):
"""load model and weights from a certain checkpoint.
Expand All @@ -177,7 +241,7 @@ def checkPoint(self, model_name, epoch, logdir="log"):
def countParams(self, proto_model):
"""count the total parameters of model.
:param proto_model: pytorch model
:param proto_model: pytorch module
:return: number of parameters
"""
num_params = 0
Expand Down Expand Up @@ -225,6 +289,23 @@ def _weight_init(self, m):
else:
pass

def _extract_module(self, data_parallel_model, extract_weights=True):
print("from `DataParallel` extract `module`...")
model = data_parallel_model.module
weights = self.model.state_dict()
if extract_weights:
weights = self._fix_weights(weights)
return model, weights

def _fix_weights(self, weights):
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in weights.items():
name = k.replace("module.", "", 1) # remove `module.`
new_state_dict[name] = v
return new_state_dict

def _set_device(self, proto_model, gpu_ids):
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
Expand All @@ -249,4 +330,3 @@ def configure(self):
for item in self.model._modules.items():
config_dic["structure"].append(str(item))
return config_dic

0 comments on commit b273be6

Please sign in to comment.