<a href="https://colab.research.google.com/github/kiwiwa/GANs-from-scratch/blob/master/cyclegan/cyclegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%%capture
%%shell
FILE=horse2zebra

if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" &&  $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
    echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
    exit 1
fi

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
mkdir ./datasets/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE

In [0]:
import torch.nn as nn
import torch.nn.functional
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchsummary import summary
import matplotlib.pyplot as plt

In [0]:
img_size = 256
channels = 3
ngf = 32
ndf = 64

epochs = 200
batch_size = 1
lamb = 10
lr = 0.0002

In [5]:
# Cuda stuff
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Device is " + str(device) + ".")

Device is cuda.


In [0]:
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super(ResidualBlock, self).__init__()
        
        block = [nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c),
                 nn.ReLU(),
                 nn.ReflectionPad2d(1),
                 nn.Conv2d(c, c, 3, 1, 0),
                 nn.InstanceNorm2d(c)]
        
        self.block = nn.Sequential(*block)
        
    
    def forward(self, x):
        return self.block(x) + x

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoding
        model = []
        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(3, ngf, 9, 1, 0),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf, ngf*2, 3, 2, 1),
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]
        model += [nn.Conv2d(ngf*2, ngf*4, 3, 2, 1),
                  nn.InstanceNorm2d(ngf*4),
                  nn.ReLU()]
        
        # Transformation
        for i in range(6):
            model += [ResidualBlock(ngf*4)]
        
        # Decoding
        model += [nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, output_padding=1),
                  nn.InstanceNorm2d(ngf*2),
                  nn.ReLU()]
        model += [nn.ConvTranspose2d(ngf*2, ngf, 3, 2, 1, output_padding=1),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU()]
        model += [nn.ReflectionPad2d(4),
                  nn.Conv2d(ngf, 3, 9, 1, 0),
                  nn.Tanh()]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        model = []
        model += [nn.Conv2d(3, ndf, 4, 2, 1),
                  nn.LeakyReLU(0.2)]
        
        in_channels = ndf
        out_channels = ndf*2
        for i in range(3):
            model += [nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                      nn.InstanceNorm2d(out_channels),
                      nn.LeakyReLU(0.2)]
            in_channels = out_channels
            out_channels = out_channels * 2
        
        model += [nn.Conv2d(ndf*8, 1, 4, 2, 1)]
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        x = self.model(x)
        x = torch.nn.functional.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
        return x

In [0]:
import os
from PIL import Image
import random

class UnallignedDataset(Dataset):
    def __init__(self, root, transform, phase='train'):
        dir_A = os.path.join(root, phase + 'A')
        dir_B = os.path.join(root, phase + 'B')
        
        self.A_paths = [os.path.join(dir_A, f) for f in os.listdir(dir_A)]
        self.B_paths = [os.path.join(dir_B, f) for f in os.listdir(dir_B)]
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        
        self.transform = transform
        
    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        B_path = self.B_paths[random.randint(0, self.B_size - 1)]
        
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        A = self.transform(A_img)
        B = self.transform(B_img)
        return A, B
    
    def __len__(self):
        return max(self.A_size, self.B_size)

In [0]:
G = Generator()
F = Generator()
D_X = Discriminator()
D_Y = Discriminator()

transform = transforms.Compose([transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
train_loader = torch.utils.data.DataLoader(dataset=UnallignedDataset('datasets/horse2zebra', transform), 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           pin_memory=True, 
                                           num_workers=2)

mse_criterion = nn.MSELoss()
l1_criterion = nn.L1Loss()

G_optimizer = torch.optim.Adam(G_X.parameters(), lr=lr)
F_optimizer = torch.optim.Adam(G_Y.parameters(), lr=lr)
D_X_optimizer = torch.optim.Adam(D_X.parameters(), lr=lr)
D_Y_optimizer = torch.optim.Adam(D_Y.parameters(), lr=lr)

In [60]:
for epoch in range(epochs):
    for X, Y in train_loader:
        # Update discriminators
        # D_Y
        G_out = G(X).detach()
        D_Y_out_fake = D_Y(G_out)
        D_Y_out_real = D_Y(Y)
        L_D_Y_fake = mse_criterion(D_Y_out_fake, torch.zeros_like(D_Y_out_fake))
        L_D_Y_real = mse_criterion(D_Y_out_real, torch.ones_like(D_Y_out_real))
        L_D_Y = (L_D_Y_fake + L_D_Y_real)*0.5
        
        D_Y_optimizer.zero_grad()
        L_D_Y.backward()
        D_Y_optimizer.step()
        
        # D_X
        F_out = F(Y).detach()
        D_X_out_fake = D_X(F_out)
        D_X_out_real = D_X(X)
        L_D_X_fake = mse_criterion(D_X_out_fake, torch.zeros_like(D_X_out_fake))
        L_D_X_real = mse_criterion(D_X_out_real, torch.ones_like(D_X_out_real))
        L_D_X = (L_D_X_fake + L_D_X_real)*0.5
        
        D_X_optimizer.zero_grad()
        L_D_X.backward()
        D_X_optimizer.step()
        
        # Update generators
        G_out = G(X)
        D_Y_out = D_Y(G_out).detach()
        L_GAN_G = mse_criterion(D_Y_out, torch.ones_like(D_Y_out))
        
        F_out = F(Y)
        D_X_out = D_X(F_out).detach()
        L_GAN_F = mse_criterion(D_X_out, torch.ones_like(D_X_out))
        
        X_CYC = F(G(X))
        Y_CYC = G(F(X))
        L_CYC_forward = l1_criterion(X_CYC, X)
        L_CYC_backward = l1_criterion(Y_CYC, Y)
        
        L_GF = L_GAN_G + L_GAN_F + lamb*(L_CYC_forward + L_CYC_backward)
        G_optimizer.zero_grad()
        F_optimizer.zero_grad()
        L_GF.backward()
        G_optimizer.step()
        F_optimizer.step()
        

tensor([[0.7843]])
tensor([[-0.0079]])
tensor([[0.6152]])
tensor([[0.4232]])
tensor([[0.2559]])
tensor([[1.0016]])
tensor([[-0.0463]])
tensor([[1.0402]])
tensor([[0.2253]])
tensor([[0.7788]])
tensor([[0.4972]])
tensor([[1.1281]])
tensor([[0.2899]])
tensor([[0.1621]])
tensor([[0.3688]])
tensor([[0.1799]])
tensor([[0.3803]])
tensor([[0.3461]])
tensor([[0.8093]])
tensor([[-0.0169]])
tensor([[1.1456]])
tensor([[0.2930]])
tensor([[0.7979]])
tensor([[0.5001]])
tensor([[0.9959]])
tensor([[1.1031]])
tensor([[0.8095]])
tensor([[0.8832]])
tensor([[0.4842]])
tensor([[1.0100]])
tensor([[0.4401]])
tensor([[0.9927]])
tensor([[0.2457]])
tensor([[0.6622]])
tensor([[0.5700]])
tensor([[1.0468]])
tensor([[0.8332]])
tensor([[1.3823]])
tensor([[0.7277]])
tensor([[0.4512]])
tensor([[1.1207]])
tensor([[0.5466]])
tensor([[1.0307]])
tensor([[0.4958]])
tensor([[0.6778]])
tensor([[0.1418]])
tensor([[1.0669]])
tensor([[0.3119]])
tensor([[0.8023]])
tensor([[0.6337]])
tensor([[1.1034]])
tensor([[0.8504]])
tensor([[

KeyboardInterrupt: ignored