#### Install Libraries

In [1]:
! pip install livelossplot

Collecting livelossplot
  Downloading livelossplot-0.5.4-py3-none-any.whl (22 kB)
Installing collected packages: livelossplot
Successfully installed livelossplot-0.5.4


### Import libraries

In [2]:
import os
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow.keras as keras

import sys
sys.path.append('/home/jovyan/work')

import nima as nima


In [23]:
! pwd

/home/jovyan


In [3]:
AVA_DATASET_DIR = '/home/jovyan/work/data/AVA/'
PROJECT_ROOT_DIR = '/home/jovyan/work/'
AVA_IMAGES_DIR = os.path.join(AVA_DATASET_DIR, 'images')
WEIGHTS_DIR = os.path.join(PROJECT_ROOT_DIR, )
assert os.path.isdir(AVA_DATASET_DIR), f'Invalid directory : {AVA_DATASET_DIR}'

from nima.utils.ava_dataset_utils import load_data, get_rating_columns

### Load Dataset

In [20]:
print(f'Project Root Directory : {PROJECT_ROOT_DIR}')

df_train, df_valid, df_test = load_data(AVA_DATASET_DIR, sample_size=100)
print(f"Training length : {len(df_train)}, Validation length : {len(df_valid)}, Test size : {len(df_test)}")

Project Root Directory : /home/jovyan/work/
Number of samples picked 100
Training length : 65, Validation length : 29, Test size : 5


In [21]:
df_valid

Unnamed: 0,image_id,count_rating_1,count_rating_2,count_rating_3,count_rating_4,count_rating_5,count_rating_6,count_rating_7,count_rating_8,count_rating_9,count_rating_10
0,953757,0,2,3,15,24,44,26,11,5,0
1,954071,0,0,6,12,46,45,13,2,1,0
2,954013,0,0,1,5,18,39,31,22,9,7
3,953946,0,0,4,4,34,56,14,8,5,0
4,953897,0,0,0,5,19,46,29,22,5,2
5,954184,0,0,4,8,41,56,10,3,4,0
6,953004,0,4,9,34,54,18,2,2,1,0
7,954218,0,0,7,12,40,46,10,3,4,2
8,953863,0,0,2,11,29,44,24,13,5,2
9,953756,0,2,3,9,35,50,20,5,2,2


### Create Model

In [5]:
from nima.model.model_builder import NIMA

nima_cnn = NIMA(base_model_name='vgg', metrics=['accuracy'])
nima_cnn.build()
# nima_cnn.model.summary()
nima_cnn.compile()

Model's module - tensorflow.keras.applications.vgg19.VGG19
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


### Create Datagenerators

In [47]:
from nima.model.data_generator import NimaDataGenerator

x_col, y_cols = 'image_id', get_rating_columns()

train_datagen = NimaDataGenerator(
    df_train,
    AVA_IMAGES_DIR,
    x_col,
    y_cols,
    nima_cnn.preprocessing_function(),
    is_train=True,
    batch_size=32,
)
valid_datagen = NimaDataGenerator(
    df_valid,
    AVA_IMAGES_DIR,
    x_col,
    y_cols,
    nima_cnn.preprocessing_function(),
    is_train=True,
    batch_size=32,
)
test_datagen = NimaDataGenerator(
    df_test,
    AVA_IMAGES_DIR,
    x_col,
    None,
    nima_cnn.preprocessing_function(),
    is_train=False,
    batch_size=32,
)

Found 65 valid image filenames belonging to 10 classes.
Found 29 valid image filenames belonging to 10 classes.
Found 5 valid image filenames belonging to 10 classes.


In [48]:
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from livelossplot.inputs.keras import PlotLossesCallback
import time
    
arg_verbose, arg_epochs = 1, 32

# set model weight and path
weight_filename = f'{nima_cnn.base_model_name}_weight_best.hdf5'
weight_filepath = os.path.join(WEIGHTS_DIR, weight_filename)
print(f'Model Weight path : {weight_filepath}')

es = EarlyStopping(monitor='val_loss', patience=4, verbose=arg_verbose)
ckpt = ModelCheckpoint(
    filepath=weight_filepath,
    save_weights_only=True,
    monitor="val_earth_movers_distance",
    mode="auto",
    save_best_only=True,
)
lr = ReduceLROnPlateau(monitor='val_loss', patience=2, verbose=1)
plot_loss = PlotLossesCallback()

# start training
start_time = time.perf_counter()
history = nima_cnn.model.fit(train_datagen, validation_data=valid_datagen,
    epochs=arg_epochs, callbacks=[es, ckpt, lr, plot_loss],
    verbose=arg_verbose)
end_time = time.perf_counter()
print(f'Time taken : {time.strftime("%H:%M:%S", time.gmtime(end_time-start_time))}')

result_df = pd.DataFrame(history.history)

Model Weight path : /home/jovyan/work/VGG19_weight_best.hdf5
Getting items ...
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31]
0 (224, 224, 3) (32, 224, 224, 3)
1 (224, 224, 3) (32, 224, 224, 3)
2 (224, 224, 3) (32, 224, 224, 3)
3 (224, 224, 3) (32, 224, 224, 3)
4 (224, 224, 3) (32, 224, 224, 3)
5 (224, 224, 3) (32, 224, 224, 3)
6 (224, 224, 3) (32, 224, 224, 3)
7 (224, 224, 3) (32, 224, 224, 3)
8 (224, 224, 3) (32, 224, 224, 3)
9 (224, 224, 3) (32, 224, 224, 3)
10 (224, 224, 3) (32, 224, 224, 3)
11 (224, 224, 3) (32, 224, 224, 3)
12 (224, 224, 3) (32, 224, 224, 3)
13 (224, 224, 3) (32, 224, 224, 3)
14 (224, 224, 3) (32, 224, 224, 3)
15 (224, 224, 3) (32, 224, 224, 3)
16 (224, 224, 3) (32, 224, 224, 3)
17 (224, 224, 3) (32, 224, 224, 3)
18 (224, 224, 3) (32, 224, 224, 3)
19 (224, 224, 3) (32, 224, 224, 3)
20 (224, 224, 3) (32, 224, 224, 3)
21 (224, 224, 3) (32, 224, 224, 3)
22 (224, 224, 3) (32, 224, 224, 3)
23 (224, 224, 3) (32, 224, 

KeyboardInterrupt: 

In [50]:
# result_df.head()
np.empty((2, *(224, 224), 3))

array([[[[ 4.66312181e-310,  4.66312857e-310,  4.66312181e-310],
         [ 4.66312857e-310, -1.61459546e+308, -1.67099367e+308],
         [-1.61481319e+308, -1.55863356e+308, -1.55819467e+308],
         ...,
         [-1.27554356e+308, -1.27554356e+308, -1.21958767e+308],
         [-1.27620361e+308, -1.21980454e+308, -1.10722670e+308],
         [-1.21936394e+308, -1.16296572e+308, -1.16296572e+308]],

        [[-1.16296572e+308, -1.10634892e+308, -1.16274714e+308],
         [-1.10634892e+308, -1.04995071e+308, -9.93552498e+307],
         [-9.37154286e+307, -9.37154286e+307, -8.89801320e+307],
         ...,
         [-1.04906950e+308, -1.10546771e+308, -1.10546771e+308],
         [-9.92894163e+307, -9.93329624e+307, -9.93329624e+307],
         [-9.93111894e+307, -9.93330482e+307, -9.93547355e+307]],

        [[-9.93329624e+307, -1.04951011e+308, -1.04972869e+308],
         [-1.04972869e+308, -1.10612691e+308, -1.10612691e+308],
         [-1.10546771e+308, -1.16186593e+308, -1.10546686e