In [None]:
import torch.nn as nn
import torch
from functools import partial

In [None]:
class PixelDiscriminator(nn.Module):
  """Defines a 1x1 PatchGAN discrimator (pixelGAN)"""
  def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
    """
    Parameters:
      input_nc(int)  -- the number of channels in input images
      ndf(int)   -- the number of filters in the last conv layer
      norm_layer  -- normalization layer
    """
    super(PixelDiscriminator, self).__init__()
    if type(norm_layer) == functools.partial:
      use_bias = norm_layer.func == nn.InstanceNorm2d   # https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html
    else:
      use_bias = norm_layer == nn.InstanceNorm2d
    self.net = [
                nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),   # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(ndf, ndf*2, kernel_size=1, stride=1, padding=0, bias=use_bias),
                norm_layer(ndf*2),
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(ndf*2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
    self.net = nn.Sequential(*self.net)
  def forward(self, input):
    return self.net(input)      # standard forward!

In [None]:
class NLayerDiscriminator(nn.Module):
  """Define a PatchGAN discriminator"""
  def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
    """
    Parameters:
      input_nc(int)
      ndf (int)
      n_layers (int)  -- the number of conv layers in the discriminator
      norm_layer
    """
    super(NLayerDiscriminator, self).__init__()
    if type(norm_layer) == functools.partial:
      use_bias = norm_layer.func == nn.InstanceNorm2d
    else:
      use_bias = norm_layer == nn.InstanceNorm2d
    
    kw = 4
    padw = 1
    sequence = [
                nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
                nn.LeakyReLU(0.2, True)]
    nf_mult = 1
    nf_mult_prev = 1
    for n in range(1, n_layers):
      nf_mult_prev = nf.nf_mult
      nf_mult = min(2 ** n, 8)
      seqeunce += [
                   nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                   norm_layer(ndf * nf_nult),
                   nn.LeakyReLU(0.2, True)
      ]
      nf_mult_prev = nf_mult_prev
      nf_mult = min (2 ** n_layers, 8)
      sequence += [
                   nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                   norm_layer(ndf * nf_mult),
                   nn.LeakyReLU(0.2, True)
      ]

      sequence += [
                   nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
      ]
      self.model = nn.Sequential(*sequence)
    
    def forward(self, input):
      return self.model(input)

In [None]:
class UnetGenerator(nn.Module):
  """Create a Unet-based generator"""
  def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
    """Construct a Unet generator from the innermost layer to the outermost layer
    Parameters:
      input_nc (int)
      output_nc (int)
      num_downs (int)
      ngf (int)
      norm_layer
    """
    super(UnetGenerator, self).__init__()
    unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
    for i in range(num_downs - 5):
      unet_blovk = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
    unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

  def forward(self, input):
    return self.model(input)

In [None]:
class ResnetGenerator(nn.Module):
  """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations."""
  def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
    """Construct a Resnet-based generator
    """
    assert(n_blocks >= 0)
    super(ResnetGenerator, self).__init__()
    if type(norm_layer) == functools.partial:
      use_bias = norm_layer.func == nn.InstanceNorm2d
    else:
      use_bias = norm_layer == nn.InstanceNorm2d
    
    model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True)]
    n_downsampling = 2
    for i in range(n_downsampling):
      mult = 2 ** i
      model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                norm_layer(ngf * mult * 2),
                nn.ReLU(True)]
    mult = 2 ** n_downsampling
    for i in range(n_blocks):
      model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
    for i in range(n_downsampling):
      mult = 2 ** (n_downsampling - i)
      model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
    
    model += [nn.ReflectionPad2d(3)]
    model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
    model += [nn.Tanh()]

    self.model = nn.Sequential(*model)

  def forward(self, input):
    return self.model(input)

In [None]:
# https://cvnote.ddlee.cc/2019/09/02/cyclegan-pytorch-github
class CycleGANModel(nn.Module):
  "class CycleGANModel: CycleGAN for learning image-to-image translation without paired data."
  def forward(self):
    self.fake_B = self.netG_A(self.real_A)
    self.rec_A = self.netG_B(self.fake_B)
    self.fake_A = self.netG_B(self.real_B)
    self.rec_B = self.netG_Z(self.fake_A)
  def backward_D_basic(self, netD, real, fake):
    """Calculate GAN loss for the discriminator"""
    # real
    pred_real = netD(real)
    loss_D_real = self.criterionGAN(pred_real, True)
    # fake
    pred_fake = netD(fake.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)
    # Combined loss and calculate gradients
    loss_D = (loss_D_real + loss_D_fake) * 0.5
    loss_D.backward()
    return loss_D
  def backward_D_A(self):
    """Calculate GAN loss for the discriminator D_A"""
    fake_B = self.fake_B_pool.query(self.fake_B)
    self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
  def backward_D_B(self):
    """Calculate GAN loss for the discriminator D_B"""
    fake_A = self.fake_A_pool.query(self.fake_A)
    self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
  def backward_G(self):
    """Calculate GAN loss for the discriminator G_A and G_B"""
    lambda_idt = self.opt.lambda_identity
    lambda_A = self.opt.lambda_A
    lambda_B = self.opt.lambda_B
    # identity loss
    if lambda_idt > 0:
      # G_A should be identity if real_B is fel: || G_A(B) - B ||
      self.idt_A = idt.netG_A(self.real_B)
      self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
      # G_B should be identity if real_A is fel: || G_B(A) - A ||
      self.idt_B = self.netG_B(self.real_A)
      self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
    else:
      self.loss_idt_A = 0
      self.loss_idt_B = 0
    
    self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_A(G_A(A))
    self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # GAN loss D_B(G_B(B))
    self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # forward cycle loss || G_B(G(A)) - A ||
    self.loss_cycle_B = self.criterionCycle(Self.rec_B, self.real_B) * lambda_B # backward cycle loss || G_A(G_B(B)) - B ||
    self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B # combined loss and calculate gradients
    self.loss_G.backward()