<a href="https://colab.research.google.com/github/kiwiwa/GANs-from-scratch/blob/master/cyclegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%%capture
%%shell
FILE=horse2zebra

if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" &&  $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
    echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
    exit 1
fi

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
mkdir ./datasets/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE

In [0]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchsummary import summary

In [0]:
img_size = 256
channels = 3
ngf = 32
ndf = 64

In [0]:
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super(ResidualBlock, self).__init__()
        
        block = [nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c),
                 nn.ReLU(),
                 nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c)]
        
        self.block = nn.Sequential(*block)
        
    
    def forward(self, x):
        return self.block(x) + x

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoding
        model = []
        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(3, ngf, 9, 1, 0),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf, ngf*2, 3, 2, 1),
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf*2, ngf*4, 3, 2, 1),
                  nn.InstanceNorm2d(ngf*4),
                  nn.ReLU()]
        
        # Transformation
        for i in range(6):
            model += [ResidualBlock(ngf*4)]
        
        # Decoding
        model += [nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, output_padding=1),
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]
        model += [nn.ConvTranspose2d(ngf*2, ngf, 3, 2, 1, output_padding=1),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]
        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(ngf, 3, 9, 1, 0),
                  nn.Tanh()]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        model = []
        model += [nn.Conv2d(3, ndf, 4, 2, 1),
                  nn.LeakyReLU(0.2)]
        
        in_channels = ndf
        out_channels = ndf*2
        for i in range(3):
            model += [nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                      nn.InstanceNorm2d(out_channels),
                      nn.LeakyReLU(0.2)]
            in_channels = out_channels
            out_channels = out_channels * 2
        
        model += [nn.Conv2d(ndf*8, 1, 4, 2, 1)]
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        x = self.model(x)
        x = F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
        return x