From 6f0eec8ef0a147a80f64f25103089feb33553a06 Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 27 Jun 2020 20:56:58 -0400 Subject: [PATCH] fix runtime error #87 and replace scipy by PIL --- models/bicycle_gan_model.py | 23 +++++++++++------------ util/util.py | 9 ++++++++- util/visualizer.py | 9 +-------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/models/bicycle_gan_model.py b/models/bicycle_gan_model.py index f6ad9fa..01e0af8 100755 --- a/models/bicycle_gan_model.py +++ b/models/bicycle_gan_model.py @@ -76,7 +76,7 @@ def get_z_random(self, batch_size, nz, random_type='gauss'): z = torch.rand(batch_size, nz) * 2.0 - 1.0 elif random_type == 'gauss': z = torch.randn(batch_size, nz) - return z.to(self.device) + return z.detach().to(self.device) def encode(self, input_image): mu, logvar = self.netE.forward(input_image) @@ -184,25 +184,24 @@ def update_D(self): def backward_G_alone(self): # 3, reconstruction |(E(G(A, z_random)))-z_random| if self.opt.lambda_z > 0.0: - self.loss_z_L1 = torch.mean(torch.abs(self.mu2 - self.z_random)) * self.opt.lambda_z + self.loss_z_L1 = self.criterionZ(self.mu2, self.z_random) * self.opt.lambda_z self.loss_z_L1.backward() else: self.loss_z_L1 = 0.0 def update_G_and_E(self): # update G and E - self.set_requires_grad([self.netD, self.netD2], False) - self.optimizer_E.zero_grad() - self.optimizer_G.zero_grad() - self.backward_EG() - self.optimizer_G.step() - self.optimizer_E.step() - # update G only - if self.opt.lambda_z > 0.0: - self.optimizer_G.zero_grad() + with torch.autograd.set_detect_anomaly(True): + self.set_requires_grad([self.netD, self.netD2], False) self.optimizer_E.zero_grad() - self.backward_G_alone() + self.optimizer_G.zero_grad() + self.backward_EG() + + # update G only + if self.opt.lambda_z > 0.0: + self.backward_G_alone() self.optimizer_G.step() + self.optimizer_E.step() def optimize_parameters(self): self.forward() diff --git a/util/util.py b/util/util.py index ce22b87..bdeeffd 100755 --- a/util/util.py +++ b/util/util.py @@ -90,13 +90,20 @@ def interp_z(z0, z1, num_frames, interp_mode='linear'): return zs -def save_image(image_numpy, image_path): +def save_image(image_numpy, image_path, aspect_ratio=1.0): """Save a numpy image to the disk Parameters: image_numpy (numpy array) -- input numpy array image_path (str) -- the path of the image """ + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + if aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) image_pil.save(image_path) diff --git a/util/visualizer.py b/util/visualizer.py index 8ef2c02..62fa460 100755 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -6,7 +6,6 @@ from . import util from . import html from subprocess import Popen, PIPE -from scipy.misc import imresize if sys.version_info[0] == 2: @@ -36,13 +35,7 @@ def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256) im = util.tensor2im(im_data) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) - h, w, _ = im.shape - if aspect_ratio > 1.0: - im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') - if aspect_ratio < 1.0: - im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') - util.save_image(im, save_path) - + util.save_image(im, save_path, aspect_ratio=aspect_ratio) ims.append(image_name) txts.append(label) links.append(image_name)