In [None]:
import pandas as pd
import os
# import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf; tf.compat.v1.disable_eager_execution()
from keras import backend as K
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Lambda, Reshape
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.datasets import mnist
# np.random.seed(25)
# tf.executing_eagerly()

In [None]:
# read in 55000 x 25 matrices if 25 nearest in class neighbors and their norms
nearest_neighbors_in_class = pd.read_csv("Outputs/NearestNeighbors/digits/nearest_neighbors_in_class.csv", sep = ',', header=None).to_numpy()
nearest_neighbors_in_class_norms = pd.read_csv("Outputs/NearestNeighbors/digits/nearest_neighbors_in_class_norms.csv", sep = ',', header=None).to_numpy()
nearest_neighbors_other_class = pd.read_csv("Outputs/NearestNeighbors/digits/nearest_neighbors_other_class.csv", sep = ',', header=None).to_numpy()
nearest_neighbors_other_class_norms = pd.read_csv("Outputs/NearestNeighbors/digits/nearest_neighbors_other_class_norms.csv", sep = ',', header=None).to_numpy()

In [None]:
print(nearest_neighbors_other_class.shape)

In [None]:
####################################################################################
## This cell is for selecting the dataset --- Digits or Fashion (MNIST toy world) ##
####################################################################################

DATASET = '_DigitMNIST' 
# DATASET = '_FashionMNIST'
REGIME = '_TRAINED_' # '_RANDOM_'

if DATASET == '_DigitMNIST':
    # Load the digit-mnist pre-shuffled train data and test data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() #digit_mnist
    print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)

    # Define the text labels
    labels = ["0",  # index 0
                            "1",  # index 1
                            "2",  # index 2 
                            "3",  # index 3 
                            "4",  # index 4
                            "5",  # index 5
                            "6",  # index 6 
                            "7",  # index 7 
                            "8",  # index 8 
                            "9"]  # index 9

else:
    # Load the fashion-mnist pre-shuffled train data and test data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() #fashion_mnist
    print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)

    # Define the text labels
    labels = ["T-shirt",  # index 0
                            "Trouser",      # index 1
                            "Pullover",     # index 2 
                            "Dress",        # index 3 
                            "Coat",         # index 4
                            "Sandal",       # index 5
                            "Shirt",        # index 6 
                            "Sneaker",      # index 7 
                            "Bag",          # index 8 
                            "Ankle boot"]   # index 9

# Print training set shape - note there are 60,000 training data of image size of 28x28, 60,000 train labels)
print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)

# save train labels
y_train_labels = y_train
y_test_labels = y_test

# Print the number of training and test datasets
print(x_train.shape[0], 'train set')
print(x_test.shape[0], 'test set')


x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Further break training data into train / validation sets (# put 5000 into validation set and keep remaining 55,000 for train)
(x_train, x_valid) = x_train[5000:], x_train[:5000] 
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# Reshape input data from (28, 28) to (28, 28, 1)
w, h = 28, 28
x_train = x_train.reshape(x_train.shape[0], w, h, 1)
x_valid = x_valid.reshape(x_valid.shape[0], w, h, 1)
x_test = x_test.reshape(x_test.shape[0], w, h, 1)

# Validation set
y_valid = tf.keras.utils.to_categorical(y_valid, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
    
# Image index, you can pick any number between 0 and 59,999
img_index = 0
# y_train contains the lables, ranging from 0 to 9
label_index = y_train[img_index]
# Print the label, for example 2 Pullover
print ("y = " + str(label_index) + " " +(labels[label_index]))
# # Show one of the images from the training dataset
plt.imshow(x_train[img_index])
plt.show()

In [None]:
# Convert from (no_of_data, 28, 28) to (no_of_data, 28, 28, 1)
X_train_new = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
X_test_new = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)

print(x_train.shape)

In [None]:
############################################# Standard boilerplate ##########################################
# citation: https://becominghuman.ai/using-variational-autoencoder-vae-to-generate-new-images-14328877e88d
#############################################################################################################

img_height   = X_train_new.shape[1]    # 28
img_width    = X_train_new.shape[2]    # 28
num_channels = X_train_new.shape[3]    # 1
input_shape =  (img_height, img_width, num_channels)   # (28,28,1)
latent_dim = 6    # Dimension of the latent space

# Encoder
encoder_input = Input(shape=input_shape)
encoder_conv = Conv2D(filters=8, kernel_size=3, strides=2, 
                padding='same', activation='relu')(encoder_input)
encoder_conv = Conv2D(filters=16, kernel_size=3, strides=2, 
                padding='same', activation='relu')(encoder_input)
encoder = Flatten()(encoder_conv)

mu = Dense(latent_dim)(encoder)
sigma = Dense(latent_dim)(encoder)

def compute_latent(x):
    mu, sigma = x
    batch = K.shape(mu)[0]
    dim = K.int_shape(mu)[1]
    eps = K.random_normal(shape=(batch,dim))
    return mu + K.exp(sigma/2)*eps

latent_space = Lambda(compute_latent, output_shape=(latent_dim,))([mu, sigma])
conv_shape = K.int_shape(encoder_conv)

# Decoder
decoder_input = Input(shape=(latent_dim,))
decoder = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
decoder = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(decoder)
decoder_conv = Conv2DTranspose(filters=16, kernel_size=3, strides=2, 
                           padding='same', activation='relu')(decoder)
decoder_conv = Conv2DTranspose(filters=8, kernel_size=3, strides=2, 
                           padding='same', activation='relu')(decoder)
decoder_conv =  Conv2DTranspose(filters=num_channels, kernel_size=3, 
                          padding='same', activation='sigmoid')(decoder_conv)

# Connect encoder and decoder
encoder = Model(encoder_input, latent_space)
decoder = Model(decoder_input, decoder_conv)

# VAE
vae = Model(encoder_input, decoder(encoder(encoder_input)))

In [None]:
def kl_reconstruction_loss(true, pred):
    # Reconstruction loss
    reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
    # KL divergence loss
    kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    # Total loss = 50% rec + 50% KL divergence loss
    return K.mean(reconstruction_loss + kl_loss)

vae.compile(optimizer='adam', loss=kl_reconstruction_loss)

In [None]:
history = vae.fit(x=X_train_new, y=X_train_new, epochs=20, batch_size=32, validation_data=(X_test_new,X_test_new))

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])

In [None]:
#########################################################################################################
##### ###################for generating morphs between vectors in the latent space ######################
#########################################################################################################

#==> generalize this for arbitrary dimensionality of z space
#==> add titles, and path vectors through the space
#==> Don't use linspace!!!!

def display_image_sequence(s1,s2,s3,s4,s5,s6,e1,e2,e3,e4,e5,e6,no_of_imgs,s_im,e_im):
    d1 = np.linspace(s1,e1,no_of_imgs)
    d2 = np.linspace(s2,e2,no_of_imgs)
    d3 = np.linspace(s3,e3,no_of_imgs)    
    d4 = np.linspace(s4,e4,no_of_imgs)
    d5 = np.linspace(s5,e5,no_of_imgs)
    d6 = np.linspace(s6,e6,no_of_imgs)
    
    d1 = d1[:, np.newaxis]
    d2 = d2[:, np.newaxis]
    d3 = d3[:, np.newaxis]
    d4 = d4[:, np.newaxis]    
    d5 = d5[:, np.newaxis]
    d6 = d6[:, np.newaxis]
    
    new_points = np.hstack((d1,d2,d3,d4,d5,d6))
    new_images = decoder.predict(new_points)
    new_images = new_images.reshape(new_images.shape[0], new_images.shape[1], new_images.shape[2])
    
    # Display some images
    fig, axes = plt.subplots(ncols=no_of_imgs+2, sharex=False,
                             sharey=True, figsize=(20, 7))
    counter = 1
    axes[0].imshow(s_im, cmap='gray')
    for i in range(0,no_of_imgs):        
        axes[counter].imshow(new_images[i], cmap='gray')
        axes[counter].get_xaxis().set_visible(False)
        axes[counter].get_yaxis().set_visible(False)
        counter += 1
    axes[counter].imshow(e_im, cmap='gray')
    plt.show()

In [None]:
CLASS = 2
rank = pd.read_csv("Outputs/Seed_diff_ranks/ranked_" + str(labels[CLASS]) + "s.csv", sep=',', header=None).to_numpy().squeeze()
print(rank.shape)
    
# get index of least confusing 0, take the first entry of the 0 rankings
image_idx = rank[5000]
NN_other_class = nearest_neighbors_other_class[image_idx, 0]
NN_other_class_norm = nearest_neighbors_other_class_norms[image_idx, 0]
orig = X_train_new[image_idx,:,:,:]

plt.imshow(orig, cmap='gray')
plt.show()
plt.imshow(X_train_new[NN_other_class,:,:,:], cmap='gray')
plt.show()

# Original images ordered by rank (smaller indices are more confusable)
origs = encoder.predict(X_train_new[rank,:,:,:])
# nearest neighbor examples, also ordered by rank. Take first nearest neighbor
knns_out = encoder.predict(X_train_new[nearest_neighbors_other_class[rank, 0],:,:,:])

# frames
fn = 20
# Generate some morphs 
for rnk in range(5000,len(rank)):
    image_idx = rank[rnk]
    NN_other_class = nearest_neighbors_other_class[image_idx, 0]
    NN_other_class_norm = nearest_neighbors_other_class_norms[image_idx, 0]
    # original image, and nearest neighbor (out of class)
    orig_im = X_train_new[image_idx,:,:,:]
    knn_im = X_train_new[NN_other_class,:,:,:]
    # generate morph (linear interpolation) in latent space, and synthesize images
    display_image_sequence(origs[rnk][0],origs[rnk][1],origs[rnk][2],origs[rnk][3],origs[rnk][4],origs[rnk][5],
                                 knns_out[rnk][0],knns_out[rnk][1],knns_out[rnk][2],knns_out[rnk][3],knns_out[rnk][4],knns_out[rnk][5],
                                 fn,orig_im,knn_im)