In [2]:
from google.colab import auth
auth.authenticate_user()

from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import os
from pathlib import Path

folder = "colab/"
project_dir = "pix2pix"

base_path = Path("/content/gdrive/My Drive/")
project_path = base_path / folder / project_dir
os.chdir(project_path)
for x in list(project_path.glob("*")):
    if x.is_dir():
        dir_name = str(x.relative_to(project_path))
        os.rename(dir_name, dir_name.split(" ", 1)[0])
print(f"현재 디렉토리 위치: {os.getcwd()}")


In [0]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
from torch.autograd import Variable
from itertools import chain
from PIL import Image
from torch.utils.data import DataLoader

print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))
device = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

In [0]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [0]:
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)

    if torch.cuda.is_available():
        labels = labels.cuda()

    return criterion(input, labels)

In [0]:
#하이퍼 파라미터
which_direction = 'AtoB'

num_epochs = 100
batchSize = 1
lr = 0.0002
beta1 = 0.5 
beta2 = 0.999
lambda_A = 100.0

sample_path = './results'
log_step = 10
sample_step = 100
num_workers = 2

In [0]:
class ImageFolder(torch.utils.data.Dataset):
    def __init__(self):
        
        self.transformP = transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.5, 0.5, 0.5),
                                                                  (0.5, 0.5, 0.5))])
        
        self.image_len = None

        self.dir_base = './datasets'
        
        self.rootA = os.path.join(self.dir_base, 'pikachu_black')
        self.rootB = os.path.join(self.dir_base, 'pikachu_resized2')
        self.image_paths_A = list(map(lambda x: os.path.join(self.rootA, x), os.listdir(self.rootA)))
        self.image_paths_B = list(map(lambda x: os.path.join(self.rootB, x), os.listdir(self.rootB)))
        self.image_paths_A.sort()
        self.image_paths_B.sort()
        self.image_len = min(len(self.image_paths_A), len(self.image_paths_B))


    def __getitem__(self, index):
        A_path = self.image_paths_A[index]
        B_path = self.image_paths_B[index]
        A = Image.open(A_path).convert('RGB')
        B = Image.open(B_path).convert('RGB')

        A = self.transformP(A)
        B = self.transformP(B)

        return {'A': A, 'B': B}

    def __len__(self):
        return self.image_len

In [0]:
dataset = ImageFolder()
data_loader = DataLoader(dataset=dataset,
                          batch_size=batchSize,
                          shuffle=True,
                          num_workers=num_workers)

In [0]:

if not os.path.exists(sample_path):
    os.makedirs(sample_path)

In [0]:
class Generator(nn.Module):
    def __init__(self, batch_size):
        super(Generator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False # Instance Normalization
        else:
            bn = True # Batch Normalization

        # [3x256x256] -> [64x128x128]
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)

        # -> [128x64x64]
        conv2 = [nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]
        if bn == True:
            conv2 += [nn.BatchNorm2d(128)]
        else:
            conv2 += [nn.InstanceNorm2d(128)]
        self.conv2 = nn.Sequential(*conv2)

        # -> [256x32x32]
        conv3 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            conv3 += [nn.BatchNorm2d(256)]
        else:
            conv3 += [nn.InstanceNorm2d(256)]
        self.conv3 = nn.Sequential(*conv3)

        # -> [512x16x16]
        conv4 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(256, 512, 4, 2, 1)]
        if bn == True:
            conv4 += [nn.BatchNorm2d(512)]
        else:
            conv4 += [nn.InstanceNorm2d(512)]
        self.conv4 = nn.Sequential(*conv4)

        # -> [512x8x8]
        conv5 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv5 += [nn.BatchNorm2d(512)]
        else:
            conv5 += [nn.InstanceNorm2d(512)]
        self.conv5 = nn.Sequential(*conv5)

        # -> [512x4x4]
        conv6 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv6 += [nn.BatchNorm2d(512)]
        else:
            conv6 += [nn.InstanceNorm2d(512)]
        self.conv6 = nn.Sequential(*conv6)

        # -> [512x2x2]
        conv7 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv7 += [nn.BatchNorm2d(512)]
        else:
            conv7 += [nn.InstanceNorm2d(512)]
        self.conv7 = nn.Sequential(*conv7)

        # -> [512x1x1]
        conv8 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv8 += [nn.BatchNorm2d(512)]
        else:
            conv8 += [nn.InstanceNorm2d(512)]
        self.conv8 = nn.Sequential(*conv8)

        # -> [512x2x2]
        deconv8 = [nn.ReLU(),
                   nn.ConvTranspose2d(512, 512, 4, 2, 1)]
        if bn == True:
            deconv8 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv8 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv8 = nn.Sequential(*deconv8)

        # [(512+512)x2x2] -> [512x4x4]
        deconv7 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv7 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv7 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv7 = nn.Sequential(*deconv7)

        # [(512+512)x4x4] -> [512x8x8]
        deconv6 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv6 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv6 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv6 = nn.Sequential(*deconv6)

        # [(512+512)x8x8] -> [512x16x16]
        deconv5 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv5 += [nn.BatchNorm2d(512)]
        else:
            deconv5 += [nn.InstanceNorm2d(512)]
        self.deconv5 = nn.Sequential(*deconv5)

        # [(512+512)x16x16] -> [256x32x32]
        deconv4 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 256, 4, 2, 1)]
        if bn == True:
            deconv4 += [nn.BatchNorm2d(256)]
        else:
            deconv4 += [nn.InstanceNorm2d(256)]
        self.deconv4 = nn.Sequential(*deconv4)

        # [(256+256)x32x32] -> [128x64x64]
        deconv3 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3 += [nn.BatchNorm2d(128)]
        else:
            deconv3 += [nn.InstanceNorm2d(128)]
        self.deconv3 = nn.Sequential(*deconv3)

        # [(128+128)x64x64] -> [64x128x128]
        deconv2 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2 += [nn.BatchNorm2d(64)]
        else:
            deconv2 += [nn.InstanceNorm2d(64)]
        self.deconv2 = nn.Sequential(*deconv2)

        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )


    def forward(self, x):

        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        c5 = self.conv5(c4)
        c6 = self.conv6(c5)
        c7 = self.conv7(c6)
        c8 = self.conv8(c7)

        d7 = self.deconv8(c8)
        d7 = torch.cat((c7, d7), dim=1)
        d6 = self.deconv7(d7)
        d6 = torch.cat((c6, d6), dim=1)
        d5 = self.deconv6(d6)
        d5 = torch.cat((c5, d5), dim=1)
        d4 = self.deconv5(d5)
        d4 = torch.cat((c4, d4), dim=1)
        d3 = self.deconv4(d4)
        d3 = torch.cat((c3, d3), dim=1)
        d2 = self.deconv3(d3)
        d2 = torch.cat((c2, d2), dim=1)
        d1 = self.deconv2(d2)
        d1 = torch.cat((c1, d1), dim=1)
        out = self.deconv1(d1)

        return out

class Discriminator(nn.Module):
    def __init__(self, batch_size):
        super(Discriminator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False  # Instance Normalization
        else:
            bn = True  # Batch Normalization

        # [(3+3)x256x256] -> [64x128x128] -> [128x64x64]
        main = [nn.Conv2d(3*2, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]
        if bn == True:
            main += [nn.BatchNorm2d(128)]
        else:
            main += [nn.InstanceNorm2d(128)]

        # -> [256x32x32]
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            main += [nn.BatchNorm2d(256)]
        else:
            main += [nn.InstanceNorm2d(256)]

        # -> [512x31x31] (Fully Convolutional)
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(256, 512, 4, 1, 1)]
        if bn == True:
            main += [nn.BatchNorm2d(512)]
        else:
            main += [nn.InstanceNorm2d(512)]

        # -> [1x30x30] (Fully Convolutional, PatchGAN)
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(512, 1, 4, 1, 1),
                  nn.Sigmoid()]

        self.main = nn.Sequential(*main)

    def forward(self, x1, x2): # One for Real, One for Fake
        out = torch.cat((x1, x2), dim=1)
        return self.main(out)

In [0]:
generator = Generator(batchSize)
discriminator = Discriminator(batchSize)

In [0]:
criterionGAN = nn.BCELoss()
criterionL1 = nn.L1Loss()

In [0]:
g_optimizer = torch.optim.Adam(generator.parameters(), lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr, [beta1, beta2])

In [0]:
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()

In [89]:
total_step = len(data_loader) # For Print Log
for epoch in range(num_epochs):
    for i, sample in enumerate(data_loader):

        AtoB = which_direction == 'AtoB'
        input_A = sample['A' if AtoB else 'B']
        input_B = sample['B' if AtoB else 'A']

        # ===================== Train D =====================#
        discriminator.zero_grad()

        real_A = to_variable(input_A)
        fake_B = generator(real_A)
        real_B = to_variable(input_B)

        # d_optimizer.zero_grad()

        pred_fake = discriminator(real_A, fake_B)
        loss_D_fake = GAN_Loss(pred_fake, False, criterionGAN)

        pred_real = discriminator(real_A, real_B)
        loss_D_real = GAN_Loss(pred_real, True, criterionGAN)

        # Combined loss
        
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        loss_D.backward(retain_graph=True)
        d_optimizer.step()

        # ===================== Train G =====================#
        generator.zero_grad()

        pred_fake = discriminator(real_A, fake_B)
        loss_G_GAN = GAN_Loss(pred_fake, True, criterionGAN)

        loss_G_L1 = criterionL1(fake_B, real_B)

        loss_G = loss_G_GAN + loss_G_L1 * lambda_A
        loss_G.backward()
        g_optimizer.step()

        # print the log info
        if (i + 1) % log_step == 0:
            print('Epoch [%d/%d], BatchStep[%d/%d], D_Real_loss: %.4f, D_Fake_loss: %.4f, G_loss: %.4f, G_L1_loss: %.4f'
                  % (epoch + 1, num_epochs, i + 1, total_step, loss_D_real.item(), loss_D_fake.item(), loss_G_GAN.item(), loss_G_L1.item()))

        # save the sampled images
        if (i + 1) % sample_step == 0:
            res = torch.cat((torch.cat((real_A, fake_B), dim=3), real_B), dim=3)
            torchvision.utils.save_image(denorm(res.data), os.path.join(sample_path, 'Generated-%d-%d.png' % (epoch + 1, i + 1)))


Epoch [1/100], BatchStep[10/205], D_Real_loss: 0.7266, D_Fake_loss: 0.6493, G_loss: 0.9916, G_L1_loss: 0.2985
Epoch [1/100], BatchStep[20/205], D_Real_loss: 0.7175, D_Fake_loss: 0.6804, G_loss: 0.9309, G_L1_loss: 0.1922
Epoch [1/100], BatchStep[30/205], D_Real_loss: 0.6068, D_Fake_loss: 0.6292, G_loss: 1.0652, G_L1_loss: 0.2586
Epoch [1/100], BatchStep[40/205], D_Real_loss: 0.6368, D_Fake_loss: 0.4503, G_loss: 1.1079, G_L1_loss: 0.1833
Epoch [1/100], BatchStep[50/205], D_Real_loss: 0.5874, D_Fake_loss: 0.8480, G_loss: 0.9378, G_L1_loss: 0.1652
Epoch [1/100], BatchStep[60/205], D_Real_loss: 0.6819, D_Fake_loss: 0.5144, G_loss: 0.9238, G_L1_loss: 0.1751
Epoch [1/100], BatchStep[70/205], D_Real_loss: 0.6609, D_Fake_loss: 0.7668, G_loss: 1.2413, G_L1_loss: 0.1247
Epoch [1/100], BatchStep[80/205], D_Real_loss: 0.5092, D_Fake_loss: 0.3553, G_loss: 1.3671, G_L1_loss: 0.1869
Epoch [1/100], BatchStep[90/205], D_Real_loss: 0.3536, D_Fake_loss: 0.9012, G_loss: 1.0380, G_L1_loss: 0.1243
Epoch [1/1

KeyboardInterrupt: ignored

In [0]:
######################################무시##################################



dir_base = './datasets'
rootA = os.path.join(dir_base, 'pikachu_draw')
image_paths_A = list(map(lambda x: os.path.join(rootA, x), os.listdir(rootA)))
A_path = image_paths_A[0]

transform = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5))])
        

In [0]:
A = Image.open(A_path).convert('RGB')
A = transform(A)
real_A= to_variable(A)
real_A = real_A.unsqueeze(0)
fake_B= generator(real_A)

In [0]:
res = torch.cat((torch.cat((real_A, fake_B), dim=3), real_A), dim=3)
torchvision.utils.save_image(denorm(res.data), os.path.join(sample_path, 'Generated_draw_2.png'))