In [None]:
import time
import train
import models

import numpy as np

from PIL import Image
from IPython.html.widgets import interact
from sklearn.metrics import roc_auc_score
from keras.callbacks import ModelCheckpoint, EarlyStopping

import matplotlib.pyplot as plt
from eval.eval import get_predictions, get_auc_score
from utils.visualization import prepare_img_to_plot, parula_map
from utils.visualization import get_gaussian_quality_map as get_quality_map
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
"""
Data parameters' definition
"""
batch_size = 8
load_data = False
weak_dir = 'D:/work/datasets/quality/quality'
epochs = 1

"""
Dataset augmentation's parameters
"""
aug_params = {'horizontal_flip': True,
              'vertical_flip': True,
              'width_shift_range': 0.05, 
              'height_shift_range': 0.05, 
              'rotation_range': 360, 
              'zoom_range': 0.02}

"""
Model parameter's definition
"""
nf = 64
n_blocks = 4
input_size = 512
pooling_wreg = 1e-2
pooling_breg = 1e-3
lr = 2e-4

"""
Callbacks' definition
"""
experiment_path = "experiments/GAP.hdf5"
patience = 0
checkpointer = ModelCheckpoint(filepath=experiment_path, verbose=1, 
                               save_best_only=True, save_weights_only=False)
callbacks = [checkpointer]
if patience > 0:
    early = EarlyStopping(patience=patience, verbose=1)
    callbacks.append(early)

In [None]:
train_it, val_it, test_it = train.get_data_iterators(batch_size=batch_size, data_dir=weak_dir, 
                                                     target_size=(512, 512), samplewise_center=False, 
                                                     samplewise_std_normalization=False, rescale=1/255., 
                                                     fill_mode='constant', load_train_data=load_data, 
                                                     color_mode='rgb', **aug_params)

In [None]:
eyequal, heatmap = models.quality_assessment(nf, input_size=input_size, n_blocks=n_blocks, lr=lr, 
                                             pooling_wreg=pooling_wreg, pooling_breg=pooling_breg)

In [None]:
eyequal.fit_generator(train_it, train_it.n, epochs, validation_data=val_it, nb_val_samples=val_it.n, 
                      verbose=2, callbacks=callbacks)

In [None]:
eyequal.load_weights(experiment_path)

In [None]:
x, y = next(val_it)

heat_pred = heatmap.predict(x)

def plot_figs(idx=0):
    print y
    print 'Pred = {0}; GT = {1}'.format(eyequal.predict(x)[idx], y[idx])
    
    heat_pred_img = heat_pred[idx, 0]
    x_plot = np.transpose(x[idx], (1, 2, 0))
    
    plt.figure(figsize=(16, 21))
    
    plt.subplot(1, 2, 1)
    plt.title('Image')
    plt.imshow(x_plot)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title('Quality Map')
    plt.imshow(x_plot)
    im = get_quality_map(heat_pred_img, n_blocks=n_blocks)
    plt.imshow(im, alpha=0.5, vmin=0, vmax=1, cmap=parula_map)
    plt.axis('off')
    
    plt.show()
    
interact(plot_figs, idx=range(batch_size));

In [None]:
w, b = eyequal.get_layer('pool').get_weights()
out_size = models.get_out_size(input_size, n_blocks)
w = w.reshape((out_size, out_size))

plt.imshow(w)
plt.show()

In [None]:
print 'Train AUC = {0}'.format(get_auc_score(eyequal, train_it, train_it.n))
print 'Validation AUC = {0}'.format(get_auc_score(eyequal, val_it, val_it.n))
print 'Test AUC = {0}'.format(get_auc_score(eyequal, test_it, test_it.n))