In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
  except RuntimeError as e:
    print(e)

In [2]:
import numpy as np 
import tensorflow as tf
from tensorflow.keras.initializers import he_normal
from tensorflow.keras.layers import BatchNormalization, Conv2D, Conv2DTranspose


#########################
#        ENCODER        #
#########################

class Encoder(tf.keras.Model):

  def __init__(self, latent_dim):
    
    super(Encoder, self).__init__()

    self.enc_block_1 = Conv2D( 
                        filters=32, 
                        kernel_size=3, 
                        strides=(2, 2), 
                        padding = 'same',
                        kernel_initializer=he_normal())
    
    
    self.enc_block_2 = Conv2D( 
                  filters=64, 
                  kernel_size=3, 
                  strides=(2, 2), 
                  padding = 'same',
                  kernel_initializer=he_normal())
    
    
    
    self.enc_block_3 = Conv2D( 
                  filters=128, 
                  kernel_size=3, 
                  strides=(2, 2), 
                  padding = 'same',
                  kernel_initializer=he_normal())
            

    
    self.enc_block_4 = Conv2D( 
                  filters=256, 
                  kernel_size=3, 
                  strides=(2, 2), 
                  padding = 'same',
                  kernel_initializer=he_normal())
    
    self.flatten = tf.keras.layers.Flatten()
    self.dense = tf.keras.layers.Dense(latent_dim + latent_dim)  


  def __call__(self, conditional_input, latent_dim, is_train):
     # Encoder block 1
    x = self.enc_block_1(conditional_input)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Encoder block 2
    x = self.enc_block_2(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Encoder block 3
    x = self.enc_block_3(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Encoder block 4
    x = self.enc_block_4(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)   

    x = self.dense(self.flatten(x))

    return x



#########################
#        DECODER        #
#########################

class Decoder(tf.keras.Model):
    

  def __init__(self, batch_size = 32):

    super(Decoder, self).__init__()

    self.batch_size = batch_size
    self.dense = tf.keras.layers.Dense(4*4*self.batch_size*8)
    self.reshape = tf.keras.layers.Reshape(target_shape=(4, 4, self.batch_size*8))

    self.dec_block_1 = Conv2DTranspose(
            filters=256,
            kernel_size=3,
            strides=(2, 2),
            padding='same',
            kernel_initializer=he_normal())

    self.dec_block_2 = Conv2DTranspose(
            filters=128,
            kernel_size=3,
            strides=(2, 2),
            padding='same',
            kernel_initializer=he_normal())
        
    self.dec_block_3 = Conv2DTranspose(
            filters=64,
            kernel_size=3,
            strides=(2, 2),
            padding='same',
            kernel_initializer=he_normal())

    self.dec_block_4 = Conv2DTranspose(
            filters=32,
            kernel_size=3,
            strides=(2, 2),
            padding='same',
            kernel_initializer=he_normal())

    self.dec_block_5 = Conv2DTranspose(
            filters=3, 
            kernel_size=3, 
            strides=(1, 1), 
            padding='same',
            kernel_initializer=he_normal())

  def __call__(self, z_cond, is_train):
    # Reshape input
    x = self.dense(z_cond)
    x = tf.nn.leaky_relu(x)
    x = self.reshape(x)
    # Decoder block 1
    x = self.dec_block_1(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Decoder block 2
    x = self.dec_block_2(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Decoder block 3
    x = self.dec_block_3(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)
    # Decoder block 4
    x = self.dec_block_4(x)
    x = BatchNormalization(trainable = is_train)(x)
    x = tf.nn.leaky_relu(x)

    return self.dec_block_5(x)



#########################
#       Conv-CVAE       #
#########################

class ConvCVAE (tf.keras.Model) :

    def __init__(self, 
        encoder,
        decoder,
        label_dim,
        latent_dim,
        batch_size = 32,
        beta = 1,
        image_dim = [64, 64, 3]):

        super(ConvCVAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.label_dim = label_dim
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.beta = beta = 1
        self.image_dim = image_dim = [64, 64, 3]              


    def __call__(self, inputs, is_train):
    
        input_img, input_label, conditional_input = self.conditional_input(inputs)

        z_mean, z_log_var = tf.split(self.encoder(conditional_input, self.latent_dim, is_train), num_or_size_splits=2, axis=1)    
        z_cond = self.reparametrization(z_mean, z_log_var, input_label)
        logits = self.decoder(z_cond, is_train)

        recon_img = tf.nn.sigmoid(logits)

        # Loss computation #
        latent_loss = - 0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1) # KL divergence
        reconstr_loss = np.prod((64,64)) * tf.keras.losses.binary_crossentropy(tf.keras.backend.flatten(input_img), tf.keras.backend.flatten(recon_img)) # over weighted MSE
        loss = reconstr_loss + self.beta * latent_loss # weighted ELBO loss
        loss = tf.reduce_mean(loss) 

        return {
                    'recon_img': recon_img,
                    'latent_loss': latent_loss,
                    'reconstr_loss': reconstr_loss,
                    'loss': loss,
                    'z_mean': z_mean,
                    'z_log_var': z_log_var
                }


    def conditional_input(self, inputs):
        """ Builds the conditional input and returns the original input images, their labels and the conditional input."""

        input_img = tf.keras.layers.InputLayer(input_shape=self.image_dim, dtype = 'float32')(inputs[0])
        input_label = tf.keras.layers.InputLayer(input_shape=(self.label_dim,), dtype = 'float32')(inputs[1])
        labels = tf.reshape(inputs[1], [-1, 1, 1, self.label_dim]) #batch_size, 1, 1, label_size
        ones = tf.ones([inputs[0].shape[0]] + self.image_dim[0:-1] + [self.label_dim]) #batch_size, 64, 64, label_size
        labels = ones * labels #batch_size, 64, 64, label_size
        conditional_input = tf.keras.layers.InputLayer(input_shape=(self.image_dim[0], self.image_dim[1], self.image_dim[2] + self.label_dim), dtype = 'float32')(tf.concat([inputs[0], labels], axis=3))

        return input_img, input_label, conditional_input


    def reparametrization(self, z_mean, z_log_var, input_label):
        """ Performs the riparametrization trick"""

        eps = tf.random.normal(shape = (input_label.shape[0], self.latent_dim), mean = 0.0, stddev = 1.0)       
        z = z_mean + tf.math.exp(z_log_var * .5) * eps
        z_cond = tf.concat([z, input_label], axis=1) # (batch_size, label_dim + latent_dim)

        return z_cond


In [3]:
import cv2
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import os
import pickle
import numpy as np
import tensorflow as tf



#######################
# Train Step Function #
#######################


def train_step(data, model, optimizer):


    with tf.GradientTape() as tape:

        model_output = model(data, is_train = True)

    trainable_variables = model.trainable_variables
    grads = tape.gradient(model_output['loss'], trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    total_loss = model_output['loss'].numpy().mean()
    recon_loss = model_output['reconstr_loss'].numpy().mean()
    latent_loss = model_output['latent_loss'].numpy().mean()

    return total_loss, recon_loss, latent_loss




##################################
# Encoding and Decoding methods  #
##################################


def encode(self, inputs, label):
    """ Encodes the input into the latent space."""
    return self.sess.run(self.z_mean, feed_dict={self.x: inputs, self.y: label})


def decode(self, label, z = None):
    """ 
    Generates data starting from the point z in the latent space.
    If z is None, z is drawn from prior in latent space.
    """
    if z is None:
        z = 0.0 + np.random.randn(self.batch_size, self.latent_dim) * 0.75
    return self.sess.run(self.generated_image, feed_dict={self.z_sample_3: z, self.y: label})




########################
#  Utils for plotting  #
########################


def batch_generator(batch_dim, test_labels, model_name):
    """
    Batch generator using the given list of labels.
    """
    while True:
        batch_imgs = []
        labels = []
        for label in (test_labels):
            labels.append(label)
            if len(labels) == batch_dim:
                batch_imgs = create_image_batch(labels, model_name)
                batch_labels = [x[1] for x in labels]
                yield np.asarray(batch_imgs), np.asarray(batch_labels)
                batch_imgs = []
                labels = []
                batch_labels = []
        if batch_imgs:
            yield np.asarray(batch_imgs), np.asarray(batch_labels)


def get_image(image_path, model_name, img_size = 128, img_resize = 64, x = 25, y = 45):
    """
    Crops, resizes and normalizes the target image.
        - If model_name == Dense, the image is returned as a flattened numpy array with dim (64*64*3)
        - otherwise, the image is returned as a numpy array with dim (64,64,3)
    """

    img = cv2.imread(image_path)
    img = img[y:y+img_size, x:x+img_size]
    img = cv2.resize(img, (img_resize, img_resize))
    img = np.array(img, dtype='float32')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img /= 255.0 # Normalization to [0.,1.]

    if model_name == "Dense" :
        img = img.ravel()
    
    return img


def create_image_batch(labels, model_name):
    """
    Returns the list of images corresponding to the given labels.
    """
    imgs = []
    imgs_id = [item[0] for item in labels]

    for i in imgs_id:
        image_path ='/input/CelebA/img_align_celeba/img_align_celeba/' + i
        imgs.append(get_image(image_path, model_name))

    return imgs



def convert_batch_to_image_grid(image_batch, dim = 64):
    reshaped = (image_batch.reshape(4, 8, dim, dim, 3)
              .transpose(0, 2, 1, 3, 4)
              .reshape(4 * dim, 8 * dim, 3))
    return reshaped 


def imshow_grid(imgs, model_name, shape=[2, 5], name='default', save=False):
    """Plot images in a grid of a given shape."""
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
    size = shape[0] * shape[1]
    if model_name == "Dense":
        for i in range(size):
            grid[i].axis('off')
            grid[i].imshow(imgs[i].reshape(64, 64, 3))  
        if save:
            plt.savefig(str(name) + '.png')
            plt.clf()
        else:
            plt.show()
    else:
        for i in range(size):
            grid[i].axis('off')
            grid[i].imshow(imgs[i])  
        if save:
            plt.savefig(str(name) + '.png')
            plt.clf()
        else:
            plt.show()


##########################################
#   Utils to save and read pickle files  #
##########################################

def save_data(file_name, data):
    """
    Saves data on file_name.pickle.
    """
    with open((file_name+'.pickle'), 'wb') as openfile:
        print(type(data))
        pickle.dump(data, openfile)


def read_data(file_name):
    """
    Reads file_name.pickle and returns its content.
    """
    with (open((file_name+'.pickle'), "rb")) as openfile:
        while True:
            try:
                objects=pickle.load(openfile)
            except EOFError:
                break
    return objects


In [4]:
import tensorflow as tf
import timeit