Skip to content

Commit

Permalink
fix runtime error #87 and replace scipy by PIL
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jun 28, 2020
1 parent 0d7459a commit 6f0eec8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
23 changes: 11 additions & 12 deletions models/bicycle_gan_model.py
Expand Up @@ -76,7 +76,7 @@ def get_z_random(self, batch_size, nz, random_type='gauss'):
z = torch.rand(batch_size, nz) * 2.0 - 1.0
elif random_type == 'gauss':
z = torch.randn(batch_size, nz)
return z.to(self.device)
return z.detach().to(self.device)

def encode(self, input_image):
mu, logvar = self.netE.forward(input_image)
Expand Down Expand Up @@ -184,25 +184,24 @@ def update_D(self):
def backward_G_alone(self):
# 3, reconstruction |(E(G(A, z_random)))-z_random|
if self.opt.lambda_z > 0.0:
self.loss_z_L1 = torch.mean(torch.abs(self.mu2 - self.z_random)) * self.opt.lambda_z
self.loss_z_L1 = self.criterionZ(self.mu2, self.z_random) * self.opt.lambda_z
self.loss_z_L1.backward()
else:
self.loss_z_L1 = 0.0

def update_G_and_E(self):
# update G and E
self.set_requires_grad([self.netD, self.netD2], False)
self.optimizer_E.zero_grad()
self.optimizer_G.zero_grad()
self.backward_EG()
self.optimizer_G.step()
self.optimizer_E.step()
# update G only
if self.opt.lambda_z > 0.0:
self.optimizer_G.zero_grad()
with torch.autograd.set_detect_anomaly(True):
self.set_requires_grad([self.netD, self.netD2], False)
self.optimizer_E.zero_grad()
self.backward_G_alone()
self.optimizer_G.zero_grad()
self.backward_EG()

# update G only
if self.opt.lambda_z > 0.0:
self.backward_G_alone()
self.optimizer_G.step()
self.optimizer_E.step()

def optimize_parameters(self):
self.forward()
Expand Down
9 changes: 8 additions & 1 deletion util/util.py
Expand Up @@ -90,13 +90,20 @@ def interp_z(z0, z1, num_frames, interp_mode='linear'):
return zs


def save_image(image_numpy, image_path):
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""

image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape

if aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
if aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)


Expand Down
9 changes: 1 addition & 8 deletions util/visualizer.py
Expand Up @@ -6,7 +6,6 @@
from . import util
from . import html
from subprocess import Popen, PIPE
from scipy.misc import imresize


if sys.version_info[0] == 2:
Expand Down Expand Up @@ -36,13 +35,7 @@ def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256)
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if aspect_ratio > 1.0:
im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
if aspect_ratio < 1.0:
im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
util.save_image(im, save_path)

util.save_image(im, save_path, aspect_ratio=aspect_ratio)
ims.append(image_name)
txts.append(label)
links.append(image_name)
Expand Down

0 comments on commit 6f0eec8

Please sign in to comment.