Skip to content

Commit

Permalink
Merge pull request #4 from dingguanglei/development
Browse files Browse the repository at this point in the history
* d_turn remove from passing parameters.
  • Loading branch information
dingguanglei committed Oct 24, 2018
2 parents 0f64c94 + f8b211a commit 520630e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 93 deletions.
38 changes: 15 additions & 23 deletions generate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

class GenerateGanTrainer(GanTrainer):
mode = "RGB"
every_epoch_checkpoint = 20 # 2
every_epoch_checkpoint = 50 # 2
every_epoch_changelr = 2 # 1

def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape,
d_turn=1):
d_turn = 5
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape):
super(GenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset,
latent_shape=latent_shape,
d_turn=d_turn)
latent_shape=latent_shape)

self.watcher.graph(netG, (4, *self.latent_shape), self.use_gpu)

Expand Down Expand Up @@ -71,39 +69,33 @@ def valid(self):
self.netG.eval()
with torch.no_grad():
fake = self.netG(self.fixed_input).detach()
self.watcher.images([fake], ["Fixedfake"], self.current_epoch, tag="Valid",
show_imgs_num=-1,
mode=self.mode)
self.watcher.image(fake, self.current_epoch, tag="Valid/Fixed_fake",grid_size=(4,4),shuffle=False)
self.watcher.set_training_progress_images(fake, grid_size=(4,4))

var_dic = {}
var_dic["FID_SCORE"] = self.metric.evaluate_model_fid(self.netG, (256, *self.latent_shape), amount=8)
self.watcher.scalars(var_dic, self.step, tag="Valid")

# var_dic["FID_SCORE"] = self.metric.evaluate_model_fid(self.netG, (256, *self.latent_shape), amount=8)
# self.watcher.scalars(var_dic, self.step, tag="Valid")
self.netG.train()


if __name__ == '__main__':
# m_fid =Metric([0,1])
# m_fid._get_cifar10_mu_sigma()
#
# exit(1)
gpus = [2, 3]
batch_shape = (256, 3, 32, 32)

gpus = [3]
batch_shape = (128, 3, 32, 32)
image_channel = batch_shape[1]
nepochs = 100
d_turn = 5
nepochs = 200
mid_channel = 8

opt_G_name = "Adam"
depth_G = 8
lr = 1e-3
lr_decay = 0.94 # 0.94
weight_decay = 2e-5 # 2e-5
weight_decay = 0 # 2e-5
betas = (0.9, 0.999)
G_mid_channel = 8

opt_D_name = "RMSprop"
depth_D = 16
depth_D = 64
momentum = 0
D_mid_channel = 16

Expand All @@ -126,5 +118,5 @@ def valid(self):
opt_G = Optimizer(G.parameters(), lr, lr_decay, weight_decay, momentum, betas, opt_G_name)

print('===> Training')
Trainer = GenerateGanTrainer("log", nepochs, gpus, G, D, opt_G, opt_D, cifar10, latent_shape, d_turn=d_turn)
Trainer = GenerateGanTrainer("log", nepochs, gpus, G, D, opt_G, opt_D, cifar10, latent_shape)
Trainer.train()
11 changes: 7 additions & 4 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from torch.optim import Adam, RMSprop


class Optimizer(object):
def __init__(self, params, lr=1e-3, lr_decay=0.92, weight_decay=2e-5, momentum=0., betas=(0.9, 0.999),
opt_name="Adam"):
opt_name="Adam", lr_minimum=1e-5):
self.lr = lr
self.lr_decay = lr_decay
self.lr_minimum = 1e-5
self.momentum = momentum
self.betas = betas
self.weight_decay = weight_decay
Expand All @@ -24,7 +26,8 @@ def do_lr_decay(self, reset_lr_decay=None, reset_lr=None):
:param reset_lr: if not None, use this value to reset `self.lr`. defaule: None.
:return:
"""
self.lr = self.lr * self.lr_decay
if self.lr > self.lr_minimum:
self.lr = self.lr * self.lr_decay
if reset_lr_decay is not None:
self.lr_decay = reset_lr_decay
if reset_lr is not None:
Expand All @@ -36,7 +39,8 @@ def _init_method(self, params):
if self.opt_name == "Adam":
opt = Adam(filter(lambda p: p.requires_grad, params), self.lr, self.betas, weight_decay=self.weight_decay)
elif self.opt_name == "RMSprop":
opt = RMSprop(filter(lambda p: p.requires_grad, params), self.lr, weight_decay=self.weight_decay, momentum=self.momentum)
opt = RMSprop(filter(lambda p: p.requires_grad, params), self.lr, weight_decay=self.weight_decay,
momentum=self.momentum)
else:
raise ValueError('%s is not a optimizer method!' % self.opt_name)
return opt
Expand All @@ -51,7 +55,6 @@ def configure(self):
config_dic["lr_decay"] = self.lr_decay
return config_dic


# def test_opt():
# import torch
# param = [torch.ones(3, 3, requires_grad=True)] * 5
Expand Down
46 changes: 21 additions & 25 deletions jdit/trainer/gan/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@


class GanTrainer(SupTrainer):

def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape,
d_turn=1):
self.d_turn = d_turn
d_turn = 1
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape):
super(GanTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs)
self.netG = netG
self.netD = netD
Expand Down Expand Up @@ -41,27 +39,26 @@ def train_epoch(self):
self.train_iteration(self.optG, self.compute_g_loss, tag="LOSS_G")

if iteration == 1:
self._watch_images(show_imgs_num=6, tag="Train")
self._watch_images("Train")

def get_data_from_loader(self, batch_data):
ground_truth_cpu = batch_data[0]
input_cpu = Variable(torch.randn((len(ground_truth_cpu), *self.latent_shape)))
return input_cpu, ground_truth_cpu

def _watch_images(self, show_imgs_num, tag):

show_list = [self.input, self.fake, self.ground_truth]
show_title = ["input", "fake", "real"]

if self.input.size() != self.ground_truth.size():
show_list.pop(0)
show_title.pop(0)

self.watcher.images(show_list, show_title,
self.current_epoch,
tag=tag,
show_imgs_num=show_imgs_num,
mode=self.mode)
def _watch_images(self, tag, grid_size=(3, 3), shuffle=False, save_file = True):
self.watcher.image(self.fake,
self.current_epoch,
tag="%s/fake" % tag,
grid_size=grid_size,
shuffle=shuffle,
save_file = save_file)
self.watcher.image(self.ground_truth,
self.current_epoch,
tag="%s/real" % tag,
grid_size=grid_size,
shuffle=shuffle,
save_file=save_file)

def valid(self):
avg_dic = {}
Expand All @@ -83,8 +80,8 @@ def valid(self):
for key in avg_dic.keys():
avg_dic[key] = avg_dic[key] / self.datasets.valid_nsteps

self.watcher.scalars(avg_dic, self.step, tag="Valid" )
self._watch_images(show_imgs_num=4, tag="Valid")
self.watcher.scalars(avg_dic, self.step, tag="Valid")
self._watch_images(tag="Valid")
self.netG.train()
self.netD.train()

Expand Down Expand Up @@ -148,13 +145,12 @@ def compute_valid(self):

def test(self):

self.mv_inplace(Variable(torch.randn((32, *self.latent_shape))), self.input)
self.mv_inplace(Variable(torch.randn((16, *self.latent_shape))), self.input)
self.netG.eval()
with torch.no_grad():
fake = self.netG(self.input).detach()
self.watcher.images([fake], ["fake"], self.current_epoch, tag="Test",
show_imgs_num=-1,
mode=self.mode)
self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(4, 4), shuffle=False)

self.netG.train()

def change_lr(self):
Expand Down
96 changes: 55 additions & 41 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import ABCMeta, abstractmethod
from tqdm import tqdm
import pandas as pd

import numpy as np

class SupTrainer(object):
every_epoch_checkpoint = 10
Expand All @@ -21,7 +21,7 @@ def __init__(self, nepochs, logdir, gpu_ids_abs=()):
self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
self.logdir = logdir
self.performance = Performance(gpu_ids_abs)
self.watcher = Watcher(logdir)
self.watcher = Watcher(logdir, self.mode)
self.loger = Loger(logdir)

self.use_gpu = True if (len(self.gpu_ids) > 0) and torch.cuda.is_available() else False
Expand Down Expand Up @@ -182,10 +182,12 @@ def clear_regist(self):


class Watcher(object):
def __init__(self, logdir):
def __init__(self, logdir, mode="L"):
self.logdir = logdir
self.writer = SummaryWriter(log_dir=logdir)
self.mode = mode
self._buildDir(logdir)
self.training_progress_images = []

def netParams(self, network, global_step):
for name, param in network.named_parameters():
Expand All @@ -194,49 +196,61 @@ def netParams(self, network, global_step):
self.writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step)

def scalars(self, var_dict, global_step, tag="Train"):
# if var_dict is None:
# value_list = list(map(self._torch_to_np, value_list))
# for key, scalar in zip(key_list, value_list):
# self.writer.add_scalars(key, {tag: scalar}, global_step)
# else:

for key, scalar in var_dict.items():
self.writer.add_scalars(key, {tag: scalar}, global_step)

def images(self, imgs_torch_list, title_list, global_step, tag="Train", show_imgs_num=3, mode="L",
mean=(-1, -1, -1), std=(2, 2, 2)):
# :param mode: color mode ,default :'L'
# :param mean: do Normalize. if input is (-1, 1).this should be -1. to convert to (0,1)
# :param std: do Normalize. if input is (-1, 1).this should be 2. to convert to (0,1)

self._buildDir(os.path.join(self.logdir, "plots", tag))
# ["%s/plots/%s" % (self.logdir, i) for i in title_list])

out = None
batchSize = len(imgs_torch_list[0])
show_nums = batchSize if show_imgs_num == -1 else min(show_imgs_num, batchSize)
columns_num = len(title_list)
imgs_stack = []

randindex_list = random.sample(list(range(batchSize)), show_nums)
for randindex in randindex_list:
for imgs_torch in imgs_torch_list:
img_torch = imgs_torch[randindex].cpu().detach()
img_torch = transforms.Normalize(mean, std)(
img_torch) # (-1,1)=>(0,1) mean = -1,std = 2
imgs_stack.append(img_torch)
out_1 = torch.stack(imgs_stack)
if out is None:
out = out_1
def _sample(self, tensor, num_samples, shuffle=True):
total = len(tensor)
assert num_samples <= total
if shuffle:
rand_index = random.sample(list(range(total)), num_samples)
sampled_tensor = tensor[rand_index]
else:
out = torch.cat((out_1, out))
out = make_grid(out, nrow=columns_num)
self.writer.add_image('%s:%s' % (tag, "-".join(title_list)), out, global_step)
sampled_tensor = tensor[:num_samples]
return sampled_tensor

for img, title in zip(imgs_stack, title_list):
img = transforms.ToPILImage()(img).convert(mode)
filename = "%s/plots/%s/E%03d_%s_.png" % (self.logdir, tag, global_step, title)
def image(self, img_tensors, global_step, tag="Train/input", grid_size=(3, 1), shuffle=True, save_file=False):
# if input is (-1, 1).this should be -1. to convert to (0,1) mean=(-1, -1, -1), std=(2, 2, 2)
assert len(img_tensors.size()) == 4, "img_tensors rank should be 4, got %d instead" % len(img_tensors.size())
self._buildDir(os.path.join(self.logdir, "plots", tag))
rows, columns = grid_size[0], grid_size[1]
batchSize = len(img_tensors) # img_tensors =>(batchsize, 3, 256, 256)
num_samples = min(batchSize, rows * columns)
assert len(img_tensors) >= num_samples, "you want to show grid %s, but only have %d tensors to show." % (
grid_size, len(img_tensors))

sampled_tensor = self._sample(img_tensors, num_samples,
shuffle).detach().cpu() # (sample_num, 3, 32,32) tensors
# sampled_images = map(transforms.Normalize(mean, std), sampled_tensor) # (sample_num, 3, 32,32) images
sampled_images = make_grid(sampled_tensor, nrow=rows, normalize=True, scale_each=True)
self.writer.add_image(tag, sampled_images, global_step)

if save_file:
img = transforms.ToPILImage()(sampled_images).convert(self.mode)
filename = "%s/plots/%s/E%03d.png" % (self.logdir, tag, global_step)
img.save(filename)

def set_training_progress_images(self,img_tensors, grid_size=(3, 1)):
assert len(img_tensors.size()) == 4, "img_tensors rank should be 4, got %d instead" % len(img_tensors.size())
rows, columns = grid_size[0], grid_size[1]
batchSize = len(img_tensors) # img_tensors =>(batchsize, 3, 256, 256)
num_samples = min(batchSize, rows * columns)
assert len(img_tensors) >= num_samples, "you want to show grid %s, but only have %d tensors to show." % (
grid_size, len(img_tensors))
sampled_tensor = self._sample(img_tensors, num_samples ,False).detach().cpu() # (sample_num, 3, 32,32) tensors
sampled_images = make_grid(sampled_tensor, nrow=rows, normalize=True, scale_each=True)
img_grid = np.transpose(sampled_images.numpy(), (1, 2, 0))
self.training_progress_images.append(img_grid)

def save_in_gif(self):
import imageio,warnings
filename = "%s/plots/training.gif" % (self.logdir)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
imageio.mimsave(filename, self.training_progress_images)
self.training_progress_images = None

def graph(self, net, input_shape=None, use_gpu=False, *input):
if hasattr(net, 'module'):
net = net.module
Expand All @@ -254,6 +268,8 @@ def graph(self, net, input_shape=None, use_gpu=False, *input):

def close(self):
# self.writer.export_scalars_to_json("%s/scalers.json" % self.logdir)
if self.training_progress_images:
self.save_in_gif()
self.writer.close()

def _buildDir(self, dirs):
Expand All @@ -262,8 +278,6 @@ def _buildDir(self, dirs):
os.makedirs(dirs)




class Performance(object):

def __init__(self, gpu_ids_abs=()):
Expand Down

0 comments on commit 520630e

Please sign in to comment.