Skip to content

Commit

Permalink
setup 0.0.10v.
Browse files Browse the repository at this point in the history
fix checkpoint
  • Loading branch information
dingguanglei committed Mar 22, 2019
1 parent 3b9eacd commit 897764a
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def check_point(self, model_name: str, epoch: int, logdir="log"):
def is_checkpoint(self, model_name: str, epoch: int, logdir="log"):
if not self.check_point_pos:
return False
if isinstance(epoch, int):
if isinstance(self.check_point_pos, int):
is_check_point = epoch > 0 and (epoch % self.check_point_pos) == 0
else:
is_check_point = epoch in self.check_point_pos
Expand Down
1 change: 1 addition & 0 deletions jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):

print('===> Build dataset')
mnist = FashionMNIST(batch_size=batch_size)
# mnist.dataset_train = mnist.dataset_test
torch.backends.cudnn.benchmark = True
print('===> Building model')
net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1)
Expand Down
7 changes: 4 additions & 3 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def __setattr__(self, key, value):
if is_change:
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
elif key == "current_epoch" and value != 0:
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("epoch", value)
if is_change:
is_change_lr = super(SupTrainer, self).__getattribute__("_change_lr")("epoch", value)
if is_change_lr:
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
super(SupTrainer, self).__getattribute__("_check_point")()
super(SupTrainer, self).__getattribute__("_check_point")()

super(SupTrainer, self).__getattribute__("_record_configs")("performance")
elif isinstance(value, Model):
super(SupTrainer, self).__getattribute__("_models").update({key: value})
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称
version="0.0.9",
version="0.0.10",
author="Guanglei Ding",
author_email="dingguanglei.bupt@qq.com",
maintainer='Guanglei Ding',
Expand Down

0 comments on commit 897764a

Please sign in to comment.