Skip to content

Commit

Permalink
config_record fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 19, 2018
1 parent 4d97d3a commit c74a7d0
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 111 deletions.
24 changes: 12 additions & 12 deletions jdit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,18 @@ def _get_samples(dataset, sample_dataset_size=0.1):
@property
def configure(self):
configs = dict()
configs["dataset_name"] = [str(self.dataset_train.__class__.__name__)]
configs["batch_size"] = [str(self.batch_size)]
configs["shuffle"] = [str(self.shuffle)]
configs["root"] = [str(self.root)]
configs["num_workers"] = [str(self.num_workers)]
configs["sample_dataset_size"] = [str(self.sample_dataset_size)]
configs["nsteps_train"] = [str(self.nsteps_train)]
configs["nsteps_valid"] = [str(self.nsteps_valid)]
configs["nsteps_test"] = [str(self.nsteps_test)]
configs["dataset_train"] = [str(self.dataset_train)]
configs["dataset_valid"] = [str(self.dataset_valid)]
configs["dataset_test"] = [str(self.dataset_test)]
configs["dataset_name"] = str(self.dataset_train.__class__.__name__)
configs["batch_size"] = str(self.batch_size)
configs["shuffle"] = str(self.shuffle)
configs["root"] = str(self.root)
configs["num_workers"] = str(self.num_workers)
configs["sample_dataset_size"] = str(self.sample_dataset_size)
configs["nsteps_train"] = str(self.nsteps_train)
configs["nsteps_valid"] = str(self.nsteps_valid)
configs["nsteps_test"] = str(self.nsteps_test)
configs["dataset_train"] = str(self.dataset_train)
configs["dataset_valid"] = str(self.dataset_valid)
configs["dataset_test"] = str(self.dataset_test)
return configs


Expand Down
20 changes: 11 additions & 9 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Union
from collections import OrderedDict
from types import FunctionType
from typing import List
from typing import List, Optional


class _cached_property(object):
Expand Down Expand Up @@ -92,13 +92,14 @@ class Model(object):
>>> output = net(input_tensor)
"""

def __init__(self, proto_model: Module,
gpu_ids_abs: Union[list, tuple] = (),
init_method: Union[str, FunctionType, None] = "kaiming",
show_structure=False, check_point_pos=None, verbose=True):
show_structure=False,
check_point_pos=None, verbose=True):
assert isinstance(proto_model, Module)
self.model: Union[DataParallel, Module] = None
self.model_name: str = "Model"
self.model_name = proto_model.__class__.__name__
self.weights_init = None
self.init_fc = None
self.init_name: str = None
Expand Down Expand Up @@ -128,8 +129,7 @@ def define(self, proto_model: Module, gpu_ids_abs: Union[list, tuple], init_meth
:param init_method: init weights method("kaiming") or ``False`` don't use any init.
:param show_structure: If print structure of model.
"""
assert isinstance(proto_model, Module)
self.num_params, self.model_name = self.print_network(proto_model, show_structure)
self.num_params = self.print_network(proto_model, show_structure)
self.model = self._set_device(proto_model, gpu_ids_abs)
self.init_name = self._apply_weight_init(init_method, self.model)
self._print("apply %s weight init!" % self.init_name)
Expand All @@ -141,13 +141,12 @@ def print_network(self, proto_model: Module, show_structure=False):
:param show_structure: If show network's structure. default: ``False``
:return: Total number of parameters
"""
model_name = proto_model.__class__.__name__
num_params = self.count_params(proto_model)
if show_structure:
self._print(str(proto_model))
num_params_log = '%s Total number of parameters: %d' % (model_name, num_params)
num_params_log = '%s Total number of parameters: %d' % (self.model_name, num_params)
self._print(num_params_log)
return num_params, model_name
return num_params

def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
"""Assemble a model and weights from paths or passing parameters.
Expand Down Expand Up @@ -252,6 +251,7 @@ def check_point_epoch(self, model_name: str, epoch: int, logdir="log"):
is_check_point = epoch in self.check_point_pos
if is_check_point:
self.check_point(model_name, epoch, logdir)
return is_check_point

@staticmethod
def count_params(proto_model: Module):
Expand Down Expand Up @@ -368,6 +368,7 @@ def configure(self):
return config_dic



if __name__ == '__main__':
from torch.nn import Sequential

Expand All @@ -379,3 +380,4 @@ def configure(self):
net = Model(mode, [0, 1], "kaiming", show_structure=False)
if torch.cuda.device_count() > 2:
net = Model(mode, [2, 3], "kaiming", show_structure=False)
net1 = Model(mode, [], "kaiming", show_structure=False)
18 changes: 8 additions & 10 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Optimizer(object):
"""

def __init__(self, params, optimizer: "[Adam,RMSprop,SGD]", lr_decay=0.92, decay_position: Union[int, list] = None,
decay_type: "['epoch','step']" = "epoch", lr_minimum=1e-5, **kwargs):
decay_type: "['epoch','step']" = "epoch", lr_minimum=1e-5,**kwargs):
assert isinstance(decay_position,
(int, tuple, list)), "`decay_position` should be int or tuple/list, get %s instead" % type(
decay_position)
Expand All @@ -118,9 +118,9 @@ def __init__(self, params, optimizer: "[Adam,RMSprop,SGD]", lr_decay=0.92, decay
for param_group in self.opt.param_groups:
self.lr = param_group["lr"]

def __str__(self):
def __repr__(self):

string = "(" + str(self.opt)+ "\n %s:%s\n" % ("lr_decay", self.lr_decay)
string = "(" + str(self.opt) + "\n %s:%s\n" % ("lr_decay", self.lr_decay)
string = string + " %s:%s\n" % ("lr_minimum", self.lr_minimum)
string = string + " %s:%s\n" % ("decay_position", self.decay_position)
string = string + " %s:%s\n)" % ("decay_type", self.decay_type) + ")"
Expand All @@ -131,10 +131,9 @@ def __getattr__(self, name):
return getattr(self.opt, name)

def use_decay(self, position: Optional[int]) -> bool:
"""Check if this is a position of applying for learning rate decay.
"""Judge if use learning decay on this position.
:param step: The steps of back propagation
:param epoch: The epoch of back propagation
:param position: (int) A position of step or epoch.
:return: bool
"""
assert isinstance(position, int)
Expand All @@ -144,9 +143,9 @@ def use_decay(self, position: Optional[int]) -> bool:
is_change_lr = position in self.decay_position
return is_change_lr

def update_state(self, position: int):
if self.use_decay(position):
self.do_lr_decay()
# def update_state(self, position: int):
# if self.use_decay(position):
# self.do_lr_decay()

def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None):
"""Do learning rate decay, or reset them.
Expand All @@ -172,7 +171,6 @@ def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None):
self.lr = reset_lr
for param_group in self.opt.param_groups:
param_group["lr"] = self.lr
print(self.lr)

@property
def configure(self):
Expand Down
3 changes: 2 additions & 1 deletion jdit/trainer/gan/sup_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_data_from_batch(self, batch_data: list, device: torch.device):
return input_cpu.to(self.device), ground_truth_cpu.to(self.device)
:param batch_data: one batch data load from ``DataLoader``
:param device: A device variable. ``torch.device``
:return: input Tensor, ground_truth Tensor
"""
input_tensor, ground_truth_tensor = batch_data[0], batch_data[1]
Expand Down Expand Up @@ -204,7 +205,7 @@ def valid_epoch(self):
self.netG.eval()
self.netD.eval()
for iteration, batch in enumerate(self.datasets.loader_valid, 1):
self.input, self.ground_truth = self.get_data_from_batch(batch,self.device)
self.input, self.ground_truth = self.get_data_from_batch(batch, self.device)
with torch.no_grad():
self.fake = self.netG(self.input)
dic: dict = self.compute_valid()
Expand Down
12 changes: 6 additions & 6 deletions jdit/trainer/instances/cifarPix2pixGan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset
super(CifarPix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
datasets)

def get_data_from_batch(self, batch_data):
def get_data_from_batch(self, batch_data, device):
ground_truth_cpu, label = batch_data[0], batch_data[1]
input_cpu = ground_truth_cpu[:, 0, :, :].unsqueeze(1) # only use one channel [?,3,32,32] =>[?,1,32,32]
return input_cpu.to(self.device), ground_truth_cpu.to(self.device)
return input_cpu, ground_truth_cpu

def compute_d_loss(self):
d_fake = self.netD(self.fake.detach())
Expand Down Expand Up @@ -136,13 +136,13 @@ def start_cifarPix2pixGanTrainer(gpus=(), nepochs=200, lr=1e-3, depth_G=32, dept
G_net = Generator(input_nc=1, output_nc=image_channel, depth=depth_G)
G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=50)
print('===> Building optimizer')
opt_D = Optimizer(D.parameters(),**D_hprams)
opt_G = Optimizer(G.parameters(),**G_hprams)
opt_D = Optimizer(D.parameters(), **D_hprams)
opt_G = Optimizer(G.parameters(), **G_hprams)
print('===> Training')
Trainer = CifarPix2pixGanTrainer("log/cifar_p2p", nepochs, gpus, G, D, opt_G, opt_D, cifar10)
if run_type=="train":
if run_type == "train":
Trainer.train()
elif run_type=="debug":
elif run_type == "debug":
Trainer.debug()


Expand Down
4 changes: 2 additions & 2 deletions jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def compute_valid(self):
return var_dic


def start_fashingClassTrainer(gpus=(), nepochs=100, run_type="train"):
def start_fashingClassTrainer(gpus=(), nepochs=100, run_type="debug"):
"""" An example of fashing-mnist classification
"""
num_class = 10
depth = 32,
depth = 32
gpus = gpus
batch_size = 64
nepochs = nepochs
Expand Down

0 comments on commit c74a7d0

Please sign in to comment.