In [None]:
from VisionEngine.utils.config import process_config
from VisionEngine.utils import factory

import os
from PIL import Image
from itertools import product
from dotenv import load_dotenv
from pathlib import Path

import numpy as np
import scipy

import numba

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from sklearn.metrics import pairwise_distances
from tensorflow.keras.applications import VGG16
from tensorflow.keras import Model

from openTSNE import TSNE
from openTSNE.sklearn import TSNE as sklTSNE
from openTSNE.callbacks import ErrorLogger

import math

import tensorflow as tf
from tensorflow.keras.layers import Flatten

In [None]:
class LikeLihoodLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(LikeLihoodLayer, self).__init__(**kwargs)
        self.model_input_shape = [256, 256, 3]

    def build(self, input_shape):
        super(LikeLihoodLayer, self).build(input_shape)

    def call(self, layer_inputs, **kwargs):
        inputs, outputs = layer_inputs
        mse = - tf.losses.mean_squared_error(inputs, outputs)
        out = 1./(tf.sqrt(2.*math.pi))*tf.exp(-.5*(mse)**2.)
        return out

        return [y_true, y_pred]
    def compute_output_shape(self, input_shape):
        return input_shape[0]

    def get_config(self):
        config = {}
        base_config = \
            super(LikeLihoodLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def sample_likelihood(x):
    inputs = Flatten()(model.model.input)
    outputs = Flatten()(model.model.output)
    out = LikeLihoodLayer()([inputs, outputs])
    lh_model = Model(model.model.input, out)
    return lh_model.predict(x)

def embed_images(x):
    outputs = [
        model.model.get_layer('variational_layer').output,
        model.model.get_layer('variational_layer_1').output,
        model.model.get_layer('variational_layer_2').output,
        model.model.get_layer('variational_layer_3').output
    ]
    encoder = Model(model.model.inputs, outputs)
    return encoder.predict(x)

def reconstruct_images(x):
    return model.model.predict(x)

def imscatter(x, y, image, ax=None, zoom=1):
    if ax is None:
        ax = plt.gca();
    try:
        image = plt.imread(image);
    except TypeError:
        # Likely already an array...
        pass
    x, y = np.atleast_1d(x, y);
    artists = [];
    for i, (x0, y0) in enumerate(zip(x, y)):
        im = OffsetImage(image[i], zoom=zoom);
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False);
        artists.append(ax.add_artist(ab));
    ax.update_datalim(np.column_stack([x, y]));
    ax.autoscale();
    ax.grid(False);
    return artists

def plot_im(img):
    if config.model.last_activation == 'tanh':
        img * 0.5 + 0.5
        return img
    else:
        return img

In [None]:
pwd

In [None]:
env_path = Path('../') / '.env'
load_dotenv(dotenv_path=env_path)

**Butterflies**

In [None]:
checkpoint_path = "/home/etheredge/Workspace/VisionEngine/checkpoints/butterflies_nouveau/2020-220-17/butterflies_nouveau.hdf5"

In [None]:
config_file = "/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/butterfly_nouveau_config.json"
config = process_config(config_file)

In [None]:
model = factory.create(
            "VisionEngine.models."+config.model.name
            )(config)

In [None]:
model.load(checkpoint_path)

In [None]:
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
z = embed_images(data_loader.get_test_data())
# lh = sample_likelihood(data_loader.get_test_data())

In [None]:
images_ = iter(data_loader.get_test_data())

In [None]:
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(321)
plt.imshow(plot_im(images[ID]))
plt.subplot(322)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(323)
plt.imshow(plot_im(images[ID]))
plt.subplot(324)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(325)
plt.imshow(plot_im(images[ID]))
plt.subplot(326)
plt.imshow(plot_im(x_hat[ID]))

In [None]:
np.concatenate([z[0],z[1],z[2],z[3]], axis=1).shape

In [None]:
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=8, exaggeration=4, learning_rate=len(z)/12).fit(np.concatenate([z[0],z[1],z[2],z[3]], axis=1))
h1 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[0])
h2 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[1])
h3 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[2])
h4 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[3])

In [None]:
labels = []
images = []
for image, label in data_loader.get_plot_data():
    labels.append(label.numpy().decode('utf8'))
    images.append(image.numpy().astype('uint8'))
images = np.stack(images)
labels = np.array(labels)

In [None]:
plt.figure(figsize=(40,10))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(141)
embedding = h1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(142)
embedding = h2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(143)
embedding = h3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(144)
embedding = h4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(10,10))
classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

embedding = vision_engine_embedding

plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = vision_engine_embedding
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
# fig = plt.gcf()
# fig.savefig('what_real.png')

**GAN Generated Guppies**

In [None]:
checkpoint_path = "/home/etheredge/Workspace/VisionEngine/checkpoints/guppy_nouveau/2020-223-20/guppy_nouveau.hdf5"

In [None]:
config_file = "/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/guppy_nouveau_config.json"
config = process_config(config_file)

In [None]:
model.load(checkpoint_path)

In [None]:
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
z = embed_images(data_loader.get_test_data())
# lh = sample_likelihood(data_loader.get_test_data())

In [None]:
images_ = iter(data_loader.get_test_data())

In [None]:
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(321)
plt.imshow(plot_im(images[ID]))
plt.subplot(322)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(323)
plt.imshow(plot_im(images[ID]))
plt.subplot(324)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(325)
plt.imshow(plot_im(images[ID]))
plt.subplot(326)
plt.imshow(plot_im(x_hat[ID]))

In [None]:
np.concatenate([z[0],z[1],z[2],z[3]], axis=1).shape

In [None]:
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=8, exaggeration=4, learning_rate=len(z)/12).fit(np.concatenate([z[0],z[1],z[2],z[3]], axis=1))
h1 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[0])
h2 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[1])
h3 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[2])
h4 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[3])

In [None]:
labels = []
images = []
for image, label in data_loader.get_plot_data():
    labels.append(label.numpy().decode('utf8'))
    images.append(image.numpy().astype('uint8'))
images = np.stack(images)
labels = np.array(labels)

In [None]:
plt.figure(figsize=(40,10))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(141)
embedding = h1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(142)
embedding = h2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(143)
embedding = h3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(144)
embedding = h4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(10,10))
classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

embedding = vision_engine_embedding

plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = vision_engine_embedding
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
# fig = plt.gcf()
# fig.savefig('what_real.png')

**GAN Generated Guppies : $\sigma = 1e-2$**

In [None]:
checkpoint_path = "/home/etheredge/Workspace/VisionEngine/checkpoints/guppy_nouveau/2020-224-14/guppy_nouveau.hdf5"

In [None]:
config_file = "/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/guppy_nouveau_config_singlemmd.json"
config = process_config(config_file)

In [None]:
model.load(checkpoint_path)

In [None]:
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
z = embed_images(data_loader.get_test_data())
# lh = sample_likelihood(data_loader.get_test_data())

In [None]:
images_ = iter(data_loader.get_test_data())

In [None]:
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(321)
plt.imshow(plot_im(images[ID]))
plt.subplot(322)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(323)
plt.imshow(plot_im(images[ID]))
plt.subplot(324)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(325)
plt.imshow(plot_im(images[ID]))
plt.subplot(326)
plt.imshow(plot_im(x_hat[ID]))

In [None]:
np.concatenate([z[0],z[1],z[2],z[3]], axis=1).shape

In [None]:
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=8, exaggeration=4, learning_rate=len(z)/12).fit(np.concatenate([z[0],z[1],z[2],z[3]], axis=1))
h1 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[0])
h2 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[1])
h3 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[2])
h4 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[3])

In [None]:
labels = []
images = []
for image, label in data_loader.get_plot_data():
    labels.append(label.numpy().decode('utf8'))
    images.append(image.numpy().astype('uint8'))
images = np.stack(images)
labels = np.array(labels)

In [None]:
plt.figure(figsize=(40,10))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(141)
embedding = h1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(142)
embedding = h2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(143)
embedding = h3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(144)
embedding = h4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(10,10))
classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

embedding = vision_engine_embedding

plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = vision_engine_embedding
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
# fig = plt.gcf()
# fig.savefig('what_real.png')

**Real Guppies - Finetune**

In [None]:
checkpoint_path = "/home/etheredge/Workspace/VisionEngine/checkpoints/guppy_nouveau_finetune/2020-224-11/guppy_nouveau_finetune.hdf5"

In [None]:
config_file = "/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/guppy_nouveau_finetune_config.json"
config = process_config(config_file)

In [None]:
model.load(checkpoint_path)

In [None]:
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
z = embed_images(data_loader.get_test_data())
# lh = sample_likelihood(data_loader.get_test_data())

In [None]:
images_ = iter(data_loader.get_test_data())

In [None]:
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(321)
plt.imshow(plot_im(images[ID]))
plt.subplot(322)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(323)
plt.imshow(plot_im(images[ID]))
plt.subplot(324)
plt.imshow(plot_im(x_hat[ID]))
images = images_.next()[0]
x_hat = reconstruct_images(images)
ID = 2
plt.subplot(325)
plt.imshow(plot_im(images[ID]))
plt.subplot(326)
plt.imshow(plot_im(x_hat[ID]))

In [None]:
np.concatenate([z[0],z[1],z[2],z[3]], axis=1).shape

In [None]:
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=8, exaggeration=4, learning_rate=len(z)/12).fit(np.concatenate([z[0],z[1],z[2],z[3]], axis=1))
h1 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[0])
h2 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[1])
h3 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[2])
h4 = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit(z[3])

In [None]:
labels = []
images = []
for image, label in data_loader.get_plot_data():
    labels.append(label.numpy().decode('utf8'))
    images.append(image.numpy().astype('uint8'))
images = np.stack(images)
labels = np.array(labels)

In [None]:
plt.figure(figsize=(40,10))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(141)
embedding = h1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(142)
embedding = h2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(143)
embedding = h3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(144)
embedding = h4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(10,10))
classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

embedding = vision_engine_embedding

plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = h1
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
# fig = plt.gcf()
# fig.savefig('what_real.png')

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = h1
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
fig = plt.gcf()
fig.savefig('what_h1_real.png')

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = h3
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
fig = plt.gcf()
fig.savefig('what_h3.png')

In [None]:
plt.figure(figsize=(80,80))
# plt.scatter(vision_engine_embedding[:,0],vision_engine_embedding[:,1],c=lh, s=10000)
# plt.colorbar()
embedding = h4
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.50);
fig = plt.gcf()
fig.savefig('what_h4_real.png')

In [None]:
ranked_lh_args = np.argsort(lh)[::-1]

In [None]:
lh[ranked_lh_args[:10]]

In [None]:
plt.imshow(images[ranked_lh_args[6]])

In [None]:
pwd

In [None]:
image_output_folder = 'report_wppvae_gens/figures/images/{}'.format(checkpoint_path.split('/')[7])
plot_output_folder = 'report_wppvae_gens/figures/panels{}'.format(checkpoint_path.split('/')[7])
n_latents = 4 
latent_size = 10
Path(image_output_folder).mkdir(parents=True, exist_ok=True)
Path(plot_output_folder).mkdir(parents=True, exist_ok=True)

In [None]:
pwd

In [None]:
np.max(z[1])

In [None]:
def make_rand_samples(model, n_samples=9, num_steps=300, mu=0., sigma=1.):
    output_folder = os.path.join(image_output_folder, 'explore_latents/random_normal/frames')
    Path(output_folder).mkdir(parents=True, exist_ok=True)
    sample =  [
        np.random.multivariate_normal([mu] * latent_size,np.diag([sigma] * latent_size), n_samples)
        ] * n_latents

    for z in range(n_latents):
        for t in range(num_steps):
            sample[z] = np.random.multivariate_normal(
                [mu] * latent_size, np.diag([sigma] * latent_size), n_samples)
            generated = model.get_layer('decoder').predict(sample, batch_size=10)
            generated = generated.reshape((n_samples, 256, 256,3))
            image_container = Image.new('RGB', (256*3,256*3))
            locs = list(product(range(int(np.sqrt(n_samples))),range(int(np.sqrt(n_samples)))))
            for i in range(n_samples):
                img = generated[i]
                j, k = locs[i]
                img = 255 * np.array(img)
                img = img.astype(np.uint8)
                image_container.paste(Image.fromarray(img.astype('uint8')), (k*256, j*256))
            image_container.save(os.path.join(output_folder,'z{}_{:03d}.png'.format(z,t)))


def make_traversal_from_zeros(model, n_samples=1, num_steps=11):
    output_folder = os.path.join(image_output_folder, 'explore_latents/traversal')
    Path(output_folder).mkdir(parents=True, exist_ok=True)
    multipliers = np.linspace(-3,3,num=num_steps)

    for z_i in range(4):
        image_container = Image.new('RGB', (256*num_steps,256*latent_size))
        for z_i_j in range(latent_size):
            for s in range(num_steps):
                sample = [np.array([[0] * latent_size]),
                        np.array([[0] * latent_size]),
                        np.array([[0] * latent_size]),
                        np.array([[0] * latent_size])]
                
                sample[z_i][0][z_i_j] = multipliers[s]
                generated = model.get_layer('decoder').predict(sample, batch_size=1)
                generated = generated.reshape((256, 256,3))
                img = 255 * np.array(generated)
                img = img.astype(np.uint8)
                image_container.paste(Image.fromarray(img.astype('uint8')), (s*256, z_i_j*256))
        image_container.save(os.path.join(output_folder,'z{}.png'.format(z_i)))


def make_traversal_from_sample(model, z, n_samples=1, num_steps=11, sample_id=0):
    output_folder = os.path.join(image_output_folder, 'explore_latents/traversal')
    Path(output_folder).mkdir(parents=True, exist_ok=True)
    multipliers = np.linspace(-3,3,num=num_steps)
    encoded_sample = [z_i[sample_id] for z_i in z]

    for z_i in range(4):
        image_container = Image.new('RGB', (256*num_steps,256*latent_size))
        for z_i_j in range(latent_size):
            for s in range(num_steps):
                sample = [np.array([encoded_sample[0]]),
                      np.array([encoded_sample[1]]),
                      np.array([encoded_sample[2]]),
                      np.array([encoded_sample[3]])]
                
                sample[z_i][0][z_i_j] = multipliers[s]
                generated = model.get_layer('decoder').predict(sample, batch_size=1)
                generated = generated.reshape((256, 256, 3))
                img = 255 * np.array(generated)
                img = img.astype(np.uint8)
                image_container.paste(Image.fromarray(img.astype('uint8')), (s*256, z_i_j*256))
        image_container.save(os.path.join(output_folder,'{}sample{}.png'.format(sample_id, z_i)))


In [None]:
make_rand_samples(model.model)

In [None]:
make_traversal_from_sample(model.model, z, sample_id=1350)

In [None]:
make_traversal_from_zeros(model.model)

In [None]:
vision_engine_embedding1 = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[0])/12, exaggeration=4).fit(np.array(z[0]))
vision_engine_embedding2 = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[1])/12, exaggeration=4).fit(np.array(z[1]))
vision_engine_embedding3 = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[2])/12, exaggeration=4).fit(np.array(z[2]))
vision_engine_embedding4 = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[3])/12, exaggeration=4).fit(np.array(z[3]))

In [None]:
plt.figure(figsize=(80,40))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(241)
embedding = vision_engine_embedding1
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.40);


plt.subplot(242)
embedding = vision_engine_embedding2
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.40);

plt.subplot(243)
embedding = vision_engine_embedding3
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.40);

plt.subplot(244)
embedding = vision_engine_embedding4
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.40);

plt.subplot(245)
embedding = vision_engine_embedding1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(246)
embedding = vision_engine_embedding2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(247)
embedding = vision_engine_embedding3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(248)
embedding = vision_engine_embedding4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

# fig = plt.gcf()
# fig.savefig(os.path.join(
#     plot_output_folder,
#     'Zs_{}_real{}_gen{}.pdf'.format(
#         checkpoint_path.split('/')[7],
#         config.data_loader.use_real,
#         config.data_loader.use_generated
#     )
# )
#            )
# fig.savefig(os.path.join(
#     plot_output_folder,
#     'Zs_{}_real{}_gen{}.png'.format(
#         checkpoint_path.split('/')[7],
#         config.data_loader.use_real,
#         config.data_loader.use_generated
#     )
# )
#            )
# plt.clf()

In [None]:
def make_perceptual_loss_model(input_shape, layers=[13]):
    loss_model = VGG16(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape)
    loss_model.trainable = False
    for layer in loss_model.layers:
        layer.trainable = False
    loss_layers = [ loss_model.layers[i].output
                for i in layers
                ]
    return Model(loss_model.inputs,loss_layers)

# @numba.jit(nopython=True, parallel=True)
# def calculate_perceptual_distances(X):
#     dists = np.zeros((989,989))
#     for layer in X:
#         layer.shape
#         for i in range(layer.shape[0]):
#             for j in range(layer.shape[0]):
#                 shape = (layer[i].shape[0]*layer[i].shape[1]*layer[i].shape[2])
#                 sqr = np.square(layer[i] - layer[j])
#                 sm = np.sum(sqr)
#                 val = sm / shape
#                 dists[i,j] =+ val
#     return dists


@numba.jit(nopython=True, parallel=True)
def calculate_perceptual_distances(X):
    norm_dists = np.zeros((len(X[0]),len(X)))
    for i in range(len(X)):
        for j in range(len(X[0])):
            norm_dists[j, i] = np.linalg.norm(X[i][j].flatten())
    return norm_dists

In [None]:
perceptual_model = make_perceptual_loss_model((256,256,3))

In [None]:
perception = perceptual_model.predict(data_loader.get_test_data()[0], batch_size=5)
perceptual_dists = calculate_perceptual_distances(perception)

In [None]:
raw_image_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[0])/12, exaggeration=4).fit(images.flatten())
perceptual_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[0])/12, exaggeration=4).fit(perceptual_dists)
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=-1, learning_rate=len(z[3])/12, exaggeration=4).fit(np.concatenate([z[0], z[1], z[2], z[3]]))

In [None]:
classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.figure(figsize=(40,40))

plt.subplot(231)
embedding = raw_image_embedding
imscatter(embedding[:, 0], embedding[:, 1], data_loader.get_test_data()[0], zoom=0.40);
plt.subplot(234)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(232)
embedding = perceptual_embedding
imscatter(embedding[:, 0], embedding[:, 1], data_loader.get_test_data()[0], zoom=0.40);
plt.subplot(235)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(233)
embedding = vision_engine_embedding
classnames, indices = np.unique(y_train, return_inverse=True)
plt.subplot(236)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=indices, cmap=cmap, norm=norm, s=400)

fig = plt.gcf()
fig.savefig(os.path.join(
    plot_output_folder,
    'compare_methods_{}.pdf'.format(
        checkpoint_path.split('/')[7]
    )
)
fig.savefig(os.path.join(
    plot_output_folder,
    'compare_methods_{}.png'.format(
        checkpoint_path.split('/')[7]
    )
)
plt.clf()