In [1]:
import tensorflow as tf
import numpy as np
import cv2
from pathlib import Path
from ipywidgets import widgets, interact_manual
import matplotlib.pyplot as plt

from training_schemes import load_compressor_with_range
from datasets.cifar10 import read_images, pipeline, normalize
from visualization.tensorboard import draw_text_line

tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0],
                                         enable=True)

In [2]:
compressor, alpha_bpp_fit, parameters = load_compressor_with_range(Path.home() / 'thesis-compression-data/experiments/compressor/19.03-best/act7_norm')
bpp_range = np.linspace(0.8, 1.5, 10).astype(np.float32)
val_dataset, _ = read_images(Path.home() / 'thesis-compression-data/datasets/cifar-10/test')
batch_size = 6

alpha_range = alpha_bpp_fit.inverse_numpy(bpp_range)
val_dataset = pipeline(val_dataset, batch_size=batch_size, flip=False, crop=False, 
                       classifier_normalize=False,
                       shuffle_buffer_size=10000,
                       repeat=True)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Loading weights and alpha-bpp fit from epoch 150





In [3]:
original_model_path = Path.home()
original_model_path = original_model_path / 'thesis-compression-data/experiments/cifar_10_normal_training/'
original_model_path = original_model_path / 'sophisticated-poetic-warthog-of-variation'
original_model = tf.keras.models.load_model(str(original_model_path / 'model.hdf5'))
original_model.load_weights(str(original_model_path / 'final_weights.hdf5'))

preprocess_for_model = lambda X: normalize(X[..., ::-1])

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [4]:
sess = tf.keras.backend.get_session()
item = tf.compat.v1.data.make_one_shot_iterator(val_dataset).get_next()

to_compress = {
    'X': tf.repeat(item[0], len(alpha_range), axis=0),
    'alpha': tf.tile(alpha_range, [batch_size]),
    'lambda': tf.repeat(parameters['lmbda'], len(alpha_range) * batch_size)
}

reconstruction = compressor.forward_with_range_coding(to_compress)
original_model_pred = original_model(preprocess_for_model(reconstruction['X_reconstructed']))
true_label_pred = tf.reduce_sum(original_model_pred * tf.repeat(item[1], len(alpha_range), axis=0), axis=1)

In [6]:
@interact_manual(scale=widgets.IntSlider(min=4, max=10))
def show_images(scale):
    reconstruction_result, item_result, true_label_result = sess.run([reconstruction, item, true_label_pred])
    text_batches = np.reshape([f'{reconstruction_result["range_coded_bpp"][i]:0.2f}\n'
                               f'{true_label_result[i]:0.2f}' 
                               for i in range(len(reconstruction_result['range_coded_bpp']))],
                              (batch_size, -1))
    img_height, img_width = reconstruction_result['X_reconstructed'].shape[1:3]
    X_rec = np.reshape(reconstruction_result['X_reconstructed'], 
                       (batch_size, len(alpha_range), *reconstruction_result['X_reconstructed'].shape[1:]))        
    X_rec = np.transpose(X_rec, (0, 2, 1, 3, 4))
    X_rec = np.reshape(X_rec, (batch_size, img_height, len(alpha_range) * img_width, 3))
    together = np.concatenate([item_result[0], X_rec], axis=2)
    together_with_text = []
    for i, text_batch in enumerate(text_batches):
        together_with_text.append(cv2.resize(together[i], dsize=None, fx=scale, fy=scale, 
                                             interpolation=cv2.INTER_NEAREST))
        text_batch = ['Bpp\nPred'] + list(text_batch)
        text_line = draw_text_line(text_batch, background_color=255, font_color=0, font_size=48,
                                   cell_dimension=(100, scale * 32))
        together_with_text.append(text_line)
        
    together_with_text = np.concatenate(together_with_text, axis=0)
    dpi = 20
    plt.figure(figsize=(together_with_text.shape[1] / dpi, together_with_text.shape[0] / dpi), dpi=dpi)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(together_with_text)

interactive(children=(IntSlider(value=4, description='scale', max=10, min=4), Button(description='Run Interact…