/
dcgan.py
120 lines (94 loc) · 4.28 KB
/
dcgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from mmengine.optim import OptimWrapper
from torch import Tensor
from mmgen.registry import MODELS
from mmgen.structures import GenDataSample
from .base_gan import BaseGAN
@MODELS.register_module()
class DCGAN(BaseGAN):
"""Impelmentation of `Unsupervised Representation Learning with Deep
Convolutional Generative Adversarial Networks.
<https://arxiv.org/abs/1511.06434>`_ (DCGAN).
Paper link:
Detailed architecture can be found in
:class:~`mmgen.models.architectures.dcgan.generator_discriminator.DCGANGenerator` # noqa
and
:class:~`mmgen.models.architectures.dcgan.generator_discriminator.DCGANDiscriminator` # noqa
"""
def disc_loss(self, disc_pred_fake: Tensor,
disc_pred_real: Tensor) -> Tuple:
r"""Get disc loss. DCGAN use the vanilla gan loss to train
the discriminator.
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
disc_pred_real (Tensor): Discriminator's prediction of the real
images.
Returns:
tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_disc_fake'] = F.binary_cross_entropy_with_logits(
disc_pred_fake, 0. * torch.ones_like(disc_pred_fake))
losses_dict['loss_disc_real'] = F.binary_cross_entropy_with_logits(
disc_pred_real, 1. * torch.ones_like(disc_pred_real))
loss, log_var = self.parse_losses(losses_dict)
return loss, log_var
def gen_loss(self, disc_pred_fake: Tensor) -> Tuple:
"""Get gen loss. DCGAN use the vanilla gan loss to train the generator.
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
Returns:
tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_gen'] = F.binary_cross_entropy_with_logits(
disc_pred_fake, 1. * torch.ones_like(disc_pred_fake))
loss, log_var = self.parse_losses(losses_dict)
return loss, log_var
def train_discriminator(
self, inputs: dict, data_samples: List[GenDataSample],
optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]:
"""Train discriminator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (List[GenDataSample]): Data samples from dataloader.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
real_imgs = inputs['img']
num_batches = real_imgs.shape[0]
noise_batch = self.noise_fn(num_batches=num_batches)
with torch.no_grad():
fake_imgs = self.generator(noise=noise_batch, return_noise=False)
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
disc_pred_real)
optimizer_wrapper.update_params(parsed_losses)
return log_vars
def train_generator(self, inputs: dict, data_samples: List[GenDataSample],
optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]:
"""Train generator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (List[GenDataSample]): Data samples from dataloader.
Do not used in generator's training.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
num_batches = inputs['img'].shape[0]
noise = self.noise_fn(num_batches=num_batches)
fake_imgs = self.generator(noise=noise, return_noise=False)
disc_pred_fake = self.discriminator(fake_imgs)
parsed_loss, log_vars = self.gen_loss(disc_pred_fake)
optimizer_wrapper.update_params(parsed_loss)
return log_vars