-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
547 lines (488 loc) · 24.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
from functools import reduce
import networks
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import numpy as np
import cv2
import torch.nn.functional as F
class DCDA(nn.Module):
def __init__(self, opts):
super(DCDA, self).__init__()
# parameters
lr = 0.0001
lr_dcontent = lr / 2.5
self.nz = 8
self.concat = opts.concat
self.no_ms = opts.no_ms
# discriminators
if opts.dis_scale > 1:
self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
else:
self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disContent = networks.Dis_content()
# encoders
self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b)
if self.concat:
self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \
norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu'))
else:
self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz)
# generator
if self.concat:
self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz)
else:
self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz)
# optimizers
self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=lr_dcontent, betas=(0.5, 0.999), weight_decay=0.0001)
self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
#segmentation model
self.src_seg_model = smp.Unet(encoder_name="resnet34", encoder_depth=5, encoder_weights='imagenet', decoder_channels=[256, 128, 64, 32, 16], in_channels=1, classes=1, activation='sigmoid')
self.src_seg_optimizer = torch.optim.Adam([
dict(params=self.src_seg_model.parameters(), lr=1e-4)
])
self.tar_seg_model = smp.Unet(encoder_name="resnet34", encoder_depth=5, encoder_weights='imagenet', decoder_channels=[256, 128, 64, 32, 16], in_channels=1, classes=1, activation='sigmoid')
self.tar_seg_optimizer = torch.optim.Adam([
dict(params=self.tar_seg_model.parameters(), lr=1e-4)
])
# Setup the loss function for training
self.criterionL1 = torch.nn.L1Loss()
self.dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)
self.ce_loss = nn.BCELoss()
def initialize(self):
self.disA.apply(networks.gaussian_weights_init)
self.disB.apply(networks.gaussian_weights_init)
self.disA2.apply(networks.gaussian_weights_init)
self.disB2.apply(networks.gaussian_weights_init)
self.disContent.apply(networks.gaussian_weights_init)
self.gen.apply(networks.gaussian_weights_init)
self.enc_c.apply(networks.gaussian_weights_init)
self.enc_a.apply(networks.gaussian_weights_init)
def set_scheduler(self, opts, last_ep=0):
self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep)
self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep)
self.disA2_sch = networks.get_scheduler(self.disA2_opt, opts, last_ep)
self.disB2_sch = networks.get_scheduler(self.disB2_opt, opts, last_ep)
self.disContent_sch = networks.get_scheduler(self.disContent_opt, opts, last_ep)
self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep)
self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep)
self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep)
def setgpu(self, gpu):
self.gpu = gpu
self.disA.cuda(self.gpu)
self.disB.cuda(self.gpu)
self.disA2.cuda(self.gpu)
self.disB2.cuda(self.gpu)
self.disContent.cuda(self.gpu)
self.enc_c.cuda(self.gpu)
self.enc_a.cuda(self.gpu)
self.gen.cuda(self.gpu)
self.src_seg_model.cuda(self.gpu)
self.tar_seg_model.cuda(self.gpu)
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.randn(batchSize, nz).cuda(self.gpu)
return z
def test_forward(self, image, a2b=True):
self.z_random = self.get_z_random(image.size(0), self.nz, 'gauss')
if a2b:
self.z_content = self.enc_c.forward_a(image)
output = self.gen.forward_b(self.z_content, self.z_random)
else:
self.z_content = self.enc_c.forward_b(image)
output = self.gen.forward_a(self.z_content, self.z_random)
return output
def test_forward_transfer(self, image_a, image_b, a2b=True):
self.z_content_a, self.z_content_b = self.enc_c.forward(image_a, image_b)
if self.concat:
self.mu_a, self.logvar_a, self.mu_b, self.logvar_b = self.enc_a.forward(image_a, image_b)
std_a = self.logvar_a.mul(0.5).exp_()
eps = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss')
self.z_attr_a = eps.mul(std_a).add_(self.mu_a)
std_b = self.logvar_b.mul(0.5).exp_()
eps = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss')
self.z_attr_b = eps.mul(std_b).add_(self.mu_b)
else:
self.z_attr_a, self.z_attr_b = self.enc_a.forward(image_a, image_b)
if a2b:
output = self.gen.forward_b(self.z_content_a, self.z_attr_b)
else:
output = self.gen.forward_a(self.z_content_b, self.z_attr_a)
return output
def forward(self, half_size=1):
# input images
half_size = half_size
real_A = self.input_A
real_B = self.input_B
self.real_A_encoded = real_A[0:half_size]
self.real_A_random = real_A[half_size:]
self.real_B_encoded = real_B[0:half_size]
self.real_B_random = real_B[half_size:]
# get encoded z_c
self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded)
# get encoded z_a
if self.concat:
self.mu_a, self.logvar_a, self.mu_b, self.logvar_b = self.enc_a.forward(self.real_A_encoded, self.real_B_encoded)
std_a = self.logvar_a.mul(0.5).exp_()
eps_a = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss')
self.z_attr_a = eps_a.mul(std_a).add_(self.mu_a)
std_b = self.logvar_b.mul(0.5).exp_()
eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss')
self.z_attr_b = eps_b.mul(std_b).add_(self.mu_b)
else:
self.z_attr_a, self.z_attr_b = self.enc_a.forward(self.real_A_encoded, self.real_B_encoded)
# get random z_a
self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.nz, 'gauss')
if not self.no_ms:
self.z_random2 = self.get_z_random(self.real_A_encoded.size(0), self.nz, 'gauss')
# first cross translation
if not self.no_ms:
input_content_forA = torch.cat((self.z_content_b, self.z_content_a, self.z_content_b, self.z_content_b),0)
input_content_forB = torch.cat((self.z_content_a, self.z_content_b, self.z_content_a, self.z_content_a),0)
input_attr_forA = torch.cat((self.z_attr_a, self.z_attr_a, self.z_random, self.z_random2),0)
input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random, self.z_random2),0)
output_fakeA = self.gen.forward_a(input_content_forA, input_attr_forA)
output_fakeB = self.gen.forward_b(input_content_forB, input_attr_forB)
self.fake_A_encoded, self.fake_AA_encoded, self.fake_A_random, self.fake_A_random2 = torch.split(output_fakeA, self.z_content_a.size(0), dim=0)
self.fake_B_encoded, self.fake_BB_encoded, self.fake_B_random, self.fake_B_random2 = torch.split(output_fakeB, self.z_content_a.size(0), dim=0)
else:
input_content_forA = torch.cat((self.z_content_b, self.z_content_a, self.z_content_b),0)
input_content_forB = torch.cat((self.z_content_a, self.z_content_b, self.z_content_a),0)
input_attr_forA = torch.cat((self.z_attr_a, self.z_attr_a, self.z_random),0)
input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random),0)
output_fakeA = self.gen.forward_a(input_content_forA, input_attr_forA)
output_fakeB = self.gen.forward_b(input_content_forB, input_attr_forB)
self.fake_A_encoded, self.fake_AA_encoded, self.fake_A_random = torch.split(output_fakeA, self.z_content_a.size(0), dim=0)
self.fake_B_encoded, self.fake_BB_encoded, self.fake_B_random = torch.split(output_fakeB, self.z_content_a.size(0), dim=0)
# get reconstructed encoded z_c
self.z_content_recon_b, self.z_content_recon_a = self.enc_c.forward(self.fake_A_encoded, self.fake_B_encoded)
# get reconstructed encoded z_a
if self.concat:
self.mu_recon_a, self.logvar_recon_a, self.mu_recon_b, self.logvar_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded)
std_a = self.logvar_recon_a.mul(0.5).exp_()
eps_a = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss')
self.z_attr_recon_a = eps_a.mul(std_a).add_(self.mu_recon_a)
std_b = self.logvar_recon_b.mul(0.5).exp_()
eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss')
self.z_attr_recon_b = eps_b.mul(std_b).add_(self.mu_recon_b)
else:
self.z_attr_recon_a, self.z_attr_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded)
# second cross translation
self.fake_A_recon = self.gen.forward_a(self.z_content_recon_a, self.z_attr_recon_a)
self.fake_B_recon = self.gen.forward_b(self.z_content_recon_b, self.z_attr_recon_b)
# for display
self.image_display = torch.cat((self.real_A_encoded[0:1].detach().cpu(), self.fake_B_encoded[0:1].detach().cpu(), \
self.fake_B_random[0:1].detach().cpu(), self.fake_AA_encoded[0:1].detach().cpu(), self.fake_A_recon[0:1].detach().cpu(), \
self.real_B_encoded[0:1].detach().cpu(), self.fake_A_encoded[0:1].detach().cpu(), \
self.fake_A_random[0:1].detach().cpu(), self.fake_BB_encoded[0:1].detach().cpu(), self.fake_B_recon[0:1].detach().cpu()), dim=0)
# for latent regression
if self.concat:
self.mu2_a, _, self.mu2_b, _ = self.enc_a.forward(self.fake_A_random, self.fake_B_random)
else:
self.z_attr_random_a, self.z_attr_random_b = self.enc_a.forward(self.fake_A_random, self.fake_B_random)
def forward_content(self):
half_size = 1
self.real_A_encoded = self.input_A[0:half_size]
self.real_B_encoded = self.input_B[0:half_size]
# get encoded z_c
self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded)
def update_dual_seg(self, image_a, image_b, gt, src_seg, pretrained=False):
if pretrained == True:
self.src_seg_model = src_seg
self.src_seg_model.train()
self.tar_seg_model.train()
torch.set_grad_enabled(True)
self.src_seg_optimizer.zero_grad()
self.tar_seg_optimizer.zero_grad()
gt = gt[0:1]
gt = gt.to(self.gpu)
self.real_A_encoded = self.real_A_encoded.to(self.gpu)
self.real_B_encoded = self.real_B_encoded.to(self.gpu)
self.fake_B_encoded = self.fake_B_encoded.to(self.gpu)
self.fake_A_encoded = self.fake_A_encoded.to(self.gpu)
self.fake_A_recon = self.fake_A_recon.to(self.gpu)
self.fake_B_recon = self.fake_B_recon.to(self.gpu)
outa_real = self.src_seg_model(self.real_A_encoded)
outb_real = self.tar_seg_model(self.real_B_encoded)
outa2b = self.tar_seg_model(self.fake_B_encoded)
outb2a = self.src_seg_model(self.fake_A_encoded)
dice_a_real = self.dice_loss(outa_real, gt)
dice_a2b = self.dice_loss(outa2b, gt)
bce_a = self.ce_loss(outa2b, torch.round(outa_real).detach())
bce_b = self.ce_loss(outb_real, torch.round(outb2a).detach())
all_loss = dice_a_real + dice_a2b + bce_a + bce_b
all_loss.backward()
self.src_seg_optimizer.step()
self.tar_seg_optimizer.step()
def update_D_content(self, image_a, image_b):
self.input_A = image_a
self.input_B = image_b
self.forward_content()
self.disContent_opt.zero_grad()
loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b)
self.disContent_loss = loss_D_Content.item()
nn.utils.clip_grad_norm_(self.disContent.parameters(), 5)
self.disContent_opt.step()
def update_D(self, image_a, image_b):
self.input_A = image_a
self.input_B = image_b
self.forward()
# update disA
self.disA_opt.zero_grad()
loss_D1_A = self.backward_D(self.disA, self.real_A_encoded, self.fake_A_encoded)
self.disA_loss = loss_D1_A.item()
self.disA_opt.step()
# update disA2
self.disA2_opt.zero_grad()
loss_D2_A = self.backward_D(self.disA2, self.real_A_random, self.fake_A_random)
self.disA2_loss = loss_D2_A.item()
if not self.no_ms:
loss_D2_A2 = self.backward_D(self.disA2, self.real_A_random, self.fake_A_random2)
self.disA2_loss += loss_D2_A2.item()
self.disA2_opt.step()
# update disB
self.disB_opt.zero_grad()
loss_D1_B = self.backward_D(self.disB, self.real_B_encoded, self.fake_B_encoded)
self.disB_loss = loss_D1_B.item()
self.disB_opt.step()
# update disB2
self.disB2_opt.zero_grad()
loss_D2_B = self.backward_D(self.disB2, self.real_B_random, self.fake_B_random)
self.disB2_loss = loss_D2_B.item()
if not self.no_ms:
loss_D2_B2 = self.backward_D(self.disB2, self.real_B_random, self.fake_B_random2)
self.disB2_loss += loss_D2_B2.item()
self.disB2_opt.step()
# update disContent
self.disContent_opt.zero_grad()
loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b)
self.disContent_loss = loss_D_Content.item()
nn.utils.clip_grad_norm_(self.disContent.parameters(), 5)
self.disContent_opt.step()
def backward_D(self, netD, real, fake):
pred_fake = netD.forward(fake.detach())
pred_real = netD.forward(real)
loss_D = 0
for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)):
out_fake = nn.functional.sigmoid(out_a)
out_real = nn.functional.sigmoid(out_b)
all0 = torch.zeros_like(out_fake).cuda(self.gpu)
all1 = torch.ones_like(out_real).cuda(self.gpu)
ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0)
ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1)
loss_D += ad_true_loss + ad_fake_loss
loss_D.backward()
return loss_D
def backward_contentD(self, imageA, imageB):
pred_fake = self.disContent.forward(imageA.detach())
pred_real = self.disContent.forward(imageB.detach())
for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)):
out_fake = nn.functional.sigmoid(out_a)
out_real = nn.functional.sigmoid(out_b)
all1 = torch.ones((out_real.size(0))).cuda(self.gpu)
all0 = torch.zeros((out_fake.size(0))).cuda(self.gpu)
ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1)
ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0)
loss_D = ad_true_loss + ad_fake_loss
loss_D.backward()
return loss_D
def update_EG(self):
# update G, Ec, Ea
self.enc_c_opt.zero_grad()
self.enc_a_opt.zero_grad()
self.gen_opt.zero_grad()
self.backward_EG()
self.enc_c_opt.step()
self.enc_a_opt.step()
self.gen_opt.step()
# update G, Ec
self.enc_c_opt.zero_grad()
self.gen_opt.zero_grad()
self.backward_G_alone()
self.enc_c_opt.step()
self.gen_opt.step()
def backward_EG(self):
# content Ladv for generator
loss_G_GAN_Acontent = self.backward_G_GAN_content(self.z_content_a)
loss_G_GAN_Bcontent = self.backward_G_GAN_content(self.z_content_b)
# Ladv for generator
loss_G_GAN_A = self.backward_G_GAN(self.fake_A_encoded, self.disA)
loss_G_GAN_B = self.backward_G_GAN(self.fake_B_encoded, self.disB)
# KL loss - z_a
if self.concat:
kl_element_a = self.mu_a.pow(2).add_(self.logvar_a.exp()).mul_(-1).add_(1).add_(self.logvar_a)
loss_kl_za_a = torch.sum(kl_element_a).mul_(-0.5) * 0.01
kl_element_b = self.mu_b.pow(2).add_(self.logvar_b.exp()).mul_(-1).add_(1).add_(self.logvar_b)
loss_kl_za_b = torch.sum(kl_element_b).mul_(-0.5) * 0.01
else:
loss_kl_za_a = self._l2_regularize(self.z_attr_a) * 0.01
loss_kl_za_b = self._l2_regularize(self.z_attr_b) * 0.01
# KL loss - z_c
loss_kl_zc_a = self._l2_regularize(self.z_content_a) * 0.01
loss_kl_zc_b = self._l2_regularize(self.z_content_b) * 0.01
# cross cycle consistency loss
loss_G_L1_A = self.criterionL1(self.fake_A_recon, self.real_A_encoded) * 10
loss_G_L1_B = self.criterionL1(self.fake_B_recon, self.real_B_encoded) * 10
loss_G_L1_AA = self.criterionL1(self.fake_AA_encoded, self.real_A_encoded) * 10
loss_G_L1_BB = self.criterionL1(self.fake_BB_encoded, self.real_B_encoded) * 10
loss_G = loss_G_GAN_A + loss_G_GAN_B + \
loss_G_GAN_Acontent + loss_G_GAN_Bcontent + \
loss_G_L1_AA + loss_G_L1_BB + \
loss_G_L1_A + loss_G_L1_B + \
loss_kl_zc_a + loss_kl_zc_b + \
loss_kl_za_a + loss_kl_za_b
loss_G.backward(retain_graph=True)
self.gan_loss_a = loss_G_GAN_A.item()
self.gan_loss_b = loss_G_GAN_B.item()
self.gan_loss_acontent = loss_G_GAN_Acontent.item()
self.gan_loss_bcontent = loss_G_GAN_Bcontent.item()
self.kl_loss_za_a = loss_kl_za_a.item()
self.kl_loss_za_b = loss_kl_za_b.item()
self.kl_loss_zc_a = loss_kl_zc_a.item()
self.kl_loss_zc_b = loss_kl_zc_b.item()
self.l1_recon_A_loss = loss_G_L1_A.item()
self.l1_recon_B_loss = loss_G_L1_B.item()
self.l1_recon_AA_loss = loss_G_L1_AA.item()
self.l1_recon_BB_loss = loss_G_L1_BB.item()
self.G_loss = loss_G.item()
def backward_G_GAN_content(self, data):
outs = self.disContent.forward(data)
for out in outs:
outputs_fake = nn.functional.sigmoid(out)
all_half = 0.5*torch.ones((outputs_fake.size(0))).cuda(self.gpu)
ad_loss = nn.functional.binary_cross_entropy(outputs_fake, all_half)
return ad_loss
def backward_G_GAN(self, fake, netD=None):
outs_fake = netD.forward(fake)
loss_G = 0
for out_a in outs_fake:
outputs_fake = nn.functional.sigmoid(out_a)
all_ones = torch.ones_like(outputs_fake).cuda(self.gpu)
loss_G += nn.functional.binary_cross_entropy(outputs_fake, all_ones)
return loss_G
def backward_G_alone(self):
# Ladv for generator
loss_G_GAN2_A = self.backward_G_GAN(self.fake_A_random, self.disA2)
loss_G_GAN2_B = self.backward_G_GAN(self.fake_B_random, self.disB2)
if not self.no_ms:
loss_G_GAN2_A2 = self.backward_G_GAN(self.fake_A_random2, self.disA2)
loss_G_GAN2_B2 = self.backward_G_GAN(self.fake_B_random2, self.disB2)
# mode seeking loss for A-->B and B-->A
if not self.no_ms:
lz_AB = torch.mean(torch.abs(self.fake_B_random2 - self.fake_B_random)) / torch.mean(torch.abs(self.z_random2 - self.z_random))
lz_BA = torch.mean(torch.abs(self.fake_A_random2 - self.fake_A_random)) / torch.mean(torch.abs(self.z_random2 - self.z_random))
eps = 1 * 1e-5
loss_lz_AB = 1 / (lz_AB + eps)
loss_lz_BA = 1 / (lz_BA + eps)
# latent regression loss
if self.concat:
loss_z_L1_a = torch.mean(torch.abs(self.mu2_a - self.z_random)) * 10
loss_z_L1_b = torch.mean(torch.abs(self.mu2_b - self.z_random)) * 10
else:
loss_z_L1_a = torch.mean(torch.abs(self.z_attr_random_a - self.z_random)) * 10
loss_z_L1_b = torch.mean(torch.abs(self.z_attr_random_b - self.z_random)) * 10
loss_z_L1 = loss_z_L1_a + loss_z_L1_b + loss_G_GAN2_A + loss_G_GAN2_B
if not self.no_ms:
loss_z_L1 += (loss_G_GAN2_A2 + loss_G_GAN2_B2)
loss_z_L1 += (loss_lz_AB + loss_lz_BA)
loss_z_L1.backward(retain_graph=True)
self.l1_recon_z_loss_a = loss_z_L1_a.item()
self.l1_recon_z_loss_b = loss_z_L1_b.item()
if not self.no_ms:
self.gan2_loss_a = loss_G_GAN2_A.item() + loss_G_GAN2_A2.item()
self.gan2_loss_b = loss_G_GAN2_B.item() + loss_G_GAN2_B2.item()
self.lz_AB = loss_lz_AB.item()
self.lz_BA = loss_lz_BA.item()
else:
self.gan2_loss_a = loss_G_GAN2_A.item()
self.gan2_loss_b = loss_G_GAN2_B.item()
def update_lr(self):
self.disA_sch.step()
self.disB_sch.step()
self.disA2_sch.step()
self.disB2_sch.step()
self.disContent_sch.step()
self.enc_c_sch.step()
self.enc_a_sch.step()
self.gen_sch.step()
def _l2_regularize(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
# weight
if train:
self.disA.load_state_dict(checkpoint['disA'])
self.disA2.load_state_dict(checkpoint['disA2'])
self.disB.load_state_dict(checkpoint['disB'])
self.disB2.load_state_dict(checkpoint['disB2'])
self.disContent.load_state_dict(checkpoint['disContent'])
self.enc_c.load_state_dict(checkpoint['enc_c'])
self.enc_a.load_state_dict(checkpoint['enc_a'])
self.gen.load_state_dict(checkpoint['gen'])
self.src_seg_model.load_state_dict(checkpoint['src_seg'])
self.tar_seg_model.load_state_dict(checkpoint['tar_seg'])
# optimizer
if train:
self.disA_opt.load_state_dict(checkpoint['disA_opt'])
self.disA2_opt.load_state_dict(checkpoint['disA2_opt'])
self.disB_opt.load_state_dict(checkpoint['disB_opt'])
self.disB2_opt.load_state_dict(checkpoint['disB2_opt'])
self.disContent_opt.load_state_dict(checkpoint['disContent_opt'])
self.enc_c_opt.load_state_dict(checkpoint['enc_c_opt'])
self.enc_a_opt.load_state_dict(checkpoint['enc_a_opt'])
self.gen_opt.load_state_dict(checkpoint['gen_opt'])
return checkpoint['ep'], checkpoint['total_it']
def save(self, filename, ep, total_it):
state = {
'disA': self.disA.state_dict(),
'disA2': self.disA2.state_dict(),
'disB': self.disB.state_dict(),
'disB2': self.disB2.state_dict(),
'disContent': self.disContent.state_dict(),
'enc_c': self.enc_c.state_dict(),
'enc_a': self.enc_a.state_dict(),
'gen': self.gen.state_dict(),
'disA_opt': self.disA_opt.state_dict(),
'disA2_opt': self.disA2_opt.state_dict(),
'disB_opt': self.disB_opt.state_dict(),
'disB2_opt': self.disB2_opt.state_dict(),
'disContent_opt': self.disContent_opt.state_dict(),
'enc_c_opt': self.enc_c_opt.state_dict(),
'enc_a_opt': self.enc_a_opt.state_dict(),
'gen_opt': self.gen_opt.state_dict(),
'src_seg': self.src_seg_model.state_dict(),
'tar_seg': self.tar_seg_model.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def assemble_outputs(self):
images_a = self.normalize_image(self.real_A_encoded).detach()
images_b = self.normalize_image(self.real_B_encoded).detach()
images_a1 = self.normalize_image(self.fake_A_encoded).detach()
images_a2 = self.normalize_image(self.fake_A_random).detach()
images_a3 = self.normalize_image(self.fake_A_recon).detach()
images_a4 = self.normalize_image(self.fake_AA_encoded).detach()
images_b1 = self.normalize_image(self.fake_B_encoded).detach()
images_b2 = self.normalize_image(self.fake_B_random).detach()
images_b3 = self.normalize_image(self.fake_B_recon).detach()
images_b4 = self.normalize_image(self.fake_BB_encoded).detach()
row1 = torch.cat((images_a[0:1, ::], images_b1[0:1, ::], images_b2[0:1, ::], images_a4[0:1, ::], images_a3[0:1, ::]),3)
row2 = torch.cat((images_b[0:1, ::], images_a1[0:1, ::], images_a2[0:1, ::], images_b4[0:1, ::], images_b3[0:1, ::]),3)
return torch.cat((row1,row2),2)
def normalize_image(self, x):
return x[:,0:3,:,:]