Skip to content

Commit

Permalink
setup 0.0.9v.
Browse files Browse the repository at this point in the history
add lr reset
  • Loading branch information
dingguanglei committed Feb 22, 2019
1 parent f87bd5b commit 3b9eacd
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 58 deletions.
90 changes: 53 additions & 37 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
from typing import Optional, Union
from typing import Optional, Union, Dict
import torch.optim as optim
from inspect import signature

Expand All @@ -26,15 +26,12 @@ class Optimizer(object):
* :attr:`learning rate reset` . Reset learning rate, it can change learning rate and decay directly.
* :attr:`minimum learning rate` . When you do a learning rate decay, it will stop, when the learning rate is
smaller than the minimum
:param params: parameters of model, which need to be updated.
:param optimizer: An optimizer classin pytorch, such as ``torch.optim.Adam``.
:param lr_decay: learning rate decay. Default: 0.92.
:param decay_at_epoch: The position of applying lr decay. Default: None.
:param decay_at_step: learning rate decay. Default: None
:param lr_minimum: minimum learning rate . Default: 1e-5.
:param kwargs: pass hyper-parameters to optimizer, such as ``lr`` , ``betas`` , ``weight_decay`` .
:return:
Expand All @@ -46,12 +43,11 @@ class Optimizer(object):
lr_decay (float, optional): learning rate decay. Default: 0.92
decay_at_epoch (int, list, optional): The position of applying lr decay.
If Default: None
decay_position (int, list, optional): The decaly position of lr. Default: None
decay_at_step (int, list, optional): learning rate decay. Default: None
lr_reset (Dict[position(int), lr(float)] ): Reset learning at a certain position. Default: None
lr_minimum (float, optional): minimum learning rate . Default: 1e-5
position_type ('epoch','step'): Position type. Default: None
**kwargs : pass hyper-parameters to optimizer, such as ``lr`` , ``betas`` , ``weight_decay`` .
Expand All @@ -60,7 +56,8 @@ class Optimizer(object):
>>> from torch.nn import Sequential, Conv3d
>>> from torch.optim import Adam
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, "epoch", lr=1.0, betas=(0.9, 0.999),weight_decay=1e-5)
>>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, {4:0.99},"epoch", lr=1.0, betas=(0.9, 0.999),
weight_decay=1e-5)
>>> print(opt)
(Adam (
Parameter Group 0
Expand All @@ -71,9 +68,9 @@ class Optimizer(object):
weight_decay: 1e-05
)
lr_decay:0.5
lr_minimum:1e-05
decay_position:10
decay_type:epoch
lr_reset:{4: 0.99}
position_type:epoch
))
>>> opt.lr
1.0
Expand All @@ -94,31 +91,37 @@ class Optimizer(object):
lr: 1
weight_decay: 1e-05
)
>>> opt.is_lrdecay(1)
>>> opt.is_decay_lr(1)
False
>>> opt.is_lrdecay(10)
>>> opt.is_decay_lr(10)
True
>>> opt.is_lrdecay(20)
>>> opt.is_decay_lr(20)
True
>>> opt.is_reset_lr(4)
0.99
>>> opt.is_reset_lr(5)
False
"""

def __init__(self, params: "parameters of model",
optimizer: "[Adam,RMSprop,SGD...]",
lr_decay=0.92,
decay_position: Union[int, list] = None,
decay_type: "['epoch','step']" = "epoch",
lr_minimum=1e-5,
lr_decay: float = 1.0,
decay_position: Union[int, tuple, list] = -1,
lr_reset: Dict[int, float] = None,
position_type: "('epoch','step')" = "epoch",
**kwargs):
if not isinstance(decay_position, (int, tuple, list)):
raise TypeError("`decay_position` should be int or tuple/list, get %s instead" % type(
decay_position))
if decay_type not in ['epoch', 'step']:
raise AttributeError("You need to set `decay_type` 'step' or 'epoch', get %s instead" % decay_type)
if position_type not in ('epoch', 'step'):
raise AttributeError("You need to set `position_type` 'step' or 'epoch', get %s instead" % position_type)
if lr_reset and any(lr_reset.values()) <= 0:
raise AttributeError("The learning rate in `lr_reset={position:lr,}` should be grater than 0!")
self.lr_decay = lr_decay
self.lr_minimum = lr_minimum
self.decay_position = decay_position
self.decay_type = decay_type
self.position_type = position_type
self.lr_reset = lr_reset
self.opt_name = optimizer

try:
Expand All @@ -136,33 +139,44 @@ def __init__(self, params: "parameters of model",
self.lr = param_group["lr"]

def __repr__(self):

string = "(" + str(self.opt) + "\n %s:%s\n" % ("lr_decay", self.lr_decay)
string = string + " %s:%s\n" % ("lr_minimum", self.lr_minimum)
string = string + " %s:%s\n" % ("decay_position", self.decay_position)
string = string + " %s:%s\n)" % ("decay_type", self.decay_type) + ")"
string = string + " %s:%s\n" % ("lr_reset", self.lr_reset)
string = string + " %s:%s\n)" % ("position_type", self.position_type) + ")"
return string

def __getattr__(self, name):

return getattr(self.opt, name)

def is_lrdecay(self, position: Optional[int]) -> bool:
def is_decay_lr(self, position: Optional[int]) -> bool:
"""Judge if use learning decay on this position.
:param position: (int) A position of step or epoch.
:return: bool
"""
if not self.decay_position:
return False
# assert isinstance(position, int)

if isinstance(self.decay_position, int):
is_change_lr = position > 0 and (position % self.decay_position) == 0
else:
is_change_lr = position in self.decay_position
return is_change_lr

def is_reset_lr(self, position: Optional[int]) -> bool:
"""Judge if use learning decay on this position.
:param position: (int) A position of step or epoch.
:return: bool
"""
if not self.lr_reset:
return False
if isinstance(self.lr_reset, (tuple, list)):
reset_lr = position > 0 and (position % self.decay_position) == 0
else:
reset_lr = self.lr_reset.get(position, False)
return reset_lr

def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None):
"""Do learning rate decay, or reset them.
Expand All @@ -179,8 +193,7 @@ def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None):
:return:
"""

if self.lr > self.lr_minimum:
self.lr = self.lr * self.lr_decay
self.lr = self.lr * self.lr_decay
if reset_lr_decay is not None:
self.lr_decay = reset_lr_decay
if reset_lr is not None:
Expand All @@ -197,7 +210,7 @@ def configure(self):
config_dic.update(opt_config)
config_dic["lr_decay"] = str(self.lr_decay)
config_dic["decay_position"] = str(self.decay_position)
config_dic["decay_decay_typeposition"] = self.decay_type
config_dic["decay_decay_typeposition"] = self.position_type
return config_dic


Expand All @@ -207,7 +220,7 @@ def configure(self):

adam, rmsprop, sgd = Adam, RMSprop, SGD
param = torch.nn.Linear(10, 1).parameters()
opt = Optimizer(param, "Adam", 0.1, 10, "step", lr=0.9, betas=(0.9, 0.999), weight_decay=1e-5)
opt = Optimizer(param, "Adam", 0.1, 10, {2: 0.01, 4: 0.1}, "step", lr=0.9, betas=(0.9, 0.999), weight_decay=1e-5)
print(opt)
print(opt.configure['lr'])
opt.do_lr_decay()
Expand All @@ -216,11 +229,14 @@ def configure(self):
print(opt.configure['lr_decay'])
opt.do_lr_decay(reset_lr=0.2)
print(opt.configure)
print(opt.is_lrdecay(1))
print(opt.is_lrdecay(2))
print(opt.is_lrdecay(40))
print(opt.is_lrdecay(10))
print(opt.is_decay_lr(1))
print(opt.is_decay_lr(2))
print(opt.is_decay_lr(40))
print(opt.is_decay_lr(10))
print(opt.is_reset_lr(2))
print(opt.is_reset_lr(3))
print(opt.is_reset_lr(4))
param = torch.nn.Linear(10, 1).parameters()
hpd = {"optimizer": "Adam", "lr_decay": 0.1, "decay_position": [1, 3, 5], "decay_type": "epoch",
hpd = {"optimizer": "Adam", "lr_decay": 0.1, "decay_position": [1, 3, 5], "position_type": "epoch",
"lr": 0.9, "betas": (0.9, 0.999), "weight_decay": 1e-5}
opt = Optimizer(param, **hpd)
4 changes: 2 additions & 2 deletions jdit/trainer/instances/cifarPix2pixGan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def start_cifarPix2pixGanTrainer(gpus=(), nepochs=200, lr=1e-3, depth_G=32, dept
depth_D = depth_D

G_hprams = {"optimizer": "Adam", "lr_decay": 0.9,
"decay_position": 10, "decay_type": "epoch",
"decay_position": 10, "position_type": "epoch",
"lr": lr, "weight_decay": 2e-5,
"betas": (0.9, 0.99)
}
D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.9,
"decay_position": 10, "decay_type": "epoch",
"decay_position": 10, "position_type": "epoch",
"lr": lr, "weight_decay": 2e-5,
"momentum": 0
}
Expand Down
5 changes: 3 additions & 2 deletions jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
opt_hpm = {"optimizer": "Adam",
"lr_decay": 0.94,
"decay_position": 10,
"decay_type": "epoch",
"lr": 1e-3,
"position_type": "epoch",
"lr_reset": {2: 5e-4, 3: 1e-3},
"lr": 1e-4,
"weight_decay": 2e-5,
"betas": (0.9, 0.99)}

Expand Down
4 changes: 2 additions & 2 deletions jdit/trainer/instances/fashingGenerateGan.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def start_fashingGenerateGanTrainer(gpus=(), nepochs=50, lr=1e-3, depth_G=32, de
depth_D = depth_D

G_hprams = {"optimizer": "Adam", "lr_decay": 0.94,
"decay_position": 2, "decay_type": "epoch",
"decay_position": 2, "position_type": "epoch",
"lr": lr, "weight_decay": 2e-5,
"betas": (0.9, 0.99)
}
D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.94,
"decay_position": 2, "decay_type": "epoch",
"decay_position": 2, "position_type": "epoch",
"lr": lr, "weight_decay": 2e-5,
"momentum": 0
}
Expand Down
4 changes: 2 additions & 2 deletions jdit/trainer/instances/fashionClassParallelTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def build_task_trainer(unfixed_params):
opt_name = "RMSprop"
lr_decay = 0.94
decay_position= 1
decay_type = "epoch"
position_type = "epoch"
weight_decay = 2e-5
momentum = 0
nepochs = 100
num_class = 10
torch.backends.cudnn.benchmark = True
mnist = FashionMNIST(root="datasets/fashion_data", batch_size=batch_size, num_workers=2)
net = Model(SimpleModel(depth), gpu_ids_abs=gpu_ids_abs, init_method="kaiming", verbose=False)
opt = Optimizer(net.parameters(), opt_name, lr_decay, decay_position, decay_type,
opt = Optimizer(net.parameters(), opt_name, lr_decay, decay_position, position_type=position_type,
lr=lr, weight_decay=weight_decay, momentum=momentum)
Trainer = FashingClassTrainer(logdir, nepochs, gpu_ids_abs, net, opt, mnist, num_class)
return Trainer
Expand Down
21 changes: 13 additions & 8 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def __setattr__(self, key, value):
super(SupTrainer, self).__setattr__(key, value)

if key == "step" and value != 0:
# is_change = self._change_lr("step", value)
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("step", value)
if is_change:
# self._record_configs("optimizer")
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
elif key == "current_epoch" and value != 0:
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("epoch", value)
Expand Down Expand Up @@ -160,7 +158,7 @@ def debug(self):
item.check_point_pos = 1
if isinstance(item, Optimizer):
item.check_point_pos = 1
item.decay_type = "step"
item.position_type = "step"
# the tested functions
debug_fcs = [self._record_configs, self.train_epoch, self.valid_epoch,
self._change_lr, self._check_point, self.test]
Expand Down Expand Up @@ -265,6 +263,8 @@ def _train_iteration(self, opt: Optimizer, compute_loss_fc: FunctionType, csv_fi
loss.backward()
opt.step()
self.watcher.scalars(var_dict=var_dic, global_step=self.step, tag="Train")
opt_name = list(self._opts.keys())[list(self._opts.values()).index(opt)]
self.watcher.scalars(var_dict={"Learning rate": opt.lr}, global_step=self.step, tag=opt_name)
self.loger.write(self.step, self.current_epoch, var_dic, csv_filename, header=self.step <= 1)

def _record_configs(self, configs_names=None):
Expand Down Expand Up @@ -321,13 +321,18 @@ def _check_point(self):
for name, model in _models.items():
model.is_checkpoint(name, current_epoch, logdir)

def _change_lr(self, decay_type="step", position=2):
is_change = False
def _change_lr(self, position_type="step", position=2):
is_change = True
_opts = super(SupTrainer, self).__getattribute__("_opts")
for opt in _opts.values():
if opt.decay_type == decay_type and opt.is_lrdecay(position):
opt.do_lr_decay()
is_change = True
if opt.position_type == position_type:
reset_lr = opt.is_reset_lr(position)
if reset_lr:
opt.do_lr_decay(reset_lr=reset_lr)
elif opt.is_decay_lr(position):
opt.do_lr_decay()
else:
is_change = False
return is_change

def valid_epoch(self):
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
numpy
psutil
torch
torchvision
tqdm
mock
scipy
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称
version="0.0.8",
version="0.0.9",
author="Guanglei Ding",
author_email="dingguanglei.bupt@qq.com",
maintainer='Guanglei Ding',
Expand Down
4 changes: 2 additions & 2 deletions unittests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_do_lr_decay(self):
self.assertEqual(self.opt.lr, 0.6)

def test_is_lrdecay(self):
self.assert_(not self.opt.is_lrdecay(2))
self.assert_(self.opt.is_lrdecay(3))
self.assert_(not self.opt.is_decay_lr(2))
self.assert_(self.opt.is_decay_lr(3))

def test_configure(self):
self.opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3)
Expand Down

0 comments on commit 3b9eacd

Please sign in to comment.