In [7]:
import os 
import numpy as np
import torch 
import torch.utils.data as data
from PIL import Image
from torchvision import datasets, transforms

import random
import os.path

import torchvision
import torch.functional as F
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.utils as vutils

from torch.utils.data import DataLoader


In [15]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
class Dataset(data.Dataset):
    def __init__(self, src_path, trg_path, transform):
        self.src_path = src_path
        self.trg_path = trg_path

        self.src_images = os.listdir(src_path)
        self.trg_images = os.listdir(trg_path)
        self.transform = transform


    def __getitem__(self, index):
        src_img = Image.open(os.path.join(self.src_path, self.src_images[index]))
        src_img = src_img.convert('RGB')
        trg_img = Image.open(os.path.join(self.trg_path, self.trg_images[index]))
        trg_img = trg_img.convert('RGB')

        if self.transform is not None:
            src_img = self.transform(src_img)
            trg_img = self.transform(trg_img)

        return src_img, trg_img

    def __len__(self):
        return len(self.src_images)


In [10]:
def get_loader(batch_size, src_path, trg_path, transform, shuffle=True):
    dataset = Dataset(src_path, trg_path, transform)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1)

    return dataloader

In [11]:
def conv_block(idx, name, in_c, out_c, activation, kernel_size=3, stride=1, padding=1, transpose=False, bn=True, bias=True, drop=False):
    block = nn.Sequential()

    if not transpose:
        block.add_module(name + ' Conv2d' + idx, nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=bias))
    else:
        block.add_module(name + ' Conv2d_Transpose' + idx, nn.ConvTranspose2d(in_c, out_c, kernel_size, stride, padding, bias=bias))
    if bn:
        block.add_module(name + ' Batch_norm' + idx, nn.BatchNorm2d(out_c))
    if activation == 'relu':
        block.add_module(name + ' ReLU' + idx, nn.ELU(inplace=True))
    elif activation == 'leaky_relu':
        block.add_module(name + ' Leaky_ReLU' + idx, nn.LeakyReLU(0.2, inplace=True))
    elif activation == 'sigmoid':
        block.add_module(name + ' Sigmoid' + idx, nn.Sigmoid())
    elif activation == 'tanh':
        block.add_module(name + ' Tanh' + idx, nn.Tanh())
    if drop:
        block.add_module(name + " Drop_out" + idx, nn.Dropout())
    
    return block


In [None]:
class G(nn.Module):
    """
    input : ? x 128 x 128 x 3
    layer0 : ? x 128 x 128 x 32
    layer1 : ? x 64 x 64 x 64
    layer2 : ? x 32 x 32 x 128
    layer3 : ? x 16 x 16 x 256
    layer4 : ? x 8 x 8 x 512
    layer5 : ? x 4 x 4 x 1024
    dlayer4 : ? x 8 x 8 x 512
    dlayer3 : ? x 16 x 16 x 256
    dlayer2 : ? x 32 x 32 x 128
    dlayer1 : ? x 64 x 64 x 64
    dlayer0 : ? x 128 x 128 x 33
    """
    def __init__(self):
        super(G, self).__init__()
        self.name = "G"

        self.build()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n) ** 0.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def build(self):
        activation = 'leaky_relu'
        self.pooling = nn.Sequential(nn.AvgPool2d((2, 2), 2))

        # down
        self.layer0_0 = conv_block('0_0', self.name, 3, 32, activation, kernel_size=4, stride=2)
        self.layer0_1 = conv_block('0_1', self.name, 32, 32, activation)

        self.layer1_0 = conv_block('1_0', self.name, 32, 64, activation, kernel_size=4, stride=2)
        self.layer1_1 = conv_block('1_1', self.name, 64, 64, activation)

        self.layer2_0 = conv_block('2_0', self.name, 64, 128, activation, kernel_size=4, stride=2)
        self.layer2_1 = conv_block('2_1', self.name, 128, 128, activation)

        self.layer3_0 = conv_block('3_0', self.name, 128, 256, activation, kernel_size=4, stride=2) 
        self.layer3_2 = conv_block('3_2', self.name, 256, 256, activation, drop=True) 
        
        self.layer4_0 = conv_block('4_0', self.name, 256, 512, activation, kernel_size=4, stride=2)
        self.layer4_2 = conv_block('4_2', self.name, 512, 256, activation)

        # up
        self.dlayer4_0 = conv_block('up4_0', self.name, 512, 512, activation, transpose=True, kernel_size=4, stride=2)
        self.dlayer4_2 = conv_block('up4_2', self.name, 512, 256, activation, drop=True)

        self.dlayer3_0 = conv_block('up3_0', self.name, 512, 256, activation, transpose=True, kernel_size=4, stride=2)
        self.dlayer3_2 = conv_block('up3_2', self.name, 256, 128, activation)

        self.dlayer2_0 = conv_block('up2_0', self.name, 256, 128, activation, transpose=True, kernel_size=4, stride=2)
        self.dlayer2_1 = conv_block('up2_1', self.name, 128, 64, activation)

        self.dlayer1_0 = conv_block('up1_0', self.name, 128, 64, activation, transpose=True, kernel_size=4, stride=2)
        self.dlayer1_1 = conv_block('up1_1', self.name, 64, 32, activation)

        self.dlayer0_0 = conv_block('up0_0', self.name, 64, 32, activation, transpose=True, kernel_size=4, stride=2)
        self.dlayer0_1 = conv_block('up0_1', self.name, 32, 3, 'tanh')

    def forward(self, x):
        out0_0 = self.layer0_0(x)
        out0_1 = self.layer0_1(out0_0)
        
        out1_0 = self.layer1_0(out0_1)
        out1_1 = self.layer1_1(out1_0)

        out2_0 = self.layer2_0(out1_1)
        out2_1 = self.layer2_1(out2_0)

        out3_0 = self.layer3_0(out2_1)
        out3_2 = self.layer3_2(out3_0)

        out4_0 = self.layer4_0(out3_2)
        out4_2 = self.layer4_2(out4_0)
        
        cat_out5_2 = torch.cat((out4_2, self.pooling(out3_2)), 1)
        dout4_0 = self.dlayer4_0(cat_out5_2)
        dout4_2 = self.dlayer4_2(dout4_0)

        cat_out4_2 = torch.cat((dout4_2, out3_2), 1)
        dout3_0 = self.dlayer3_0(cat_out4_2)
        dout3_2 = self.dlayer3_2(dout3_0)

        cat_out3_2 = torch.cat((dout3_2, out2_1), 1)
        dout2_0 = self.dlayer2_0(cat_out3_2)
        dout2_1 = self.dlayer2_1(dout2_0)

        cat_out2_1 = torch.cat((dout2_1, out1_1), 1)
        dout1_0 = self.dlayer1_0(cat_out2_1)
        dout1_1 = self.dlayer1_1(dout1_0)

        cat_out1_1 = torch.cat((dout1_1, out0_1), 1)
        dout0_0 = self.dlayer0_0(cat_out1_1)
        dout0_1 = self.dlayer0_1(dout0_0)

        return dout0_1

In [12]:
class D(nn.Module):
    """
    input : ? x 128 x 128 x 6(3 x 2)
    layer0 : ? x 128 x 128 x 32 
    layer1 : ? x 64 x 64 x 64 
    layer2 : ? x 32 x 32 x 128 
    layer3 : ? x 31 x 31 x 256 
    output : ? x 30 x 30 x 1 
    """
    def __init__(self):
        super(D, self).__init__()
        self.name = "D"
        
        self.build()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n) ** 0.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def build(self):
        activation = 'leaky_relu'

        self.layer0 = conv_block('0', self.name, 6, 32, activation, bn=False, kernel_size=4, stride=2)

        self.layer1_0 = conv_block('1_0', self.name, 32, 64, activation, kernel_size=4, stride=2)
        self.layer1_1 = conv_block('1_1', self.name, 64, 64, activation)

        self.layer2_0 = conv_block('2_0', self.name, 64, 128, activation)
        self.layer2_1 = conv_block('2_1', self.name, 128, 128, activation)

        self.layer3_0 = conv_block('3_0', self.name, 128, 256, activation, kernel_size=2, stride=1, padding=0) 
        self.layer3_1 = conv_block('3_1', self.name, 256, 256, activation) 
        self.layer3_2 = conv_block('3_2', self.name, 256, 256, activation) 
        
        self.layer4_0 = conv_block('4_0', self.name, 256, 512, activation)
        self.layer4_1 = conv_block('4_1', self.name, 512, 512, activation)
        self.layer4_2 = conv_block('4_2', self.name, 512, 1, 'sigmoid', kernel_size=2, stride=1, padding=0)

    def forward(self, src_input, trg_input):
        x = torch.cat((src_input, trg_input), 1)
        out0 = self.layer0(x)

        out1_0 = self.layer1_0(out0)
        out1_1 = self.layer1_1(out1_0)

        out2_0 = self.layer2_0(out1_1)
        out2_1 = self.layer2_1(out2_0)

        out3_0 = self.layer3_0(out2_1)
        out3_1 = self.layer3_1(out3_0)
        out3_2 = self.layer3_2(out3_1)

        out4_0 = self.layer4_0(out3_2)
        out4_1 = self.layer4_1(out4_0)
        out4_2 = self.layer4_2(out4_1)

        return out4_2


In [None]:
class Pix2Pix:
    def __init__(self, batch_size, epoch_iter, lr, src_path, trg_path, sample_img_path, save_model_path, restore_D_model_path, restore_G_model_path, gpu):
        self.batch_size = batch_size
        self.epoch_iter = epoch_iter
        self.lr = lr

        self.src_path = src_path
        self.trg_path = trg_path
        self.sample_img_path = sample_img_path
        self.save_model_path = save_model_path
        self.restore_D_model_path = restore_D_model_path
        self.restore_G_model_path = restore_G_model_path

        self.D = D()
        self.G = G()
        self.d_step = 1
        self.g_step = 1

        self.gpu = gpu

        self.transform = transforms.Compose([transforms.ToTensor()])
                                            #  transforms.Normalize((0.485, 0.456, 0.406),
                                            #                       (0.229, 0.224, 0.225))])

    def load_dataset(self):
        src_data = dset.ImageFolder(self.src_path, self.transformations)
        self.src_loader = DataLoader(src_data, batch_size=self.batch_size)
        trg_data = dset.ImageFolder(self.trg_path, self.transformations)
        self.trg_loader = DataLoader(trg_data, batch_size=self.batch_size)

    def train(self):
        data_loader = get_loader(self.batch_size, self.src_path, self.trg_path, self.transform)
        print('Dataset Load Success!')

        if len(self.restore_G_model_path):
            self.D.load_state_dict(torch.load(self.restore_D_model_path))
            self.G.load_state_dict(torch.load(self.restore_G_model_path))
            print('Pretrained model load success!')

        D_adam = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0.5, 0.999))
        G_adam = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0.5, 0.999))

        if self.gpu:
            self.D = self.D.cuda()
            self.G = self.G.cuda()
            ones, zeros = Variable(torch.ones(self.batch_size, 1, 30, 30).cuda()), Variable(torch.zeros(self.batch_size, 1, 30, 30).cuda())
            BCE_loss = nn.BCELoss().cuda()
            L1_loss = nn.L1Loss().cuda()
            MSE_loss = nn.MSELoss().cuda()
        else:
            ones, zeros = Variable(torch.ones(self.batch_size, 1, 30, 30)), Variable(torch.zeros(self.batch_size, 1, 30, 30))
            BCE_loss = nn.BCELoss()
            L1_loss = nn.L1Loss()
            MSE_loss = nn.MSELoss()

        self.D.train()
        self.G.train()
        print('Training Start')
        for epoch in range(3, self.epoch_iter):
            for step, (src, trg) in enumerate(data_loader):
                for d_i in range(self.d_step):
                    src, trg = iter(data_loader).next()
                    src_data = src.cuda()
                    trg_data = trg.cuda()

                    self.D.zero_grad()
                    self.G.zero_grad()

                    src_input = Variable(src_data.cuda())
                    trg_input = Variable(trg_data.cuda())

                    src_generated = self.G(src_input)

                    D_src_generated = self.D(src_generated, trg_input)
                    D_trg_input = self.D(trg_input, trg_input)

                    # training D
                    D_fake_loss = BCE_loss(D_src_generated, zeros)
                    D_real_loss = BCE_loss(D_trg_input, ones)
                    D_loss = D_fake_loss + D_real_loss
                    D_loss.backward(retain_graph=True)
                    D_adam.step()
                
                for p in self.D.parameters():
                    p.requires_grad = False
                
                for g_i in range(self.g_step):
                    # training G
                    G_fake_loss = BCE_loss(D_src_generated, ones)
                    G_distance_loss = MSE_loss(src_generated, trg_input) * 100
                    G_loss = G_fake_loss + G_distance_loss
                    G_loss.backward(retain_graph=True)
                    G_adam.step()
                    
                    
                for p in self.D.parameters():
                    p.requires_grad = True

                # logging losses
                if step % 20 == 0:
                    print(f"Epoch: {epoch} & Step: {step} => D-fake Loss: {D_fake_loss.data}, D-real Loss: {D_real_loss.data}, G Loss: {G_loss.data}")
                    
                # save sample images 
                if step % 50 == 0:
                    vutils.save_image(src_data[0], os.path.join(self.sample_img_path, f'epoch-{epoch}-step-{step}-src_input.jpg'))
                    vutils.save_image(trg_data[0], os.path.join(self.sample_img_path, f'epoch-{epoch}-step-{step}-trg_input.jpg'))
                    vutils.save_image(src_generated.data[0], os.path.join(self.sample_img_path, f'epoch-{epoch}-step-{step}-generated.jpg'))

            # save model
            torch.save(self.D.state_dict(), os.path.join(self.save_model_path, str(epoch) + 'D' + '.pth'))
            torch.save(self.G.state_dict(), os.path.join(self.save_model_path, str(epoch) + 'G' + '.pth'))
                