<a href="https://colab.research.google.com/github/laplaisanterie/GAN/blob/master/3.%20Disco_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# from google.colab import auth
# auth.authenticate_user()

from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)

In [0]:
import os
from pathlib import Path

folder = "colab/"
project_dir = "disco"

base_path = Path("/content/gdrive/My Drive/")
project_path = base_path / folder / project_dir
os.chdir(project_path)
for x in list(project_path.glob("*")):
    if x.is_dir():
        dir_name = str(x.relative_to(project_path))
        os.rename(dir_name, dir_name.split(" ", 1)[0])
print(f"현재 디렉토리 위치: {os.getcwd()}")

In [0]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
from torch.autograd import Variable
from itertools import chain
from PIL import Image


print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))
device = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

In [0]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [0]:
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)

    if torch.cuda.is_available():
        labels = labels.cuda()

    return criterion(input, labels)

In [0]:
def Feature_Loss(real_feats, fake_feats, criterion):
    losses = 0
    for real_feat, fake_feat in zip(real_feats, fake_feats):
        l2 = (real_feat.mean(0) - fake_feat.mean(0)) * (real_feat.mean(0) - fake_feat.mean(0))
        loss = criterion(l2, Variable(torch.ones(l2.size())).cuda())
        losses += loss

    return losses

In [0]:
num_epochs = 200
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
batch_size = 8

decay_gan_loss = 10000
starting_rate = 0.01
changed_rate = 0.5

sample_path = './results'
log_step = 10
sample_step = 50

In [0]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [0]:
class ImageFolder(torch.utils.data.Dataset):
    def __init__(self):
        
        self.transformP = transforms.Compose([transforms.Resize((64,64)),
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.5, 0.5, 0.5),
                                                                  (0.5, 0.5, 0.5))])
        
        self.image_len = None

        self.dir_base = './fruit_datasets'
        
        self.rootA = os.path.join(self.dir_base, 'Apple')
        self.rootB = os.path.join(self.dir_base, 'Banana')
        self.image_paths_A = list(map(lambda x: os.path.join(self.rootA, x), os.listdir(self.rootA)))
        self.image_paths_B = list(map(lambda x: os.path.join(self.rootB, x), os.listdir(self.rootB)))
        self.image_len = min(len(self.image_paths_A), len(self.image_paths_B))


    def __getitem__(self, index):
        A_path = self.image_paths_A[index]
        B_path = self.image_paths_B[index]
        A = Image.open(A_path).convert('RGB')
        B = Image.open(B_path).convert('RGB')

        A = self.transformP(A)
        B = self.transformP(B)

        return {'A': A, 'B': B}

    def __len__(self):
        return self.image_len

In [0]:
train_data = ImageFolder()

data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [0]:
class Generator(nn.Module):
    def __init__(self, extra_layers=False):
        super(Generator, self).__init__()

        if extra_layers == True: # For Car & Face DB
            self.main = nn.Sequential(
                # [-1, 3, 64x64] -> [-1, 64, 32x32]
                nn.Conv2d(3, 64, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),

                # [-1, 128, 16x16]
                nn.Conv2d(64, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace=True),

                # [-1, 256, 8x8]
                nn.Conv2d(128, 256, 4, 2, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace=True),

                # [-1, 512, 4x4]
                nn.Conv2d(256, 512, 4, 2, 1, bias=False),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace=True),

                # [-1, 100, 1x1]
                nn.Conv2d(512, 100, 4, 1, 0, bias=False),
                nn.BatchNorm2d(100),
                nn.LeakyReLU(0.2, inplace=True),

                # [-1, 512, 4x4]
                nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(True),

                # [-1, 256, 8x8]
                nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(True),

                # [-1, 128, 16x16]
                nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(True),

                # [-1, 64, 32x32]
                nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(True),

                # [-1, 3, 64x64]
                nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
                nn.Tanh()
            )

        if extra_layers == False: # For Edges/Shoes/Handbags and Facescrub
            self.main = nn.Sequential(
                # [-1, 3, 64x64] -> [-1, 64, 32x32]
                nn.Conv2d(3, 64, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=False),

                # [-1, 128, 16x16]
                nn.Conv2d(64, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace=False),

                # [-1, 256, 8x8]
                nn.Conv2d(128, 256, 4, 2, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace=False),

                # [-1, 512, 4x4]
                nn.Conv2d(256, 512, 4, 2, 1, bias=False),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace=False),

                # [-1, 256, 8x8]
                nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),

                # [-1, 128, 16x16]
                nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),

                # [-1, 256, 32x32]
                nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                # [-1, 3, 64x64]
                nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
                nn.Tanh()
            )

    def forward(self, input):
        return self.main( input )

class Discriminator(nn.Module):
    def __init__(self):

        super(Discriminator, self).__init__()

        # [-1, 3, 64x64] -> [-1, 64, 32x32]
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1, bias=False)
        self.layer1 = nn.LeakyReLU(0.2, inplace=False)

        # -> [-1, 128, 16x16]
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.layer2 = nn.LeakyReLU(0.2, inplace=False)

        # -> [-1, 256, 8x8]
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.layer3 = nn.LeakyReLU(0.2, inplace=False)

        # -> [-1, 512, 4x4]
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.layer4 = nn.LeakyReLU(0.2, inplace=False)

        # -> [-1, 1, 1x1]
        self.conv5 = nn.Conv2d(512, 1, 4, 1, 0, bias=False)
        
        self.sig = nn.Sigmoid()

    def forward(self, input):
        layer1 = self.layer1( self.conv1( input ) )
        layer2 = self.layer2( self.bn2( self.conv2( layer1 ) ) )
        layer3 = self.layer3( self.bn3( self.conv3( layer2 ) ) )
        layer4 = self.layer4( self.bn4( self.conv4( layer3 ) ) )
        layer5 = self.conv5(layer4)
        
        
        x =self. sig(layer5)
        
        feature = [layer2, layer3, layer4]

        return x, feature

In [0]:
generator_AtoB = Generator()
generator_BtoA = Generator()
discriminator_A = Discriminator()
discriminator_B = Discriminator()

if torch.cuda.is_available():
    generator_AtoB = generator_AtoB.cuda()
    generator_BtoA = generator_BtoA.cuda()
    discriminator_A = discriminator_A.cuda()
    discriminator_B = discriminator_B.cuda()

In [0]:
criterionGAN = nn.BCELoss()
criterionRecon = nn.MSELoss()
criterionFeature = nn.HingeEmbeddingLoss()

In [0]:
g_params = chain(generator_AtoB.parameters(), generator_BtoA.parameters())
d_params = chain(discriminator_A.parameters(), discriminator_B.parameters())

g_optimizer = torch.optim.Adam(g_params, lr, [beta1, beta2], weight_decay = 0.00001)
d_optimizer = torch.optim.Adam(d_params, lr, [beta1, beta2], weight_decay = 0.00001)

In [0]:
total_step = len(data_loader) # For Print Log
iter = 0

for epoch in range(num_epochs):
    for i, sample in enumerate(data_loader):
        input_A = sample['A']
        input_B = sample['B']
        

        # ===================== Random Shuffle =====================#
        idx_A = np.arange(input_A.size(0))
        idx_B = np.arange(input_B.size(0))
        np.random.shuffle(idx_A)
        np.random.shuffle(idx_B)

        input_A = input_A.numpy()
        input_B = input_B.numpy()

        input_A = torch.from_numpy(input_A[idx_A])
        input_B = torch.from_numpy(input_B[idx_B])

        A = to_variable(input_A)
        B = to_variable(input_B)

        # ===================== Forward =====================#
        generator_AtoB.zero_grad()
        generator_BtoA.zero_grad()
        discriminator_A.zero_grad()
        discriminator_B.zero_grad()

        A_to_B = generator_AtoB(A)
        B_to_A = generator_BtoA(B)

        A_to_B_to_A = generator_BtoA(A_to_B)
        B_to_A_to_B = generator_AtoB(B_to_A)

        A_real, A_real_features = discriminator_A(A)
        A_fake, A_fake_features = discriminator_A(B_to_A)

        B_real, B_real_features = discriminator_B(B)
        B_fake, B_fake_features = discriminator_B(A_to_B)

        # ===================== Train D =====================#
        loss_D_A = (GAN_Loss(A_real, True, criterionGAN) + GAN_Loss(A_fake, False, criterionGAN)) * 0.5
        loss_D_B = (GAN_Loss(B_real, True, criterionGAN) + GAN_Loss(B_fake, False, criterionGAN)) * 0.5
        loss_D = loss_D_A + loss_D_B

        # ===================== Train G =====================#
        loss_G_Recon_A = criterionRecon(A_to_B_to_A, A)
        loss_G_Recon_B = criterionRecon(B_to_A_to_B, B)

        loss_G_A = GAN_Loss(A_fake, True, criterionGAN)
        loss_G_B = GAN_Loss(B_fake, True, criterionGAN)

        loss_G_A_feature = Feature_Loss(A_real_features, A_fake_features, criterionFeature)
        loss_G_B_feature = Feature_Loss(B_real_features, B_fake_features, criterionFeature)

        if iter < decay_gan_loss:
            rate = starting_rate
        else:
            rate = changed_rate

        loss_G_A_Total = (loss_G_A*0.1 + loss_G_A_feature*0.9) * (1.-rate) + loss_G_Recon_A * rate
        loss_G_B_Total = (loss_G_B*0.1 + loss_G_B_feature*0.9) * (1.-rate) + loss_G_Recon_B * rate

        loss_G = loss_G_A_Total + loss_G_B_Total

        # ===================== Optimized =====================#

        if iter % 3 == 0:
            loss_D.backward()
            d_optimizer.step()
        else:
            loss_G.backward()
            g_optimizer.step()

        # print the log info
        if (i + 1) % log_step == 0:
            print('Iteration [%d], Epoch [%d/%d], BatchStep[%d/%d], D_loss: %.4f, G_loss: %.4f'
                  % (iter + 1, epoch + 1, num_epochs, i + 1, total_step, loss_D.item(), loss_G.item()))
            
        if (iter + 1) % sample_step == 0:
            res1 = torch.cat((torch.cat((A, A_to_B), dim=2), A_to_B_to_A), dim=2)
            res2 = torch.cat((torch.cat((B, B_to_A), dim=2), B_to_A_to_B), dim=2)
            res = torch.cat((res1, res2), dim=2)
            torchvision.utils.save_image(denorm(res.data), os.path.join(sample_path, 'Generated-%d-%d-%d.png' % (iter + 1, epoch + 1, i + 1)))

        iter += 1

    # save the model parameters for each epoch
#     g_pathAtoB = os.path.join(args.model_path, 'generatorAtoB-%d.pkl' % (epoch + 1))
#     g_pathBtoA = os.path.join(args.model_path, 'generatorBtoA-%d.pkl' % (epoch + 1))
#     torch.save(generator_AtoB.state_dict(), g_pathAtoB)
#     torch.save(generator_BtoA.state_dict(), g_pathBtoA)