In [None]:
import os
import glob
import random
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms, models, datasets
from torch.optim.lr_scheduler import StepLR

from pylab import imread

In [None]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

for epoch in range (200):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

# Data

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

kwargs = {'num_workers': 2, 'pin_memory': True}

cuda = True
image_size = 32
batchSize = 64

In [None]:
def show_img(source, target, source_label, target_label):
    num_row = 4
    num_col = 5
    num = 10
    images = source[:num]
    labels = source_label[:num]

    fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))
    for i in range(num):
        ax = axes[i//num_col, i%num_col]
        
        image =  images[i].transpose(0,2).transpose(0,1)

        ax.imshow(image, cmap='gray')
        ax.set_title('Label: {}'.format(labels[i]))


    images = target[:num]
    labels = target_label[:num]
    for i in range(10,20):
        ax = axes[i//num_col, i%num_col]
        image = images[i - 10].transpose(0,2).transpose(0,1)
        ax.imshow(image, cmap='gray')
        ax.set_title('Label: {}'.format(labels[i - 10]))
    plt.tight_layout()
    plt.show()

In [None]:
transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: t * 2 - 1)])

mnist_trainset = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform)


In [None]:
def get_backgrounds():
    backgrounds = []
    for file in os.listdir("./images/train"):
        if file.endswith('.jpg'):
            backgrounds.append(plt.imread(os.path.join("./images/train",file)))
    return backgrounds

def compose_image(image, backgrounds):
    image = (image > 0).astype(np.float32)
    image = image.reshape([28,28])*255.0
    
    image = np.stack([image,image,image],axis=2)
    
    background = np.random.choice(backgrounds)
    w,h,_ = background.shape
    dw, dh,_ = image.shape
    x = np.random.randint(0,w-dw)
    y = np.random.randint(0,h-dh)
    
    temp = background[x:x+dw, y:y+dh]
    return np.abs(temp-image).astype(np.uint8)

class MNISTM(Dataset):    
    def __init__(self, train=True,transform=None):
        if train:
            self.data = datasets.MNIST(root='.data/mnist',train=True, download=True)
        else:
            self.data = datasets.MNIST(root='.data/mnist',train=False, download=True)
        self.backgrounds = get_backgrounds()
        self.transform = transform
        self.images = []
        self.targets = []
        for index in range(len(self.data)):
            image = np.array(self.data.__getitem__(index)[0])
            target = self.data.__getitem__(index)[1]
            image = compose_image(image, self.backgrounds)
            if self.transform is not None:
                image = self.transform(image)
            self.images.append(image)
            self.targets.append(target)
        
    def __getitem__(self,index):
        image = self.images[index]
        target = self.targets[index]
        
        return image, target
        
    def __len__(self):
        return len(self.data)

In [None]:
transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t * 2 - 1)
        ])

trainset = MNISTM(train=True,transform=transform)
testset = MNISTM(train=False,transform=transform)

In [None]:
source_train = DataLoader(mnist_trainset, batch_size=batchSize, shuffle=True, drop_last=True, **kwargs)
source_test = DataLoader(mnist_testset, batch_size=batchSize, shuffle=True, drop_last=True, **kwargs)

target_train = DataLoader(trainset, batch_size=batchSize, shuffle=True, drop_last=True, **kwargs)
target_test = DataLoader(testset, batch_size=batchSize, shuffle=False, drop_last=True, **kwargs)

# Module

In [None]:
class Mixer(nn.Module):
    def __init__(self):
        super(Mixer, self).__init__()

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )

        self.filter = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(256),
            nn.MaxPool2d(2),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7),
            nn.InstanceNorm2d(3),
            nn.Tanh()
        )

    def forward(self, A, B):
        encode_A = self.encoder(A)
        encode_B = self.encoder(B)

        reconA = self.decoder(encode_A)
        reconB = self.decoder(encode_B)

        encode_A.detach()
        encode_B.detach()

        mixed_latent = torch.cat([encode_A, encode_B], dim=1)
        mixed_image = self.filter(mixed_latent)
        new_image = self.decoder(mixed_image)
        
        return new_image, reconA, reconB

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc = 3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, padding=1),
            nn.InstanceNorm2d(128), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, padding=1),
            nn.InstanceNorm2d(256), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, padding=1)
        )


    def forward(self, x):
        x =  self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [None]:
source_iter = iter(source_train)
source_inputs, source_label = source_iter.next()

target_iter = iter(target_train)
target_inputs, target_label = target_iter.next()

test_tensor_source = source_inputs
test_tensor_target = target_inputs

perceptual = Mixer()
encoder = perceptual.encoder
decoder = perceptual.decoder
encoder_out = encoder(test_tensor_target)
decoder_out = decoder(encoder_out)

print(encoder_out.size())
print(decoder_out.size())

mixed, reconA, reconB = perceptual(test_tensor_source, test_tensor_target)
print(mixed.size())
print(reconA.size())
print(reconB.size())

print("mixed: min: %.2f, max: %.2f " % (torch.min(mixed), torch.max(mixed)))
print("reconA: min: %.2f, max: %.2f " % (torch.min(reconA), torch.max(reconA)))
print("reconB: min: %.2f, max: %.2f " % (torch.min(reconB), torch.max(reconB)))

# Loss

In [None]:
def tv_loss(img, tv_weight=5e-2):
    w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))
    h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))
    loss = tv_weight * (h_variance + w_variance)
    return loss

def total_variation_loss(img, weight=5e-2):
    bs_img, c_img, h_img, w_img = img.size()
    tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum()
    tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum()
    return weight * (tv_h + tv_w) / (bs_img * c_img * h_img * w_img)

# Training

In [None]:
mixer = Mixer()
discriminator = Discriminator()

In [None]:
learning_rate = 0.05
beta = (0.5, 0.999)

criterion_adv = torch.nn.MSELoss()
criterion_discriminator = torch.nn.MSELoss()
criterion_construct = torch.nn.L1Loss()

optimizer_mix = torch.optim.Adam(mixer.parameters(), lr=learning_rate, betas=beta)
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=beta)

scheduler_mix = StepLR(optimizer_mix, step_size=10, gamma=0.4)
scheduler_dis = StepLR(optimizer_dis, step_size=10, gamma=0.4)

In [None]:
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    cudnn.benchmark = True

    criterion_construct = criterion_construct.cuda()
    criterion_discriminator = criterion_discriminator.cuda()
    criterion_construct = criterion_construct.cuda()

    mixer.cuda()
    discriminator.cuda()

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

In [None]:
mixer.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

In [None]:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
real_label = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
fake_label = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)
real_label = real_label[:, None]
fake_label = fake_label[:, None]

input_s = Tensor(batchSize, 3, image_size, image_size)
input_t = Tensor(batchSize, 3, image_size, image_size)

if cuda:
  input_s, input_t = input_s.cuda(), input_t.cuda()

In [None]:
# Mixed image will have same label with source
def training(source, target, mixer, discriminator, 
             critic_adv, cirtic_recon, cirtic_dis, 
             optim_mix, optim_dis, 
             sche_mix, sche_dis, 
             use_cuda = True):
    source_iter = iter(source)
    target_iter = iter(target)

    len_dataloader = min(len(source_iter), len(target_iter))

    i = 0
    while i < len_dataloader:
        s_img, _ = source_iter.next()
        t_img, _ = target_iter.next()

        if use_cuda:
            s_img, t_img = s_img.cuda(), t_img.cuda()

        org_s = Variable(input_s.copy_(s_img))
        org_t = Variable(input_t.copy_(t_img))

        # Mixer
        optim_mix.zero_grad()
        mixed, recon_s, recon_t = mixer(org_s, org_t)
        mixed_label = discriminator(mixed)

        loss_ss = cirtic_recon(recon_s, org_s) * 5.0
        loss_tt = cirtic_recon(recon_t, org_t) * 5.0

        # if the discriminator predit false, loss_adv will decrease
        loss_adv = critic_adv(mixed_label, real_label)
        TV_loss = total_variation_loss(mixed)

        mixer_loss = loss_ss + loss_tt + loss_adv + TV_loss
        mixer_loss.backward()
        optim_mix.step()

        # Discriminator
        optim_dis.zero_grad()

        pred_real = discriminator(org_s)
        pred_fake = discriminator(mixed.detach())

        loss_dis_real = cirtic_dis(pred_real, real_label)
        loss_dis_fake = cirtic_dis(pred_fake, fake_label)

        discriminator_loss = loss_dis_real + loss_dis_fake
        discriminator_loss.backward()
        optim_dis.step()

        if  i % 400 == 0:
            real_A = org_s.data
            real_B = org_t.data
            mixed_image = mixed.data
            reconstructionA = recon_s.data
            reconstructionB = recon_t.data

            save_image(real_A, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(real_B, 'output/%03d/%d_B.png' % ( epoch, i))
            save_image(reconstructionA, 'output/%03d/%d_reconA.png' % ( epoch, i))
            save_image(reconstructionB, 'output/%03d/%d_reconB.png' % ( epoch, i))
            save_image(mixed_image, 'output/%03d/%d_Mixed.png' % ( epoch, i))

        i += 1

    print ("e: %d" % epoch)    
    print ("mixer_loss: %.2f, loss_ss: %.2f, loss_tt: %.2f, loss_adv: %.2f" % (mixer_loss, loss_ss, loss_tt, loss_adv))
    print ("discriminator_loss: %.2f, D_real: %.2f, D_fake: %.2f" % (discriminator_loss, loss_dis_real, loss_dis_fake))

    sche_mix.step()
    sche_dis.step()

In [None]:
for epoch in range(0, 200):
  training(source_train, target_train, mixer, discriminator, 
           criterion_adv, criterion_construct, criterion_discriminator, 
           optimizer_mix, optimizer_dis, 
           scheduler_mix, scheduler_dis, use_cuda = True)