In [None]:
#default_exp vision.gan.loss

In [None]:
#hide
from IPython.display import clear_output
from nbdev.export import notebook2script
%reload_ext autoreload
%autoreload 2

In [None]:
#export
import torch
import logging
from torch import nn


logger = logging.getLogger()
logger.setLevel("INFO")

# vision.gan.loss

> 對抗生成網路（Generative Adversarial Network） 常見的損失函數。

## 建立跟 logits 同 device 的 label tensor

In [None]:
#export
def create_like(t, func, on_gpu=False):
    t2 = func(t)
    if on_gpu:
        t2 = t.cuda(t.device.index)
    return t2


def ones_like(t, on_gpu=False):
    return create_like(t, torch.ones_like, on_gpu)


def zeros_like(t, on_gpu=False):
    return create_like(t, torch.zeros_like, on_gpu)

## GAN

In [None]:
#export
def get_gan_loss_fns(is_logits=True, on_gpu=False):
    if is_logits:
        bce = nn.BCEWithLogitsLoss()
    else:
        bce = nn.BCELoss()
    
    def g_loss_fn(fake_logits, on_gpu=False):
        return bce(fake_logits, ones_like(fake_logits, on_gpu))
    
    def d_loss_fn(real_logits, fake_logits, on_gpu=False):
        real_loss = bce(real_logits, ones_like(real_logits, on_gpu))
        fake_loss = bce(fake_logits, zeros_like(fake_logits, on_gpu))
        return real_loss, fake_loss
    
    return g_loss_fn, d_loss_fn

### LSGAN

In [None]:
#export
def get_lsgan_loss_fns(is_logits=True, on_gpu=False):
    mse = nn.MSELoss()

    def g_loss_fn(fake_logits, on_gpu=False):
        return mse(fake_logits, ones_like(fake_logits, on_gpu))
    
    def d_loss_fn(real_logits, fake_logits, on_gpu=False):
        real_loss = mse(real_logits, ones_like(real_logits, on_gpu))
        fake_loss = mse(fake_logits, zeros_like(fake_logits, on_gpu))
        return real_loss, fake_loss

    return g_loss_fn, d_loss_fn

## WGAN

In [None]:
#export
def get_wgan_loss_fns(is_logits=True, **kwargs):
    
    def g_loss_fn(fake_logits):
        return -fake_logits.mean()
    
    def d_loss_fn(real_logits, fake_logits, on_gpu=False):
        real_loss = -real_logits.mean()
        fake_loss = fake_logits.mean()
        return real_loss, fake_loss

    return g_loss_fn, d_loss_fn

### 依據名稱取得對應 GAN 損失函數的 Helper

In [None]:
#export
def get_adversarial_loss_fns(_type, is_logits=True, on_gpu=False):
    if _type == "gan":
        fn = get_gan_loss_fns
    elif _type == "lsgan":
        fn = get_lsgan_loss_fns
    elif _type == "wgan":
        fn = get_wgan_loss_fns
    else:
        raise NotImplementedError
        
    return fn(is_logits=is_logits, on_gpu=on_gpu)

In [None]:
#hide
notebook2script()
clear_output()