In [None]:
import os
import PIL
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import norm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils

from vae_auto_encoder import VAEAutoEncoder
from image_label_dataset import ImageLabelDataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
batch_size = 32
image_size = 128
z_dim_size = 200

data_path = './data/celeba/'
csv_data_path = './data/celeba/list_attr_celeba.csv'
model_save_path = './vae_faces_model.pth'
save_folder = './images/celeba'

In [None]:
att = pd.read_csv(csv_data_path)
att.head()

In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),])
    #transforms.Normalize((0.5, 0.5, 0.5),
    #                     (0.5, 0.5, 0.5))])

dataset = ImageLabelDataset(data_path,
                            csv_data_path,
                            transform=transform)

dataloader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)


In [None]:
model = VAEAutoEncoder(4,
                       encoder_channels=[3, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[2, 2, 2, 2],
                       decoder_channels=[64, 64, 64, 32, 3],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[2, 2, 2, 2],
                       linear_sizes=[4096, z_dim_size],
                       view_size=[-1, 64, 8, 8],
                       use_batch_norm=True,
                       use_dropout=True).to(device)

model.load_state_dict(torch.load(model_save_path))
model.eval()

In [None]:
num_to_show = 10

inputs, _ = next(iter(dataloader))

inputs = inputs.to(device)
reconst_images, mu, log_var = model(inputs)
print(reconst_images.shape)
print(nn.MSELoss()(reconst_images, inputs))

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(inputs[:10],
                                         nrow=10,
                                         padding=2,
                                         normalize=True).cpu(), (1, 2, 0)))
plt.savefig(os.path.join(save_folder, 'input_images.png'))

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(reconst_images[:10],
                                         nrow=10,
                                         padding=2,
                                         normalize=True).detach().cpu(), (1, 2, 0)))
plt.savefig(os.path.join(save_folder, 'output_images.png'))

In [None]:
output_list = []

for i, (inputs, _) in enumerate(dataloader):
    inputs = inputs.to(device)
    outputs, mu, log_var = model.encode(inputs)
        
    outputs = outputs.detach().cpu().numpy()
    output_list.append(outputs)
    
    if i == 20:
        break

output_np = np.vstack(output_list)

x = np.linspace(-3, 3, 100)

fig = plt.figure(figsize=(20, 20))
fig.subplots_adjust(hspace=0.6, wspace=0.4)

for i in range(50):
    ax = fig.add_subplot(5, 10, i + 1)
    ax.hist(output_np[:, i], density=True, bins=20)
    ax.axis('off')
    ax.text(0.5, -0.35, str(i), fontsize=10, ha='center', transform=ax.transAxes)
    ax.plot(x, norm.pdf(x))

fig.savefig(os.path.join(save_folder, 'distribution.png'))
    

In [None]:
num_to_show = 30

z_new = torch.randn(size=(num_to_show, z_dim_size), device=device)

reconst = model.decode(z_new)
    
plt.figure(figsize=(18, 5))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(reconst_images[:30],
                                         nrow=10,
                                         padding=2,
                                         normalize=True).detach().cpu(), (1, 2, 0)))
plt.savefig(os.path.join(save_folder, 'generated_images.png'))


In [None]:
def get_vector_from_label(label, batch_size):
    image_label_dataset = ImageLabelDataset(data_path,
                                            csv_data_path,
                                            transform=transform,
                                            label=label)

    image_label_dataloader = DataLoader(image_label_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=4,
                                        pin_memory=True)

    origin = np.zeros(shape=z_dim_size, dtype='float32')
    current_sum_POS = np.zeros(shape=z_dim_size, dtype='float32')
    current_n_POS = 0
    current_mean_POS = np.zeros(shape=z_dim_size, dtype='float32')

    current_sum_NEG = np.zeros(shape=z_dim_size, dtype='float32')
    current_n_NEG = 0
    current_mean_NEG = np.zeros(shape=z_dim_size, dtype='float32')

    current_vector = np.zeros(shape=z_dim_size, dtype='float32')
    current_dist = 0

    print('label: ' + label)
    print('images : POS move : NEG move :distance : 𝛥 distance')
    while(current_n_POS < 10000):
        for _, (inputs, attribute) in enumerate(image_label_dataloader, 0):
            inputs = inputs.to(device)

            outputs, mu, log_var = model.encode(inputs)
            outputs = outputs.detach().cpu().numpy()

            z_POS = outputs[attribute == 1]
            z_NEG = outputs[attribute == -1]

            if len(z_POS) > 0:
                current_sum_POS = current_sum_POS + np.sum(z_POS, axis=0)
                current_n_POS += len(z_POS)
                new_mean_POS = current_sum_POS / current_n_POS
                movement_POS = np.linalg.norm(new_mean_POS - current_mean_POS)

            if len(z_NEG) > 0:
                current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis=0)
                current_n_NEG += len(z_NEG)
                new_mean_NEG = current_sum_NEG / current_n_NEG
                movement_NEG = np.linalg.norm(new_mean_NEG - current_mean_NEG)

            current_vector = new_mean_POS - new_mean_NEG
            new_dist = np.linalg.norm(current_vector)
            dist_change = new_dist - current_dist

            print(str(current_n_POS)
                  + '\t: ' + str(np.round(movement_POS, 3))
                  + '\t: ' + str(np.round(movement_NEG, 3))
                  + '\t: ' + str(np.round(new_dist, 3))
                  + '\t: ' + str(np.round(dist_change, 3))
                  )

            current_mean_POS = np.copy(new_mean_POS)
            current_mean_NEG = np.copy(new_mean_NEG)
            current_dist = np.copy(new_dist)

            if np.sum([movement_POS, movement_NEG]) < 0.08:
                current_vector = current_vector / current_dist
                print('Found the ' + label + ' vector')
                break

        return current_vector


In [None]:
def add_vector_to_images(label, feature_vec):
    num_to_show = 5
    factors = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
    
    images = []
    
    batch = next(iter(dataloader))
    inputs, attributes = batch
    device_inputs = inputs.to(device)
        
    z_points, mu, log_var = model.encode(device_inputs)
    z_points = z_points.detach().cpu()
    
    for i in range(num_to_show):
        images.append(inputs[i])
        for factor in factors:
            changed_z_point = z_points[i] + feature_vec * factor
            changed_images = model.decode(changed_z_point.to(device))
            changed_images = changed_images.detach().cpu()
            
            images.append(changed_images.squeeze(0))
        
    plt.figure(figsize=(18, 10))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(images,
                                             nrow=10,
                                             padding=2,
                                             normalize=True), (1, 2, 0)))
    plt.savefig(os.path.join(save_folder, label + '.png'))

In [None]:
BATCH_SIZE = 500
attractive_vec = get_vector_from_label('Attractive', BATCH_SIZE)
mouth_open_vec = get_vector_from_label('Mouth_Slightly_Open', BATCH_SIZE)
smiling_vec = get_vector_from_label('Smiling', BATCH_SIZE)
lipstick_vec = get_vector_from_label('Wearing_Lipstick', BATCH_SIZE)
young_vec = get_vector_from_label('High_Cheekbones', BATCH_SIZE)
male_vec = get_vector_from_label('Male', BATCH_SIZE)

In [None]:
eyeglasses_vec = get_vector_from_label('Eyeglasses', BATCH_SIZE)

In [None]:
blonde_vec = get_vector_from_label('Blond_Hair', BATCH_SIZE)

In [None]:
print('Attractive Vector')
add_vector_to_images('Attractive', attractive_vec)

print('Mouth Open Vector')
add_vector_to_images('Mouth_Open', mouth_open_vec)

print('Smiling Vector')
add_vector_to_images('Smiling', smiling_vec)

print('Lipstick Vector')
add_vector_to_images('Lipstick', lipstick_vec)

print('Young Vector')
add_vector_to_images('Young', young_vec)

print('Male Vector')
add_vector_to_images('Male', male_vec)

print('Eyeglasses Vector')
add_vector_to_images('Eyeglasses', eyeglasses_vec)

print('Blond Vector')
add_vector_to_images('Blond', blonde_vec)

In [None]:
def morph_faces(start_image_filename, end_image_filename):
    factors = np.arange(0, 1, 0.1)
    images = []
    
    start_image = PIL.Image.open(os.path.join(data_path, 'img_align_celeba_png', start_image_filename))
    end_image = PIL.Image.open(os.path.join(data_path, 'img_align_celeba_png', end_image_filename))
    
    start_image = transform(start_image)
    end_image = transform(end_image)
    
    inputs = torch.stack((start_image, end_image), 0).to(device)
    
    z_points, mu, log_var = model.encode(inputs)
    z_points = z_points.detach().cpu()
    
    images.append(start_image)
    for factor in factors:
        changed_z_point = z_points[0] * (1-factor) + z_points[1] * factor
        changed_image = model.decode(changed_z_point.to(device))[0]
        images.append(changed_image.detach().cpu())
        
    images.append(end_image)
        
    plt.figure(figsize=(18, 8))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(images,
                                             nrow=12,
                                             padding=2,
                                             normalize=True), (1, 2, 0)))
    plt.savefig(os.path.join(save_folder, start_image_filename[:-4] + '_' + end_image_filename))
        

In [None]:
start_image_file = '000238.png'
end_image_file = '000193.png'

morph_faces(start_image_file, end_image_file)

In [None]:
start_image_file = '000112.png'
end_image_file = '000258.png'

morph_faces(start_image_file, end_image_file)

In [None]:
start_image_file = '000230.png'
end_image_file = '000712.png'


morph_faces(start_image_file, end_image_file)