In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as td
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import cv2
import torchvision
import pandas as pd

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data
def load_image_data(input_path, max_size):
    # load a list of image file paths
    data_list = os.listdir(input_path)
    # create a list of images
    images = []
    for i, l in enumerate(data_list):
        if i > max_size:
            break
        p = input_path + l
        im = cv2.imread(p)
        images.append(im)
    return images

# Build Data Loader
def set_image_data(x_path, c_path, batch_size=100, max_size=100000):
    x_images = load_image_data(x_path, max_size)
    c_images = load_image_data(c_path, max_size)

    # list to torch tensor
    # each element is divided by 255 to float
    # axis change from (n_image, height, width, channel) to (n_image, channel, height, width)
    x_tensor = torch.from_numpy((np.array(x_images) / 255.).astype(np.float32)).to(device)
    x_tensor = x_tensor.permute((0,3,1,2))
    c_tensor = torch.from_numpy((np.array(c_images) / 255.).astype(np.float32)).to(device)
    c_tensor = c_tensor.permute((0,3,1,2))

    # output data set
    image_set = td.TensorDataset(x_tensor, c_tensor) # train data and the label
    dataloader = td.DataLoader(dataset=image_set, batch_size=batch_size, shuffle=True)
    return dataloader

In [3]:
# CNN-VAE Model components
class EncoderModule(nn.Module):
    def __init__(self, input_channels, output_channels, stride, kernel, pad):
        super().__init__()
        # 2D convolution: kernel=1->filter size=[1,1], pad=0->padding size=0 (NONE padding),pad=1->ZERO padding for size 1 (for both edges)
        # TO BE REVIEWED: Can Pooling give some improvement to the face generation??
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride)
        # Batch normalisation
        self.bn = nn.BatchNorm2d(output_channels)
        # Relu activation
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
class Encoder(nn.Module):
    def __init__(self, color_channels, pooling_kernels, n_neurons_in_middle_layer):
        self.n_neurons_in_middle_layer = n_neurons_in_middle_layer
        super().__init__()
        # (N,3,50,100)->(N,32,50,100)
        self.bottle = EncoderModule(color_channels, 32, stride=1, kernel=1, pad=0)
        # (N,32,50,100)->(N,64,50,100)
        self.m1 = EncoderModule(32, 64, stride=1, kernel=3, pad=1)
        # (N,64,50,100)->(N,128,12,24) with stride size 4. TO BE REVIEWED: pad=0 or p
        self.m2 = EncoderModule(64, 128, stride=4, kernel=[3,5], pad=0)
        # (N,128,12,24)->(N,256,4,4) with stride size [2,4] TO BE REVIEWED: pad=0 or p
        self.m3 = EncoderModule(128, 256, stride=[2,4], kernel=[5,9], pad=0)
    def forward(self, x):
        out = self.m3(self.m2(self.m1(self.bottle(x))))
        return out.view(-1, self.n_neurons_in_middle_layer)

class DecoderModule(nn.Module):
    def __init__(self, input_channels, output_channels, stride, activation="relu", kernel=1):
        super().__init__()
        # 2D Deconvolution
        self.convt = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel, stride=stride)
        # Batch renormalisation
        self.bn = nn.BatchNorm2d(output_channels)
        # ReLu or Sigmoid activation
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
    def forward(self, x):
        return self.activation(self.bn(self.convt(x)))

class Decoder(nn.Module):
    def __init__(self, in_channels, color_channels, pooling_kernels, decoder_input_size):
        self.decoder_input_size = decoder_input_size
        self.in_channels = in_channels
        super().__init__()
        # (N,128,12,24)<-(N,256,4,4) (in_channels=256 for VAE, 512 for CVAE)
        self.m1 = DecoderModule(self.in_channels, 128, stride=[3,6], kernel=[3,6])
        # (N,64,50,100)<-(N,128,12,24)
        self.m2 = DecoderModule(128, 64, stride=[4,4], kernel=[6,8])
        # (N,32,50,100)<-(N,64,50,100)
        self.m3 = DecoderModule(64, 32, stride=1, kernel=1)
        # (N,3,50,100)<-(N,32,50,100)
        self.bottle = DecoderModule(32, color_channels, stride=1, activation="sigmoid", kernel=1)
    def forward(self, x):
        # support 2 cases
        # VAE: input size (N,256*4*4) ->(N,256,4,4)
        # CVAE: input size (N,2*256*4*4) ->(N,512,4,4)
        out = x.view(-1, self.in_channels, self.decoder_input_size, self.decoder_input_size)
        out = self.m3(self.m2(self.m1(out)))
        return self.bottle(out)

In [4]:
# (C)VAE model
class VAE(nn.Module):
    def __init__(self, conditional):
        self.device = device
        self.conditional = conditional # if true, CVAE
        super().__init__()
        
        # Common feature in Encoder/Decoder (De)Convolution layer
        color_channels = 3
        pooling_kernel = [1, 1]
        encoder_output_size = 4 # Given by the Encoder architecture
        encoder_1d_output_size = 256 * encoder_output_size * encoder_output_size

        # Encoder for X (lower-half face)
        self.encoder_x = Encoder(color_channels, pooling_kernel, encoder_1d_output_size)        
        # Encoder for C (Upper-half face, "label")
        if self.conditional:
            self.encoder_c = Encoder(color_channels, pooling_kernel, encoder_1d_output_size) # ALMOST 100% SURE ABOUT THIS SETTING, BUT TO BE REVIEWED

        # Latent Features
        # output size [N,256,4,4]
        self.n_latent_features = 64 # dimension of z
        n_neurons_middle_layer = (2 if self.conditional else 1) * encoder_1d_output_size # we double if CVAE
        
        # for encoder X
        self.fc1 = nn.Linear(n_neurons_middle_layer, self.n_latent_features)
        self.fc2 = nn.Linear(n_neurons_middle_layer, self.n_latent_features)
        self.fc3 = nn.Linear(self.n_latent_features, encoder_1d_output_size)

        # for encoder C
        if self.conditional:
            self.fc4 = nn.Linear(encoder_1d_output_size, self.n_latent_features) 
            self.fc5 = nn.Linear(encoder_1d_output_size, self.n_latent_features)
            self.fc6 = nn.Linear(self.n_latent_features, encoder_1d_output_size)

        # Decoder for X (Labels should be taken into account also for this decoder stage...)
        self.decoder_x = Decoder(256*(2 if self.conditional else 1), color_channels, pooling_kernel, encoder_output_size)
        # Decoder for C (Upper-half face, "label")
        if self.conditional:
            self.decoder_c = Decoder(256, color_channels, pooling_kernel, encoder_output_size)

    # Reparametrisation trick
    def _reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = torch.randn(*mu.size()).to(self.device)
        z = mu + std * esp
        return z
    
    # From the encoder_x output, Output the latent vector "zx" 
    def _bottleneck_x(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self._reparameterize(mu, logvar)
        return z, mu, logvar

    # From the encoder_c output, Output the latent vector "zc" 
    def _bottleneck_c(self, h):
        mu, logvar = self.fc4(h), self.fc5(h)
        z = self._reparameterize(mu, logvar)
        return z, mu, logvar

    # Sampling lower-half face (given an input upper-half face image)
    def sampling(self, c=None, r=None):
        # assume latent features space ~ N(0, 1)
        n = c.size(0) if self.conditional else 64 # number of faces
        if r==None:
            zx = torch.randn(n, self.n_latent_features).to(self.device)
        else:
            zx = torch.full((n,self.n_latent_features), fill_value=r).to(self.device)
        zx = self.fc3(zx) # 256 * encode_out_size(4)* encode_out_size(4)
        
        # concat
        if self.conditional:
            hc = self.encoder_c(c) # 256 * encode_out_size(4)* encode_out_size(4)
            z = torch.cat((zx, hc), dim=-1) # 2* 256 * encode_out_size(4)* encode_out_size(4)
        else:
            z = zx
        # decode
        dx = self.decoder_x(z)
        d = torch.cat((c,dx), dim=2)
        return d

    def forward(self, x, c=None):
        # Encoder for X and for C
        # output size [N,256*4*4]
        hx = self.encoder_x(x)
        if self.conditional:
            hc = self.encoder_c(c)
        # concat if CVAE
        if self.conditional:
            h = torch.cat((hx, hc), dim=-1)
        else:
            h = hx
        # Bottle-neck
        # z,mu,logvar size [N,256*4*4]->[N,64]
        zx, mu_x, logvar_x = self._bottleneck_x(h)
        if self.conditional:
            zc, mu_c, logvar_c = self._bottleneck_c(hc)
            
        # decoder
        # output size [N,64]->[N,256*4*4]
        zx = self.fc3(zx)
        if self.conditional:
            zx = torch.cat((zx, hc), dim=-1) # 2* 256 * encode_out_size(4)* encode_out_size(4)
            zc = self.fc6(zc)
            
        dx = self.decoder_x(zx)
        if self.conditional:
            dc = self.decoder_c(zc)
            d = torch.cat((dc,dx), dim=2)
            mu = torch.sqrt(mu_x.pow(2) + mu_c.pow(2))
            logvar = torch.sqrt(logvar_x.pow(2) + logvar_c.pow(2))
        else:
            d = dx
            mu = mu_x
            logvar = logvar_x
        return d, mu, logvar

    # Model loss function
    def loss_function(self, recon_x, x, mu, logvar):
        # https://arxiv.org/abs/1312.6114 (Appendix B)
        BCE = F.binary_cross_entropy(recon_x, x, size_average=False)        
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def init_model(self):
        self.optimizer = optim.Adam(self.parameters(), lr=1e-3)

        if self.device == "cuda":
            self = self.cuda()
            torch.backends.cudnn.benchmark=True
        self.to(self.device)

In [5]:
# Train and Save (C)VAE Model
def train_and_save(data_loader, model_path, csv_log_path, nepoch=20, nReportFreq=20, conditional=True):
    # Init the model
    cvae = VAE(conditional)
    cvae.init_model()
    # Training mode
    cvae.train()
    history = []
    for epoch in range(nepoch):
        train_loss = 0
        samples_cnt = 0
        for i, data in enumerate(dataloader):
            # Extract input (lower-half face) and label (upper-half face)
            inputs, labels = data
            # Define them as torch Variable (auto-gradient will be applied on)
            inputs, labels = Variable(inputs), Variable(labels)
            # init the optimiser (FORGOT why we initialise the grad at the beginning of each epoch)
            cvae.optimizer.zero_grad()
            # forward
            recon_batch, mu, logvar = cvae.forward(inputs, labels)
            # evaluate loss
            inputs_full = torch.cat((labels, inputs), dim=2)
            loss = cvae.loss_function(recon_batch, inputs_full, mu, logvar)
            # back prop and update the weights
            loss.backward()
            cvae.optimizer.step()

            train_loss += loss.item()
            samples_cnt += inputs.size(0)

            if i%nReportFreq == 0:
                history.append([epoch, i, data_loader.batch_size*i, train_loss / samples_cnt])
                print('#Epoch: {:}, #Batch: {:}, #Pics:{:} Loss: {:.2f}'.format(epoch, i, data_loader.batch_size*i, train_loss / samples_cnt))

    # save model
    state_dict = cvae.state_dict()
    torch.save(state_dict, model_path)
    # save log
    if csv_log_path != None:
        col = ["#Epoch", "#Batch", "#Pic", "Loss"]
        df_history = pd.DataFrame(data=history, columns=col)
        df_history.to_csv(csv_log_path)

# Testing
def test_and_save(dataloader, model_path, reconstr_path, csv_log_path, nrow=10, nReportFreq=20, conditional=True):
    # Load the trained model
    cvae = VAE(conditional)
    cvae.init_model()
    state_dict = torch.load(model_path)
    cvae.load_state_dict(state_dict)

    # Evaluation
    cvae.eval()
    val_loss = 0
    samples_cnt = 0
    history = []
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # Extract input (lower-half face) and label (upper-half face)
            inputs, labels = data
            # Define them as torch Variable (auto-gradient will be applied on)
            inputs, labels = Variable(inputs), Variable(labels)
            recon_batch, mu, logvar = cvae.forward(inputs, labels)
            inputs_full = torch.cat((labels, inputs), dim=2)
            val_loss += cvae.loss_function(recon_batch, inputs_full, mu, logvar).item()
            samples_cnt += inputs.size(0)

            # From BGR -> RGB
            perm = torch.LongTensor([2,1,0])
            recon_batch = recon_batch[:, perm, :, :]
            if reconstr_path != None:
                torchvision.utils.save_image(recon_batch, reconstr_path + '/test_' + str(i) + '.png', nrow=nrow)

            if i%nReportFreq == 0:
                history.append([i, dataloader.batch_size*i, val_loss / samples_cnt])
                print('#Batch: {:}, Loss: {:.2f}'.format(i, val_loss / samples_cnt))
                
    # save log
    if csv_log_path != None:
        col = ["#Batch", "#Pic", "Loss"]
        df_history = pd.DataFrame(data=history, columns=col)
        df_history.to_csv(csv_log_path)

# Face Generation with Random (or input) latent vector
def gen_face(dataloader, model_path, gen_path, nrow=10, conditional=True):
    # Load the trained model
    cvae = VAE(conditional)
    cvae.init_model()
    state_dict = torch.load(model_path)
    cvae.load_state_dict(state_dict)

    # Random sampling
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            if i > 0:
                break
            _, labels = data
            gen_batch = cvae.sampling(c=labels, r=None)

            # BGR -> RGB
            perm = torch.LongTensor([2,1,0])
            full_face_batch = gen_batch[:, perm, :, :]
            torchvision.utils.save_image(full_face_batch, gen_path + '/test_' + str(i) + '.png', nrow=nrow, normalize=False, range=(0,255))

In [None]:
# Data loader
x_path = '/root/userspace/project/image_bottom/' # lower-half face
c_path = '/root/userspace/project/image_top/' # upper-half face
dataloader = set_image_data(x_path, c_path, batch_size=100, max_size=100000)

# Train and Save the model
conditional = True
save_path = '/root/userspace/project/model/model.pth'
csv_log_path = '/root/userspace/project/model/model_log.csv'
train_and_save(dataloader, model_path=save_path, csv_log_path=csv_log_path, nepoch=30, nReportFreq=20, conditional=True)

# Test and Save the pictures
reconstr_path = '/root/userspace/project/model/reconstruction/'
csv_log_path = None
test_and_save(dataloader, model_path=save_path, reconstr_path=reconstr_path, csv_log_path=csv_log_path)

# Lower-half face generation
gen_path = '/root/userspace/project/model/generation/'
gen_face(dataloader, model_path=save_path, gen_path=gen_path)

    Found GPU0 GRID K520 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    


#Epoch: 0, #Batch: 0, #Pics:0 Loss: 21919.39
#Epoch: 0, #Batch: 20, #Pics:2000 Loss: 20239.55
#Epoch: 0, #Batch: 40, #Pics:4000 Loss: 19844.59
#Epoch: 0, #Batch: 60, #Pics:6000 Loss: 19646.01
#Epoch: 0, #Batch: 80, #Pics:8000 Loss: 19506.06
#Epoch: 0, #Batch: 100, #Pics:10000 Loss: 19397.57
#Epoch: 0, #Batch: 120, #Pics:12000 Loss: 19318.98
#Epoch: 1, #Batch: 0, #Pics:0 Loss: 18782.31
#Epoch: 1, #Batch: 20, #Pics:2000 Loss: 18774.26
#Epoch: 1, #Batch: 40, #Pics:4000 Loss: 18735.52
#Epoch: 1, #Batch: 60, #Pics:6000 Loss: 18690.75
#Epoch: 1, #Batch: 80, #Pics:8000 Loss: 18660.97
#Epoch: 1, #Batch: 100, #Pics:10000 Loss: 18634.82
#Epoch: 1, #Batch: 120, #Pics:12000 Loss: 18606.11
#Epoch: 2, #Batch: 0, #Pics:0 Loss: 18479.76
#Epoch: 2, #Batch: 20, #Pics:2000 Loss: 18399.30
#Epoch: 2, #Batch: 40, #Pics:4000 Loss: 18376.83
#Epoch: 2, #Batch: 60, #Pics:6000 Loss: 18358.07
#Epoch: 2, #Batch: 80, #Pics:8000 Loss: 18350.37
#Epoch: 2, #Batch: 100, #Pics:10000 Loss: 18335.36
#Epoch: 2, #Batch: 120