Skip to content

Commit

Permalink
md fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 9, 2018
1 parent bed0217 commit b7da5f7
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 388 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![](http://img.shields.io/travis/dingguanglei/jdit.svg)](https://github.com/dingguanglei/jdit)
[![Documentation Status](https://readthedocs.org/projects/jdit/badge/?version=latest)](https://jdit.readthedocs.io/en/latest/?badge=latest)
[![codebeat badge](https://codebeat.co/badges/f8c6cfa5-5e6b-499c-b318-2656bc91cab0)](https://codebeat.co/projects/github-com-dingguanglei-jdit-master)
![Packagist](https://img.shields.io/packagist/l/doctrine/orm.svg)
![Packagist](https://img.shields.io/hexpm/l/plug.svg)

**Jdit** is a research processing oriented framework based on pytorch.
Only care about your ideas. You don't need to build a long boring code
Expand Down
1 change: 1 addition & 0 deletions jdit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
import jdit.assessment
import jdit.parallel


6 changes: 4 additions & 2 deletions jdit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union
from abc import ABCMeta, abstractmethod


class DataLoadersFactory(metaclass=ABCMeta):
"""This is a super class of dataloader.
Expand Down Expand Up @@ -161,13 +162,14 @@ def samples_valid(self):
def samples_test(self):
return self._get_samples(self.dataset_train, self.sample_dataset_size)

def _get_samples(self, dataset, sample_dataset_size=0.1):
@staticmethod
def _get_samples(dataset, sample_dataset_size=0.1):
import math
assert len(dataset) > 10, "Dataset is (%d) to small" % len(dataset)
size_is_prop = isinstance(sample_dataset_size, float)
size_is_amount = isinstance(sample_dataset_size, int)
if size_is_prop:
assert sample_dataset_size <= 1 and sample_dataset_size > 0, \
assert 0 < sample_dataset_size <= 1, \
"sample_dataset_size proportion should between 0. and 1."
subdata_size = math.floor(sample_dataset_size * len(dataset))
elif size_is_amount:
Expand Down
106 changes: 43 additions & 63 deletions jdit/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
import torch, os
import torch
import os
from torch.nn import init, Conv2d, Linear, ConvTranspose2d, InstanceNorm2d, BatchNorm2d, DataParallel, Module
from torch import save, load
from typing import Union
Expand Down Expand Up @@ -74,7 +75,7 @@ class Model(object):
Attributes:
num_params (int): The totals amount of weights in this model.
gpu_ids (list or tuple): Which device is this model on.
gpu_ids_abs (list or tuple): Which device is this model on.
Examples::
Expand All @@ -86,8 +87,8 @@ class Model(object):
Sequential Total number of parameters: 15873
Sequential model use CPU!
apply kaiming weight init!
>>> input = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input)
>>> input_tensor = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input_tensor)
"""

Expand Down Expand Up @@ -122,6 +123,7 @@ def define(self, proto_model: Module, gpu_ids_abs: Union[list, tuple], init_meth
:param proto_model: Network, type of ``module``.
:param gpu_ids_abs: Be used GPUs' id, type of ``tuple`` or ``list``. If not use GPU, pass ``()``.
:param init_method: init weights method("kaiming") or ``False`` don't use any init.
:param show_structure: If print structure of model.
"""
assert isinstance(proto_model, Module)
self.num_params, self.model_name = self.print_network(proto_model, show_structure)
Expand Down Expand Up @@ -149,33 +151,22 @@ def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
You can load a model from a file, passing parameters or both.
:param model_or_path: Pytorch model or model file path.
:param weights_or_path: Pytorch weights or weights file path.
:param gpu_ids: If using gpus. default:``()``
:param weights: Pytorch weights or weights file path.
:param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
default:``True``
:return: ``module``
Example::
>>> from torchvision.models.resnet import resnet18
>>> resnet = Model(resnet18())
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> resnet.save_weights("model.pth", "weights.pth", True)
move to cpu...
>>> resnet_load = Model()
>>> # only load module structure
>>> resnet_load.load_weights("model.pth", None)
ResNet model use CPU!
>>> # only load weights
>>> resnet_load.load_weights(None, "weights.pth")
ResNet model use CPU!
>>> # load both
>>> resnet_load.load_weights("model.pth", "weights.pth")
ResNet model use CPU!
>>> model.save_weights("model.pth",)
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("model.pth", True)
Try to remove `moudle.` to keys of weights dict
"""
if isinstance(weights, str):
Expand All @@ -200,25 +191,22 @@ def save_weights(self, weights_path: str, fix_weights=True):
This method deal well with different devices on model saving.
You don' need to care about which devices your model have saved.
:param model_or_path: Pytorch model or model file path.
:param weights_or_path: Pytorch weights or weights file path.
:param to_cpu: If this is true, it will keep the location of module.
:param weights_path: Pytorch weights or weights file path.
:param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``.
without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
default:``False``
Example::
>>> from torchvision.models.resnet import resnet18
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
>>> from torch.nn import Linear
>>> model = Model(Linear(10,1))
Linear Total number of parameters: 11
Linear model use CPU!
apply kaiming weight init!
>>> model.save_weights("model.pth", "weights.pth")
>>> #you have had the model. Only get weights from path.
>>> model.load_weights(None, "weights.pth")
ResNet model use CPU!
>>> model.load_weights("model.pth", None)
ResNet model use CPU!
>>> model.save_weights("weights.pth")
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("weights.pth")
Try to remove `moudle.` to keys of weights dict
"""
if fix_weights:
Expand All @@ -236,29 +224,26 @@ def load_point(self, model_name: str, epoch: int, logdir="log"):
this method is cooperate with method `self.chechPoint()`
"""
if logdir.endswith("checkpoint"):
dir = logdir
else:
dir = os.path.join(logdir, "checkpoint")
if not logdir.endswith("checkpoint"):
logdir = os.path.join(logdir, "checkpoint")

model_weights_path = os.path.join(dir, "Weights_%s_%d.pth" % (model_name, epoch))
model_weights_path = os.path.join(logdir, "Weights_%s_%d.pth" % (model_name, epoch))

self.load_weights(model_weights_path, True)

def check_point(self, model_name: str, epoch: int, logdir="log"):
if logdir.endswith("checkpoint"):
dir = logdir
else:
dir = os.path.join(logdir, "checkpoint")
if not logdir.endswith("checkpoint"):
logdir = os.path.join(logdir, "checkpoint")

if not os.path.exists(dir):
os.makedirs(dir)
if not os.path.exists(logdir):
os.makedirs(logdir)

model_weights_path = os.path.join(dir, "Weights_%s_%d.pth" % (model_name, epoch))
model_weights_path = os.path.join(logdir, "Weights_%s_%d.pth" % (model_name, epoch))
weights = self._fix_weights(self.model.state_dict(), "remove", False) # try to remove '.module' in keys.
save(weights, model_weights_path)

def count_params(self, proto_model: Module):
@staticmethod
def count_params(proto_model: Module):
"""count the total parameters of model.
:param proto_model: pytorch module
Expand Down Expand Up @@ -309,20 +294,12 @@ def _weight_init(self, m):
else:
pass

# def _extract_module(self, data_parallel_model: DataParallel, extract_weights=True):
# self._print("from `DataParallel` extract `module`...")
# model: Module = data_parallel_model.module
# weights = self.model.state_dict()
# if extract_weights:
# weights = self._fix_weights(weights)
# return model, weights

def _fix_weights(self, weights: Union[dict, OrderedDict], fix_type: str = "remove", is_strict=True):
@staticmethod
def _fix_weights(weights: Union[dict, OrderedDict], fix_type: str = "remove", is_strict=True):
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in weights.items():
k: str
if fix_type == "remove":
if is_strict:
assert k.startswith(
Expand All @@ -333,6 +310,8 @@ def _fix_weights(self, weights: Union[dict, OrderedDict], fix_type: str = "remov
assert not k.startswith("module."), "The key of weights dict is %s. Can not add 'module.'" % k
if not k.startswith("module."):
name = "module.".join(k) # add `module.`
else:
name = k
else:
raise TypeError("`fix_type` should be 'remove' or 'add'.")
new_state_dict[name] = v
Expand All @@ -345,23 +324,23 @@ def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, D
gpu_ids = [i for i in range(len(gpu_ids_abs))]
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
if (len(gpu_ids) == 1):
if len(gpu_ids) == 1:
assert gpu_available, "No gpu available! torch.cuda.is_available() is False. CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"]
proto_model = proto_model.cuda(gpu_ids[0])
self._print("%s model use GPU(%d)!" % (model_name, gpu_ids[0]))
elif (len(gpu_ids) > 1):
elif len(gpu_ids) > 1:
assert gpu_available, "No gpu available! torch.cuda.is_available() is False. CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"]
proto_model = DataParallel(proto_model.cuda(), gpu_ids)
self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids))
else:
self._print("%s model use CPU!" % (model_name))
self._print("%s model use CPU!" % model_name)
return proto_model

def _print(self, str: str):
def _print(self, str_msg: str):
if self.verbose:
print(str)
print(str_msg)

@property
def configure(self):
Expand All @@ -380,11 +359,12 @@ def configure(self):

if __name__ == '__main__':
from torch.nn import Sequential

mode = Sequential(Conv2d(10, 1, 3, 1, 0))
net = Model(mode, [], "kaiming", show_structure=False)
if torch.cuda.is_available():
net = Model(mode, [0], "kaiming", show_structure=False)
if torch.cuda.device_count() > 1:
net = Model(mode, [0, 1], "kaiming", show_structure=False)
if torch.cuda.device_count() > 2:
net = Model(mode, [2, 3], "kaiming", show_structure=False)
net = Model(mode, [2, 3], "kaiming", show_structure=False)
6 changes: 3 additions & 3 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
import torch

from torch.optim import Adam, RMSprop, SGD


Expand All @@ -26,7 +26,7 @@ class Optimizer(object):
* :attr:`learning rate reset` . Reset learning rate, it can change learning rate and decay directly.
* :attr:`minimum learning rate` . When you do a learning rate decay, it will stop,
when the learning rate is smaller than the minmimum
when the learning rate is smaller than the minimum
Args:
params (dict): parameters of model, which need to be updated.
Expand All @@ -37,7 +37,7 @@ class Optimizer(object):
weight_decay (float, optional): weight_decay in pytorch ``optimizer`` . Default: 2e-5
moemntum (float, optional): moemntum in pytorch ``moemntum`` . Default: 0
momentum (float, optional): momentum in pytorch ``momentum`` . Default: 0
betas (tuple, list, optional): betas in pytorch ``betas`` . Default: (0.9, 0.999)
Expand Down

0 comments on commit b7da5f7

Please sign in to comment.