In [None]:
import glob
import random
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import sys
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import math
import itertools
import datetime
import time
from torchvision.utils import save_image, make_grid
from torchvision import datasets
from torch.autograd import Variable

In [None]:
!pip install kaggle
from google.colab import files, drive

#구글 드라이브 마운트
drive.mount('/gdrive')

#kaggle json 파일 가져오기
%cd /gdrive/MyDrive/kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

# Permission Warning 방지
!chmod 600 ~/.kaggle/kaggle.json
%cd /content/
!kaggle datasets download -d arnaud58/selfie2anime  #kaggle API url 복사

# 파일 압축 해제
!unzip /content/selfie2anime.zip

In [None]:
# # %cd /content/saved_model/
# !mkdir -p /gdrive/MyDrive/model
# !cp D_A_0.pth /gdrive/MyDrive/model
# !cp D_B_0.pth /gdrive/MyDrive/model
# !cp G_AB_0.pth /gdrive/MyDrive/model
# !cp G_BA_0.pth /gdrive/MyDrive/model
%cd /gdrive/MyDrive
!mkdir -p /content/saved_model/
!cp D_A_8.pth /content/saved_model/
!cp D_B_8.pth /content/saved_model/
!cp G_AB_8.pth /content/saved_model/
!cp G_BA_8.pth /content/saved_model/

/gdrive/MyDrive


In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

In [None]:
class ImageDataset(Dataset):
    def __init__(self,root, transforms_ = None, unaligned = False, mode = "train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned
    
        if mode == "train":
            self.files_A = sorted(glob.glob(os.path.join(root, "trainA" + "/*.*")))
            self.files_B = sorted(glob.glob(os.path.join(root, "trainB" + "/*.*")))
            
        else :
            self.files_A = sorted(glob.glob(os.path.join(root, "testA" + "/*.*")))
            self.files_B = sorted(glob.glob(os.path.join(root, "testB" + "/*.*")))
            
    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0,len(self.files_B) - 1)])        
        else :
            imgae_B = Image.open(self.files_B[index % len(self.files_B)])

        if image_A.mode !="RGB":
            image_A = to_rgb(image_A)
        
        if image_B.mode !="RGB":
            image_B = to_rgb(image_B)
        
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, .0, .02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, .0)
    
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace = True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )
    def forward(self, x):
        return x + self.block(x)

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()
        channels = input_shape[0]

        out_features = 64
        model = [
                 nn.ReflectionPad2d(channels),
                 nn.Conv2d(channels, out_features, 7),
                 nn.InstanceNorm2d(out_features),
                 nn.ReLU(inplace = True),
        ]

        in_features = out_features

        for _ in range(2):
            out_features *=2
            model += [
                      nn.Conv2d(in_features, out_features, 3, stride = 2, padding =1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True),
            ]
            in_features = out_features

        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        for _ in range(2):
            out_features //=2
            model += [
            nn.Upsample(scale_factor =2),
            nn.Conv2d(in_features, out_features, 3, stride =1, padding =1),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace = True),
            ]
            in_features = out_features

        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channel, height, width = input_shape

        self.output_shape = (1, height//2 **4, width //2 **4)

        def discriminator_block(in_filters, out_filters, normalize = True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride =2, padding =1)]

            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channel, 64, normalize =False),
            *discriminator_block(64,128),
            *discriminator_block(128,256),
            *discriminator_block(256,512),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512,1,4,padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:
dataset_name = "/content"
channels = 3
img_height = 256
img_width = 256
n_residual_blocks = 9
lr=.0002
b1 = .5
b2 = .999
n_epochs = 200
init_epoch = 0
decay_epoch = 100
lambda_cyc = 100
lambda_id = 5.
n_cpu = 2
batch_size =1
sample_interval = 100
checkpoint_interval = 5

In [None]:
os.makedirs("/content/images/%s" %dataset_name, exist_ok = True)
os.makedirs("/content/saved_model/%s" %dataset_name, exist_ok = True)

In [None]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

In [None]:
input_shape = (channels, img_height, img_width)

G_AB = GeneratorResNet(input_shape, n_residual_blocks)
G_BA = GeneratorResNet(input_shape, n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

if os.path.isfile('/content/D_A_8.pth'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model.load_state_dict(torch.load(PATH, map_location=device))
    D_A.load_state_dict(torch.load('/content/D_A_8.pth', map_location=device))
    D_B.load_state_dict(torch.load('/content/D_B_8.pth', map_location=device))

    G_AB.load_state_dict(torch.load('/content/G_AB_8.pth', map_location=device))
    G_BA.load_state_dict(torch.load('/content/G_BA_8.pth', map_location=device))
    # model.load_state_dict(torch.load(PATH))
else : 
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

In [None]:
cuda = torch.cuda.is_available()
if cuda :
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

In [None]:
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1,b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1,b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1,b2))

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, \
        "Decay must start before the trainig session ends"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1. - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [None]:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda = LambdaLR(n_epochs, init_epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda = LambdaLR(n_epochs, init_epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda = LambdaLR(n_epochs, init_epoch, decay_epoch).step
)

In [None]:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

In [None]:
class ReplayBuffer:
    def __init__(self,max_size=50):
        assert max_size >0 , "Empty buffer or trying to create a black hole"

        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element,0)
            if len(self.data) < self.max_size :
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > .5 :
                    i = random.randint(0,self.max_size -1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [None]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [None]:
transforms_ = [
               transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
               transforms.RandomCrop((img_height, img_width)),
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               transforms.Normalize((.5,.5,.5), (.5,.5,.5)),
]

  "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "


In [None]:
dataloader = DataLoader(
    ImageDataset(dataset_name, transforms_ = transforms_, unaligned=True),
    batch_size = batch_size,
    shuffle = True,
    num_workers = n_cpu
)

val_dataloader = DataLoader(
    ImageDataset(dataset_name, transforms_ = transforms_, unaligned=True, mode = "test"),
    batch_size = 5,
    shuffle = True,
    num_workers = 1
)

In [None]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)

    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize = True)
    fake_B = make_grid(fake_B, nrow=5, normalize = True)

    image_grid = torch.cat((real_A, fake_B, real_B, fake_A),1)
    save_image(image_grid, "%s/images/%s.png" % (dataset_name, batches_done), normalize=False)

In [None]:
prev_time = time.time()
for epoch in range(init_epoch, n_epochs):
    if epoch >=100:         # to test few epoches
        break
    for i, batch in enumerate(dataloader):
        if i>=40 :          # to test few batches
            break
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        valid = Variable(Tensor(np.ones((real_A.size(0),
                                         *D_A.output_shape))),
                         requires_grad = False)
        fake = Variable(Tensor(np.ones((real_A.size(0),
                                        *D_A.output_shape))),
                        requires_grad = False)
        
        G_AB.train()
        G_BA.train()

        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) /2

        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)

        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
        loss_G.backward()
        optimizer_G.step()
        optimizer_D_A.zero_grad()


        loss_real = criterion_GAN(D_A(real_A), valid)
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A.detach()), fake)
        loss_D_A = (loss_real + loss_fake) /2

        loss_D_A.backward()
        optimizer_D_A.step()

        optimizer_D_B.zero_grad()

        loss_real = criterion_GAN(D_B(real_B), valid)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)

        loss_D_B = (loss_real + loss_fake) /2
        loss_D_B.backward()
        optimizer_D_B.step()
        loss_D = (loss_D_A + loss_D_B) /2

        batches_done = epoch * len(dataloader) +i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds= batches_left * (time.time() - prev_time))
        prev_time = time.time()

        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # (18) If at sample interval save image
        # if batches_done % sample_interval == 0:
        sample_images(batches_done)

    # (19) Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()
    # (20) Save model checkpoints
    if checkpoint_interval != -1 and epoch % 2 == 0:
        torch.save(G_AB.state_dict(), "%s/saved_model/G_AB_%d.pth" % (dataset_name, epoch))
        torch.save(G_BA.state_dict(), "%s/saved_model/G_BA_%d.pth" % (dataset_name, epoch))
        torch.save(D_A.state_dict(), "%s/saved_model/D_A_%d.pth" % (dataset_name, epoch))
        torch.save(D_B.state_dict(), "%s/saved_model/D_B_%d.pth" % (dataset_name, epoch))

[Epoch 0/200] [Batch 9/3400] [D loss: 0.816931] [G loss: 57.214153, adv: 0.913433, cycle: 0.536030, identity: 0.539553] ETA: 527 days, 23:13:16.997074

KeyboardInterrupt: ignored