In [None]:
from VisionEngine.utils.config import process_config
from VisionEngine.utils import factory
from VisionEngine.utils.eval import (embed_images, 
                                     reconstruct_images,
                                     reconstruct_images, 
                                     sample_likelihood)

from VisionEngine.utils.plotting import imscatter

from VisionEngine.utils.perceptual_loss import (make_perceptual_loss_model,
                                                calculate_perceptual_distances)

from VisionEngine.utils.disentanglement_score import dissentanglement_score

from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder

import os
from dotenv import load_dotenv
from pathlib import Path

import numpy as np
import tensorflow as tf

from openTSNE import TSNE
from openTSNE.callbacks import ErrorLogger

import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
# If you only have one GPU, this must = 0
GPU = 0

In [None]:
env_path = Path('../') / '.env'
load_dotenv(dotenv_path=env_path)

In [None]:
# DHRL-Trained Guppies (Original Data)
checkpoint_path = os.path.join(
    os.getenv("VISIONENGINE_HOME"),
    "checkpoints/guppies_DHRL_model.hdf5"
)
config_file = os.path.join(
    os.getenv("VISIONENGINE_HOME"),
    "VisionEngine/configs/guppies_DHRL_config.json"
)

In [None]:
config = process_config(config_file)

# need to change a few config values
config.data_loader.shuffle = False
config.data_loader.use_generated = False
config.data_loader.use_real = True

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    model = factory.create(
                "VisionEngine.models."+config.model.name
                )(config)

In [None]:
model.load(checkpoint_path)
model.trainable = False

In [None]:
data_loader = factory.create(
            "VisionEngine.data_loaders."+config.data_loader.name
            )(config)

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    z = embed_images(data_loader.get_test_data(), model)
    lh = sample_likelihood(data_loader.get_test_data(), model)
    lh = (lh-tf.math.reduce_mean(lh))/tf.math.reduce_std(lh)
    images_ = iter(data_loader.get_test_data())
    images =  np.stack([image[0].numpy() for image in data_loader.get_plot_data()])
    images = images.reshape(len(images),256*256*4)

**Visualize Reconstructions**

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    images = images_.next()[0]
    x_hat = reconstruct_images(images)
    ID = 2
    plt.subplot(321)
    plt.imshow(plot_im(images[ID]))
    plt.subplot(322)
    plt.imshow(plot_im(x_hat[ID]))
    images = images_.next()[0]
    x_hat = reconstruct_images(images)
    ID = 2
    plt.subplot(323)
    plt.imshow(plot_im(images[ID]))
    plt.subplot(324)
    plt.imshow(plot_im(x_hat[ID]))
    images = images_.next()[0]
    x_hat = reconstruct_images(images)
    ID = 2
    plt.subplot(325)
    plt.imshow(plot_im(images[ID]))
    plt.subplot(326)
    plt.imshow(plot_im(x_hat[ID]))

**Visualize Sample Likelihood**

In [None]:
plt.figure(figsize=(20,10))

cmap = plt.cm.viridis

embedding = vision_engine_embedding
plt.subplot(121)
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.15);

plt.subplot(122)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.5,
        c=lh, cmap=cmap, s=400, rasterized=True)

plt.colorbar()

In [None]:
perception = []
with tf.device('/device:cpu:0'):
    perceptual_model = make_perceptual_loss_model((256,256,3))
    for batch in data_loader.get_test_data().batch(16):
        perception.extend(perceptual_model.predict(batch))
    perceptual_distances = calculate_perceptual_distances(np.array(perception))

**Visualize Perceptual Distance, Raw Pixel, and Our Approach**

In [None]:
perceptual_embedding = TSNE(callbacks=ErrorLogger(),  n_jobs=8).fit(perceptual_distances.T)
raw_image_embedding = TSNE(callbacks=ErrorLogger(), exaggeration=4, learning_rate=len(z)/12, n_jobs=8).fit()
vision_engine_embedding = TSNE(callbacks=ErrorLogger(), n_jobs=8).fit(np.concatenate([z[0],z[1],z[2],z[3]], axis=1))
h1 = TSNE(callbacks=ErrorLogger(), n_jobs=8).fit(z[0])
h2 = TSNE(callbacks=ErrorLogger(),  n_jobs=8).fit(z[1])
h3 = TSNE(callbacks=ErrorLogger(),n_jobs=8).fit(z[2])
h4 = TSNE(callbacks=ErrorLogger(),  n_jobs=8).fit(z[3])

In [None]:
plt.figure(figsize=(40,10))

classnames, indices = np.unique( labels, return_inverse=True)
N = len(classnames)
cmap = plt.cm.rainbow
bounds = np.linspace(0,N,N+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

plt.subplot(141)
embedding = h1
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

plt.subplot(142)
embedding = h2
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(143)
embedding = h3
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)
plt.subplot(144)
embedding = h4
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400)

In [None]:
plt.figure(figsize=(30,20))
plt.subplot(231)
plt.title('Raw Pixel Distribution')
embedding = raw_image_embedding
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.15);
plt.subplot(234)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400, rasterized=True)
plt.subplot(232)
plt.title('Perceptual Loss Metric')
embedding = perceptual_embedding 
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.15);
plt.subplot(235)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400, rasterized=True)
plt.subplot(233)
plt.title('DHRL (Our method)')
embedding = vision_engine_embedding
imscatter(embedding[:, 0], embedding[:, 1], images, zoom=0.15);
plt.subplot(236)
plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.2,
        c=indices, cmap=cmap, norm=norm, s=400, rasterized=True)
plt.tight_layout()
fig = plt.gcf()

**Measure Disentanglement and Completeness Score**

In [None]:
with tf.device(f'/device:GPU:{GPU}'):
    labels = np.hstack([image[1] for image in data_loader.get_test_data()])
    label_encoder = LabelEncoder()
    integer_encoded = label_encoder.fit_transform(labels)
    onehot_encoder = OneHotEncoder(sparse=False)
    integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
    onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
    inputs = onehot_encoded

In [None]:
h = 3
disent_w_avg, complete_avg = dissentanglement_score(z, inputs, h)

In [None]:
print(disent_w_avg, complete_avg)