In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.autograd import Variable
import torch.backends.cudnn as cudnn

from torchvision.utils import save_image

import numpy as np

import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms, models, datasets

In [2]:
import glob

In [3]:
os.path.exists("./output/000")

True

In [4]:
for epoch in range (200):
  if (os.path.exists("./output/000")) == False:
    os.mkdir("./output/%03d" % epoch)
  else:
    files = glob.glob("./output/%03d/*.png" % epoch)

    for f in files:
      os.remove(f)

# Data

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def get_backgrounds():
    backgrounds = []
    for file in os.listdir("./images/train"):
        if file.endswith('.jpg'):
            backgrounds.append(plt.imread(os.path.join("./images/train",file)))
    return np.array(backgrounds)
backgrounds = get_backgrounds()


def compose_image(image):
    image = (image > 0).astype(np.float32)
    image = image.reshape([28,28])*255.0
    
    image = np.stack([image,image,image],axis=2)
    
    background = np.random.choice(backgrounds)
    w,h,_ = background.shape
    dw, dh,_ = image.shape
    x = np.random.randint(0,w-dw)
    y = np.random.randint(0,h-dh)
    
    temp = background[x:x+dw, y:y+dh]
    return np.abs(temp-image).astype(np.uint8)


class MNISTM(Dataset):
            
    def __init__(self, train=True,transform=None):
        if train:
            self.data = datasets.MNIST(root='.data/mnist',train=True, download=True)
        else:
            self.data = datasets.MNIST(root='.data/mnist',train=False, download=True)
        self.backgrounds = get_backgrounds()
        self.transform = transform
        self.images = []
        self.targets = []
        for index in range(len(self.data)):
            image = np.array(self.data.__getitem__(index)[0])
            target = self.data.__getitem__(index)[1]
            image = compose_image(image)
            if self.transform is not None:
                image = self.transform(image)
            self.images.append(image)
            self.targets.append(target)
        
    def __getitem__(self,index):
        
        #image = Image.fromarray(image.squeeze(), mode="RGB")
        image = self.images[index]
        target = self.targets[index]
        
        return image, target
        
    def __len__(self):
        return len(self.data)
    
def get_mnistm_loaders(data_aug = False, batch_size=128,test_batch_size=1000):
    if data_aug:
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(128),
            transforms.RandomCrop(128,padding=4),
            transforms.ToTensor()
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(128),
            transforms.ToTensor()
        ])
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(128),
        transforms.ToTensor()
    ])
    
    kwargs = {}
    train_loader = DataLoader(
        MNISTM(train=True,transform=train_transform),batch_size=batch_size, shuffle=True, drop_last=True)
    train_eval_loader = DataLoader(
        MNISTM(train=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(
        MNISTM(train=False,transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    return train_loader, train_eval_loader, test_loader


def get_mnist_loaders(data_aug = False, batch_size=128,test_batch_size=1000):
    if data_aug:
        train_transform = transforms.Compose(
            [transforms.Resize(128),
            transforms.RandomCrop(128,padding=4),
            transforms.Grayscale(3),
            transforms.ToTensor()
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.Grayscale(3),
            transforms.ToTensor()
        ])
    test_transform = transforms.Compose([
        transforms.Resize(128),
        transforms.Grayscale(3),
        transforms.ToTensor()
    ])
    kwargs = {}

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=True, download=True,transform=train_transform),batch_size=batch_size, shuffle=True, drop_last=True)
    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=True, download=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=False, download=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    return train_loader, train_eval_loader, test_loader

  return np.array(backgrounds)


In [6]:
loader_source, mnist_eval_loader, mnist_test_loader = get_mnist_loaders(batch_size=16)
loader_target, mnistm_eval_loader,mnistm_test_loader = get_mnistm_loaders(batch_size=16)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .data/mnist\MNIST\raw\train-images-idx3-ubyte.gz


9913344it [00:00, 12063601.29it/s]                             


Extracting .data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to .data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz


29696it [00:00, 29776249.48it/s]         

Extracting .data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz



1649664it [00:00, 10337953.55it/s]                            


Extracting .data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to .data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


5120it [00:00, 5130156.83it/s]          


Extracting .data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw



  return np.array(backgrounds)


In [7]:
data_target_iter = iter(loader_source)
target_inputs, target_label = data_target_iter.next()

In [8]:
target_inputs.size()

torch.Size([16, 3, 64, 64])

# Module

## Encoder

In [9]:
class Encoder(nn.Module):
    def __init__(self, input_nc=3):
        super(Encoder, self).__init__()

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        model += [nn.Conv2d(64, 128, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(128),
                  nn.ReLU(inplace=True)]
                  
        model += [nn.Conv2d(128, 256, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(256, 256, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(256, 512, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(512),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(512, 1024, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(1024),
                  nn.ReLU(inplace=True)]



        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


## Decoder

In [10]:
class Decoder(nn.Module):
    def __init__(self, input_nc=1024, output_nc=3):
        super(Decoder, self).__init__()
        model = [nn.ConvTranspose2d(input_nc, 512, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(512),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]
        
        model += [nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(128),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(64),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(64),
                  nn.ReLU(inplace=True)]
                  
        
        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.InstanceNorm2d(3),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

## Identity Generator

In [11]:
class Identity_Generator(nn.Module):
    def __init__(self, encoder, decoder):
        super(Identity_Generator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, A, B):
        latentA = self.encoder(A)
        latentB = self.encoder(B)

        reconstructedA = self.decoder(latentA)
        reconstructedB = self.decoder(latentB)
        return reconstructedA, reconstructedB

## Perceptual

In [12]:
class Perceptual(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(Perceptual, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def forward(self, A, B):

        reconstructedA, reconstructedB = self.generator(A, B)

        latentA = self.encoder(A)
        latentB = self.encoder(B)

        latentA.detach()
        latentB.detach()

        style = latentA[:, 0:512, : , :]
        content = latentB[:, 512:1024, :, :]
        
        mixed_latent = torch.cat([style, content], dim=1)
        mixed_image = self.decoder(mixed_latent)

        return mixed_image, reconstructedA, reconstructedB

## Discriminator

In [13]:
class Discriminator(nn.Module):
    def __init__(self, input_nc = 3):
        super(Discriminator, self).__init__()

        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(512, 1, 4, padding=1)]


        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

## Model test

In [14]:
test_tensor_source = torch.rand(20, 3, 64, 64)
test_tensor_target = torch.rand(20, 3, 64, 64)

encoder = Encoder()
decoder = Decoder()
generator = Identity_Generator(encoder, decoder)
perceptual = Perceptual(encoder, decoder, generator)
discriminator = Discriminator()

In [15]:
output = encoder(test_tensor_source)
print(output.size())

torch.Size([20, 1024, 2, 2])


In [16]:
output = decoder(output)
print(output.size())

torch.Size([20, 3, 64, 64])


In [17]:
output_a, output_b = generator(test_tensor_source, test_tensor_target)
print(output_a.size())
print(output_b.size())

torch.Size([20, 3, 64, 64])
torch.Size([20, 3, 64, 64])


In [18]:
mix, output_a, output_b = perceptual(test_tensor_source, test_tensor_target)

print(mix.size())
print(output_a.size())
print(output_b.size())

torch.Size([20, 3, 64, 64])
torch.Size([20, 3, 64, 64])
torch.Size([20, 3, 64, 64])


In [19]:
output = discriminator(test_tensor_source)
print(output.size())

torch.Size([20, 1])


In [20]:
print(mix[0][0][0] - output_a[0][0][0])

tensor([-0.1822, -0.0187, -0.0880,  0.1730,  0.0934,  0.1873, -0.0700,  0.0873,
        -0.6610,  0.1256, -0.5339,  0.5257,  0.0075, -0.2972,  0.6022,  0.0140,
        -0.4612, -0.3736, -0.1491,  0.1655,  0.2377, -0.6492,  0.4178, -1.1898,
         0.4611,  0.0185, -0.0891, -0.6819, -0.4517, -0.2797,  0.4677, -0.1813,
        -1.2597, -0.0622,  0.2252, -0.7779,  0.0116,  0.3160, -0.3435,  0.1735,
        -0.4187, -0.2169,  0.4381, -1.7511, -0.4523,  0.0623, -0.3610,  0.5910,
         0.7784, -0.1097,  0.1520, -0.0261, -0.2544,  0.6162,  0.1289,  0.2401,
         0.1021,  0.5830,  0.1024, -0.0797, -0.4688, -0.0559,  0.5697,  0.3826],
       grad_fn=<SubBackward0>)


# Loss

In [21]:
def tv_loss(img, tv_weight=5e-2):
    w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))
    h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))
    loss = tv_weight * (h_variance + w_variance)
    return loss

def total_variation_loss(img, weight=5e-2):
    bs_img, c_img, h_img, w_img = img.size()
    tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum()
    tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum()
    return weight * (tv_h + tv_w) / (bs_img * c_img * h_img * w_img)

def compute_content_loss(target_feature, content_feature):
    return torch.mean((target_feature - content_feature)**2)

def batch_gram_matrix(img):
    """
    Compute the gram matrix by converting to 2D tensor and doing dot product
    img: (batch, channel/depth, height, width)
    """
    b, d, h, w = img.size()
    img = img.view(b*d, h*w) # fix the dimension. It doesn't make sense to put b=1 when it's not always the case
    gram = torch.mm(img, img.t())
    return gram

style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}

def compute_style_loss(style_features, target_features):
    # the style loss
    # initialize the style loss to 0
    style_loss = 0
    style_grams = {layer: batch_gram_matrix(style_features[layer]) for layer in style_features}
    # then add to it for each layer's gram matrix loss
    for layer in style_weights:
        # get the "target" style representation for the layer
        target_feature = target_features[layer]
        target_gram = batch_gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        # get the "style" style representation
        style_gram = style_grams[layer]
        # the style loss for one layer, weighted appropriately
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
        # add to the style loss
        style_loss += layer_style_loss / (d * h * w)

    return style_loss

# Training

## Training Method

In [22]:
def get_features(image, model, layers=None):
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1',
                  '10': 'conv3_1',
                  '19': 'conv4_1',
                  '21': 'conv4_2',  ## content representation
                  '28': 'conv5_1'}

    features_arr = []
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x

    return features

In [23]:
def train_with_autoencoder(source, target, preceptual, discriminator, criterion_adv, criterion_discriminator, criterion_construct, opt_pre, opt_dis, use_cuda, vgg, epoch):
    data_target_iter = iter(target)
    total = 0

    for batch_idx, (source_inputs, source_label) in enumerate(source):
        batch_size = source_inputs.size(0)
        total += batch_size
        target_inputs, target_label = data_target_iter.next()

        if use_cuda: 
            source_inputs, target_inputs = source_inputs.cuda(), target_inputs.cuda()
            real_label, fake_label = real_label.cuda(), fake_label.cuda()

        # preceptual
        mixed_image, reconstruct_source, reconstruct_target = preceptual(source_inputs, target_inputs)

        loss_ss = criterion_construct(source_inputs, reconstruct_source) * 30
        loss_tt = criterion_construct(target_inputs, reconstruct_target) * 30

        TV_loss = total_variation_loss(mixed_image)
        pred_fake = discriminator(mixed_image)
        loss_adv = criterion_adv(pred_fake, real_label)

        cuda_mixed_image = mixed_image.clone().requires_grad_(True).cuda()
        cuda_real_A = source_inputs.clone().requires_grad_(True).cuda()
        cuda_real_B = target_inputs.clone().requires_grad_(True).cuda()
        style_features = get_features(cuda_real_A, vgg)
        content_features = get_features(cuda_real_B, vgg)
        target_features = get_features(cuda_mixed_image, vgg)

        content_loss = compute_content_loss(target_features['conv4_2'], content_features['conv4_2']) * 0.1
        style_loss = compute_style_loss(style_features, target_features) * 0.05

        preceptual_loss = loss_ss + loss_tt + TV_loss + loss_adv + content_loss + style_loss

        opt_pre.zero_grad()
        preceptual_loss.backward()
        opt_pre.step()

        # Update discriminator
        pred_real = discriminator(source_inputs)
        pred_fake = discriminator(mixed_image.detach())

        loss_dis_real = criterion_discriminator(pred_real, real_label)
        loss_dis_fake = criterion_discriminator(pred_real, fake_label)

        discriminator_loss = loss_dis_real + loss_dis_fake

        opt_dis.zero_grad()
        discriminator_loss.backward()
        opt_dis.step()
    
        if  batch_idx % 200 == 0:
            real_A = source_inputs.data
            real_B = target_inputs.data
            mixed_image = mixed_image.data
            reconstructionA = reconstruct_source.data
            reconstructionB = reconstruct_target.data

            save_image(real_A, './output/%03d/00_%d_A.png' % ( epoch, batch_idx))
            save_image(real_B, 'output/%03d/00_%d_B.png' % ( epoch, batch_idx))
            save_image(reconstructionA, 'output/%03d/00_%d_reconA.png' % ( epoch, batch_idx))
            save_image(reconstructionB, 'output/%03d/00_%d_reconB.png' % ( epoch, batch_idx))
            save_image(mixed_image, 'output/%03d/00_%d_Mixed.png' % ( epoch, batch_idx))

    print ("e: %d, pre_loss: %.2f, D_real: %.2f, D_fake: %.2f" % (epoch, preceptual_loss, loss_dis_real, loss_dis_fake))


## Model hypermeter

In [24]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [25]:
encoder = Encoder(input_nc = 3)
decoder = Decoder()
generator = Identity_Generator(encoder, decoder)
perceptual = Perceptual(encoder, decoder, generator)
discriminator = Discriminator()

In [26]:
learning_rate = 0.0002
beta = (0.5, 0.999)

criterion_adv = torch.nn.MSELoss()
criterion_discriminator = torch.nn.MSELoss()
criterion_construct = torch.nn.L1Loss()

optimizer_pre = torch.optim.Adam(perceptual.parameters(), lr=learning_rate, betas=beta)
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=beta)

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_pre, lr_lambda=LambdaLR(200, 0, 100).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_dis, lr_lambda=LambdaLR(200, 0, 100).step)


In [27]:
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [28]:
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    cudnn.benchmark = True

    criterion_construct = criterion_construct.cuda()
    criterion_discriminator = criterion_discriminator.cuda()
    criterion_construct = criterion_construct.cuda()
    
    encoder.cuda()
    decoder.cuda()
    generator.cuda()
    perceptual.cuda()
    discriminator.cuda()

## Init the network

In [29]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

In [30]:
encoder.apply(weights_init_normal)
decoder.apply(weights_init_normal)
generator.apply(weights_init_normal)
perceptual.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

## Training

In [31]:
batchSize = 16
cuda = True

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

real_label = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
fake_label = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)
real_label = real_label[:, None]
fake_label = fake_label[:, None]


input_A = Tensor(batchSize, 3, 64, 64)
input_B = Tensor(batchSize, 3, 64, 64)

for epoch in range(0, 200):
    i = -1
    for batchA, batchB in zip(loader_target, loader_source):
        i += 1
        real_A = Variable(input_A.copy_(batchA[0]))
        real_B = Variable(input_B.copy_(batchB[0]))

        optimizer_pre.zero_grad()

        mixed_image, reconstructionA, reconstructionB = perceptual(real_A, real_B)

        loss_ss = criterion_construct(reconstructionA, real_A)*30.0
        loss_tt = criterion_construct(reconstructionB, real_B)*30.0

        TV_loss = total_variation_loss(mixed_image)
        pred_fake = discriminator(mixed_image)
        loss_adv = criterion_adv(pred_fake, real_label)

        cuda_mixed_image = mixed_image.clone().requires_grad_(True).cuda()
        cuda_real_A = real_A.clone().requires_grad_(True).cuda()
        cuda_real_B = real_B.clone().requires_grad_(True).cuda()
        style_features = get_features(cuda_real_A, vgg)
        content_features = get_features(cuda_real_B, vgg)
        target_features = get_features(cuda_mixed_image, vgg)

        content_loss = compute_content_loss(target_features['conv4_2'], content_features['conv4_2']) * 0.1
        style_loss = compute_style_loss(style_features, target_features) * 0.05

        preceptual_loss = loss_ss + loss_tt + TV_loss + loss_adv + content_loss + style_loss

        preceptual_loss.backward()
        optimizer_pre.step()

        optimizer_dis.zero_grad()

        pred_real = discriminator(real_A)
        pred_fake = discriminator(mixed_image.detach())

        loss_dis_real = criterion_discriminator(pred_real, real_label)
        loss_dis_fake = criterion_discriminator(pred_real, fake_label)

        discriminator_loss = loss_dis_real + loss_dis_fake
        discriminator_loss.backward()
        optimizer_dis.step()

        if  i % 200 == 0:
            real_A = real_A.data
            real_B = real_B.data
            mixed_image = mixed_image.data 
            reconstructionA = reconstructionA.data
            reconstructionB = reconstructionB.data

            save_image(real_A, './output/%03d/00_%d_A.png' % ( epoch, i))
            save_image(real_B, 'output/%03d/00_%d_B.png' % ( epoch, i))
            save_image(reconstructionA, 'output/%03d/00_%d_reconA.png' % ( epoch, i))
            save_image(reconstructionB, 'output/%03d/00_%d_reconB.png' % ( epoch, i))
            save_image(mixed_image, 'output/%03d/00_%d_Mixed.png' % ( epoch, i))

    lr_scheduler_G.step()
    lr_scheduler_D.step()
    print ("e: %d, pre_loss: %.2f, D_real: %.2f, D_fake: %.2f" % (epoch, preceptual_loss, loss_dis_real, loss_dis_fake))
    

KeyboardInterrupt: 