Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
126 lines (99 sloc) 4.8 KB
"""
---
title: Generative Adversarial Networks (GAN)
summary: A simple PyTorch implementation/tutorial of Generative Adversarial Networks (GAN) loss functions.
---
# Generative Adversarial Networks (GAN)
This is an implementation of
[Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
The generator, $G(\pmb{z}; \theta_g)$ generates samples that match the
distribution of data, while the discriminator, $D(\pmb{x}; \theta_g)$
gives the probability that $\pmb{x}$ came from data rather than $G$.
We train $D$ and $G$ simultaneously on a two-player min-max game with value
function $V(G, D)$.
$$\min_G \max_D V(D, G) =
\mathop{\mathbb{E}}_{\pmb{x} \sim p_{data}(\pmb{x})}
\big[\log D(\pmb{x})\big] +
\mathop{\mathbb{E}}_{\pmb{z} \sim p_{\pmb{z}}(\pmb{z})}
\big[\log (1 - D(G(\pmb{z}))\big]
$$
$p_{data}(\pmb{x})$ is the probability distribution over data,
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
gaussian noise.
This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example
with two multilayer perceptron for the generator and discriminator.
"""
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.data
from labml_helpers.module import Module
class DiscriminatorLogitsLoss(Module):
"""
## Discriminator Loss
Discriminator should **ascend** on the gradient,
$$\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m \Bigg[
\log D\Big(\pmb{x}^{(i)}\Big) +
\log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
\Bigg]$$
$m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch.
$\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.
"""
def __init__(self, smoothing: float = 0.2):
super().__init__()
# We use PyTorch Binary Cross Entropy Loss, which is
# $-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$,
# where $y$ are the labels and $\hat{y}$ are the predictions.
# *Note the negative sign*.
# We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$
# and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$
# Then descending on the sum of these is same as ascending on
# the above gradient.
#
# `BCEWithLogitsLoss` combines softmax and binary cross entropy loss.
self.loss_true = nn.BCEWithLogitsLoss()
self.loss_false = nn.BCEWithLogitsLoss()
# We use label smoothing because it seems to work better in some cases
self.smoothing = smoothing
# Labels are registered as buffered and persistence is set to `False`.
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
"""
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
"""
if len(logits_true) > len(self.labels_true):
self.register_buffer("labels_true",
_create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
if len(logits_false) > len(self.labels_false):
self.register_buffer("labels_false",
_create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
class GeneratorLogitsLoss(Module):
"""
## Generator Loss
Generator should **descend** on the gradient,
$$\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \Bigg[
\log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
\Bigg]$$
"""
def __init__(self, smoothing: float = 0.2):
super().__init__()
self.loss_true = nn.BCEWithLogitsLoss()
self.smoothing = smoothing
# We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$
# Then descending on this loss is same as descending on
# the above gradient.
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
def __call__(self, logits: torch.Tensor):
if len(logits) > len(self.fake_labels):
self.register_buffer("fake_labels",
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
return self.loss_true(logits, self.fake_labels[:len(logits)])
def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
"""
Create smoothed labels
"""
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)