Skip to content

Commit

Permalink
add a small discriminator for heavily downsized image
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 22, 2020
1 parent 4cd81d0 commit f9de11d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 28 deletions.
98 changes: 71 additions & 27 deletions lightweight_gan/lightweight_gan.py
Expand Up @@ -38,7 +38,6 @@

# constants

EPS = 1e-8
NUM_CORES = multiprocessing.cpu_count()
EXTS = ['jpg', 'jpeg', 'png']
CALC_FID_NUM_IMAGES = 12800
Expand Down Expand Up @@ -100,6 +99,9 @@ def gradient_penalty(images, outputs, weight = 10):
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def hinge_loss(real, fake):
return (F.relu(1 + real) + F.relu(1 - fake)).mean()

def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
Expand Down Expand Up @@ -148,6 +150,21 @@ def __init__(self, fn):
def forward(self, x):
return self.g * self.fn(x)

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x):
return self.fn(x) + x

class SumBranches(nn.Module):
def __init__(self, branches):
super().__init__()
self.branches = nn.ModuleList(branches)
def forward(self, x):
return sum(map(lambda fn: fn(x), self.branches))

# dataset

def convert_image_to(img_type, image):
Expand Down Expand Up @@ -331,14 +348,14 @@ def __init__(
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
self.sle_map = dict(self.sle_map)

self.num_layers_spatial_res = 2
self.num_layers_spatial_res = 1

for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
image_width = 2 ** res

attn = None
if image_width in attn_res_layers:
attn = Rezero(GSA(dim = chan_in, rel_pos_length = image_width))
attn = Rezero(GSA(dim = chan_in, norm_queries = True))

sle = None
if res in self.sle_map:
Expand Down Expand Up @@ -481,20 +498,22 @@ def __init__(

attn = None
if image_width in attn_res_layers:
attn = Rezero(GSA(dim = chan_in, batch_norm = False, rel_pos_length = image_width))
attn = Rezero(GSA(dim = chan_in, batch_norm = False, norm_queries = True))

self.residual_layers.append(nn.ModuleList([
nn.Sequential(
nn.Conv2d(chan_in, chan_out, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(chan_in, chan_out, 1),
nn.LeakyReLU(0.1),
),
SumBranches([
nn.Sequential(
nn.Conv2d(chan_in, chan_out, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(chan_in, chan_out, 1),
nn.LeakyReLU(0.1),
)
]),
attn
]))

Expand All @@ -512,6 +531,27 @@ def __init__(
nn.Conv2d(last_chan, 1, 4)
)

self.to_shape_disc_out = nn.Sequential(
nn.Conv2d(init_channel, 64, 3, padding = 1),
Residual(Rezero(GSA(dim = 64, norm_queries = True, batch_norm = False))),
SumBranches([
nn.Sequential(
nn.Conv2d(64, 32, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(32, 32, 3, padding = 1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(64, 32, 1),
nn.LeakyReLU(0.1),
)
]),
Residual(Rezero(GSA(dim = 32, norm_queries = True, batch_norm = False))),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(32, 1, 4)
)

self.decoder1 = SimpleDecoder(chan_in = last_chan, chan_out = init_channel)
self.decoder2 = SimpleDecoder(chan_in = features[-2][-1], chan_out = init_channel) if resolution >= 9 else None

Expand All @@ -523,17 +563,20 @@ def forward(self, x, calc_aux_loss = False):

layer_outputs = []

for (layer, residual_layer, attn) in self.residual_layers:
for (net, attn) in self.residual_layers:
if exists(attn):
x = attn(x) + x

x = layer(x) + residual_layer(x)
x = net(x)
layer_outputs.append(x)

out = self.to_logits(x).flatten(1)

img_32x32 = F.interpolate(orig_img, size = (32, 32))
out_32x32 = self.to_shape_disc_out(img_32x32)

if not calc_aux_loss:
return out, None
return out, out_32x32, None

# self-supervised auto-encoding loss

Expand Down Expand Up @@ -561,7 +604,7 @@ def forward(self, x, calc_aux_loss = False):

aux_loss = aux_loss + aux_loss_16x16

return out, aux_loss
return out, out_32x32, aux_loss

class LightweightGAN(nn.Module):
def __init__(
Expand Down Expand Up @@ -861,23 +904,24 @@ def train(self):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

generated_images = G(latents)
fake_output, fake_aux_loss = D_aug(generated_images.detach(), calc_aux_loss = True, detach = True, **aug_kwargs)
fake_output, fake_output_32x32, _ = D_aug(generated_images.detach(), detach = True, **aug_kwargs)

image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
real_output, real_aux_loss = D_aug(image_batch, calc_aux_loss = True, **aug_kwargs)
real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss = True, **aug_kwargs)

real_output_loss = real_output
fake_output_loss = fake_output

divergence = (F.relu(1 + real_output_loss) + F.relu(1 - fake_output_loss)).mean()
disc_loss = divergence
divergence = hinge_loss(real_output_loss, fake_output_loss)
divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32)
disc_loss = divergence + divergence_32x32

aux_loss = real_aux_loss + fake_aux_loss
aux_loss = real_aux_loss
disc_loss = disc_loss + aux_loss

if apply_gradient_penalty:
gp = gradient_penalty(image_batch, (real_output,))
gp = gradient_penalty(image_batch, (real_output, real_output_32x32))
self.last_gp_loss = gp.clone().detach().item()
disc_loss = disc_loss + gp

Expand All @@ -899,8 +943,8 @@ def train(self):
for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
generated_images = G(latents)
fake_output, _ = D_aug(generated_images, **aug_kwargs)
fake_output_loss = fake_output.mean(dim = 1)
fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
fake_output_loss = fake_output.mean(dim = 1) + fake_output_32x32.mean(dim = 1)

epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset)
k_frac = max(self.generator_top_k_gamma ** epochs, self.generator_top_k_frac)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
@@ -1 +1 @@
__version__ = '0.9.2'
__version__ = '0.10.0'

0 comments on commit f9de11d

Please sign in to comment.