In [None]:
import os
from os import listdir
from os.path import join
import numpy as np
import random
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms
import torchvision.datasets as datasets

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [None]:
class data_from_dir(data.Dataset):
     def __init__(self, image_dir, transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_dir = image_dir
        self.image_filenames = [ x for x in listdir(image_dir) if 
        is_image_file(x)]
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(join(self.image_dir, 
        self.image_filenames[index])).convert('RGB')
        image = self.transform(image)
        return image

    def __len__(self):
        return len(self.image_filenames)

In [None]:
batch_size = 16
image_size = 125
image_channels = 3
n_conv_blocks = 2
up_sampling = 2
n_epochs = 100
learning_rate_G = 0.00001
learning_rate_D = 0.0000001

In [None]:
class discriminator_model(nn.Module)
    def __init__(self):
        super(discriminator_model, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv5_bn = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
        self.conv6_bn = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.conv7_bn = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1)
        self.conv8_bn = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.relu(self.conv4_bn(self.conv4(x)))
        x = F.relu(self.conv5_bn(self.conv5(x)))
        x = F.relu(self.conv6_bn(self.conv6(x)))
        x = F.relu(self.conv7_bn(self.conv7(x)))
        x = F.relu(self.conv8_bn(self.conv8(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        return x

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_channels, k, layers, p=0.2):
        super(conv_block, self).__init__()
        self.layers = layers

        for i in range(layers):
            self.add_module('batchnorm' + str(i+1), 
            nn.BatchNorm2d(in_channels))
            self.add_module('conv' + str(i+1), 
            nn.Conv2d(in_channels, k, 3, stride=1,
            padding=1))
            self.add_module('drop' + str(i+1),
            nn.Dropout2d(p=p))
            in_channels += k

    def forward(self, x):
        for i in range(self.layers):
            y = self.__getattr__('batchnorm' + str(i+1))(x.clone())
            y = F.elu(y)
            y = self.__getattr__('conv' + str(i+1))(y)
            y = self.__getattr__('drop' + str(i+1))(y)
            x = torch.cat((x,y), dim=1)
        return x

In [None]:
class upsample_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(upsample_block, self).__init__()
        self.upsample1 = nn.Upsample(scale_factor=2,
        mode='nearest')
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 
       stride=1, padding=1)

    def forward(self, x):
        return F.elu(self.conv1(self.upsample1(x)))

In [None]:
class generator_model(nn.Module):
    def __init__(self, n_conv_blocks, n_upsample_blocks):
        super(generator_model, self).__init__()
        self.n_dense_blocks = n_blocks
        self.upsample = upsample

        self.conv1 = nn.Conv2d(3, 64, 9, stride=1, padding=1)

        inchannels = 64
        for i in range(self.n_conv_blocks):
            self.add_module('conv_block' + str(i+1), 
            conv_block(inchannels, 12, 4))
            inchannels += 12*4

        self.conv2 = nn.Conv2d(inchannels, 64, 3,
        stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)

        in_channels = 64
        out_channels = 256
        for i in range(self.n_upsample_blocks):
            self.add_module('upsample_block' + str(i+1),
            upsample_block(in_channels, out_channels))
            in_channels = out_channels
            out_channels = int(out_channels/2)

        self.conv3 = nn.Conv2d(in_channels, 3, 9,
        stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)

        for i in range(self.n_dense_blocks):
            x = self.__getattr__('conv_block' + str(i+1))(x)

        x = F.elu(self.conv2_bn(self.conv2(x)))

        for i in range(self.upsample):
            x = self.__getattr__('upsample_blcok' + str(i+1))(x)

        return self.conv3(x)

In [None]:
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                std = [0.229, 0.224, 0.225])

scale = transforms.Compose([transforms.ToPILImage(),
                            transforms.Scale(image_size),
                            transforms.ToTensor(),
                            transforms.Normalize
                            (mean = [0.485, 0.456, 0.406],
                            std = [0.229, 0.224, 0.225])
                            ])

transform = transforms.Compose([transforms.Scale(image_size*n_upsampling),
                                transforms.ToTensor()])

In [None]:
dataset = data_from_dir('Data/CelebA/splits/train', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

net_G = generator_model(n_conv_blocks, n_upsampling)
net_D = discriminator_model()

In [None]:
target_real = Variable(torch.ones(batch_size, 1))
target_fake = Variable(torch.zeros(batch_size, 1))
target_real = target_real.cuda()
target_fake = target_fake.cuda()

inputs_G = torch.FloatTensor(batch_size, image_channels, image_size, image_size)

net_G.cuda()
net_D.cuda()
feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
feature_extractor.cuda()

content_loss = nn.MSELoss()
adversarial_loss = nn.BCELoss()
content_loss.cuda()
adversarial_loss.cuda()

In [None]:
opt_G = optim.Adam(net_G.parameters(), lr=learning_rate_G)
opt_D = optim.Adam(net_D.parameters(), lr=learning_rate_D)

In [None]:
def plot_output(inputs_G, inputs_D_real, inputs_D_fake):
    image_size = (250, 250)
    transform = transforms.Compose([transforms.Normalize
                        (mean = [-2.118, -2.036, -1.804],
                         std = [4.367, 4.464, 4.444]),
                         transforms.ToPILImage(),
                                            
     transforms.Scale(image_size)])
    
    figure, (lr_plot, hr_plot, fake_plot) = plt.subplots(1,3)
    
    i = random.randint(0, inputs_G.size(0) -1)

    lr_image = transform(inputs_G[i])
    hr_image = transform(inputs_D_real[i])
    fake_hr_image = transform(inputs_D_fake[i])

    lr_image_ph = lr_plot.imshow(lr_image)
    hr_image_ph = hr_plot.imshow(hr_image)
    fake_hr_image_ph = fake_plot.imshow(fake_hr_image)

    figure.canvas.draw()
    plt.show()

In [None]:
inputs_G = torch.FloatTensor(batch_size, 3, image_size, image_size)

for epoch in range(n_epochs):
    for i, inputs in enumerate(dataloader):

        for j in range(batch_size):
            inputs_G[j] = scale(inputs[j])
            inputs[j] = normalize(inputs[j])

        inputs_D_real = Variable(inputs.cuda())
        inputs_D_fake = net_G(Variable(inputs_G).cuda())        
        net_D.zero_grad()

        outputs = net_D(inputs_D_real)
        D_real = outputs.data.mean()

        loss_D_real = adversarial_loss(outputs, target_real)
        loss_D_real.backward()

        outputs = net_D(inputs_D_fake.detach())
        D_fake = outputs.data.mean()

        loss_D_fake = adversarial_loss(outputs, target_fake)
        loss_D_fake.backward()

        opt_D.step()

        net_G.zero_grad()
        real_features = 
        Variable(feature_extractor(inputs_D_real).data)
        fake_features = feature_extractor(inputs_D_fake)

        loss_G_content = content_loss(fake_features, real_features)
        loss_G_adversarial = adversarial_loss(net_D(inputs_D_fake).detach(), target_real)

        loss_G_total = 0.005*lossG_content + 0.001*lossG_adversarial
        loss_G_total.backward()
        opt_G.step()

    plot_output(inputs_G, inputs_D_real.cpu().data, inputsD_fake.cpu().data)