In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px


import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

from sklearn.metrics import confusion_matrix , classification_report 
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import roc_curve, auc, roc_auc_score

from IPython.display import clear_output
import warnings
warnings.filterwarnings('ignore')

from keras.models import model_from_json
import cv2, os
from keras.layers import Flatten
from keras.utils.vis_utils import plot_model
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
train_dir = r"dataset/ImageEmotion (85-15)/train"
test_dir = r"dataset/ImageEmotion (85-15)/test"

SEED = 12
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 64
momentum = 0.9
EPOCHS = 500
#FINE_TUNING_EPOCHS = 30
LR = 0.001
NUM_CLASSES = 8
EARLY_STOPPING_CRITERIA=3
CLASS_LABELS  = ['Amusement', 'Anger', 'Awe', 'Contentment', 'Disgust', 'Excitement', 'Fear', 'Sadness']
CLASS_LABELS_EMOJIS = ["🥳", "😡", "😯", "😌", "🤢" ,"🤩", "😱" , "😔" ]

In [None]:
#preprocess_fun = tf.keras.applications.mobilenet.preprocess_input

train_datagen = ImageDataGenerator(#horizontal_flip=True,
                                   #width_shift_range=0.1,
                                   #height_shift_range=0.05,
                                   #rotation_range= 10,
                                   rescale = 1./255,
                                   validation_split = 0,
                                   #preprocessing_function=preprocess_fun
                                  )
test_datagen = ImageDataGenerator(rescale = 1./255,
                                  validation_split = 0,
                                  #preprocessing_function=preprocess_fun
                                 )

train_generator = train_datagen.flow_from_directory(directory = train_dir,
                                                    target_size = (IMG_HEIGHT ,IMG_WIDTH),
                                                    batch_size = BATCH_SIZE,
                                                    shuffle  = True , 
                                                    color_mode = "rgb",
                                                    class_mode = "categorical",
                                                    subset = "training",
                                                    seed = 12
                                                   )

test_generator = test_datagen.flow_from_directory(directory = test_dir,
                                                   target_size = (IMG_HEIGHT ,IMG_WIDTH),
                                                    batch_size = BATCH_SIZE,
                                                    shuffle  = False , 
                                                    color_mode = "rgb",
                                                    class_mode = "categorical",
                                                    seed = 12
                                                  )

In [None]:
""" Returns mean value of RGB """
def mean(inputs):
    # flatten image to be 2D and compute mean rgb
    mean_rgb_val = mean_helper(inputs)
    # convert image to hsv scale
    #hsv = cv2.cvtColor(img_data, cv2.COLOR_RGB2HSV)
    hsv = tf.image.rgb_to_hsv(inputs)
    # calculate mean
    mean_hsv_val = mean_helper(hsv)
    return mean_rgb_val, mean_hsv_val

""" Calculates mean value of a plane given a 3D matrix """
def mean_helper(org_mat):
    # "flatten" matrix to a 2D matrix
    temp = org_mat
    temp = tf.math.reduce_mean(temp,axis=1)
    temp = tf.math.reduce_mean(temp,axis=1)
    mean_val = temp
    return mean_val

""" Calculates pleasure, arousal, dominance values of the image"""
def calculate_pad(hsv):
    batch=hsv.shape[0]
    i = 0
    while i < batch:
    #for i in range(BATCH_SIZE):
        saturation = hsv[i][1]
        brightness = hsv[i][2] # or 'value' in hsv
        pleasure = 0.69 * brightness + 0.22 * saturation
        arousal = 0.31 * brightness + 0.6 * saturation
        dominance = 0.76 * brightness + 0.32 * saturation
        pleasure = tf.reshape(pleasure, (1,1))
        arousal = tf.reshape(arousal, (1,1))
        dominance = tf.reshape(dominance, (1,1))
        pad_per_image = tf.concat([pleasure, arousal, dominance],axis= 1)
        pad_per_image = tf.reshape(pad_per_image, (1,1,3))
        if i == 0:
            pad = pad_per_image
        else:
            pad = tf.concat([pad, pad_per_image], axis= 0)
        i += 1
    return pad

def lowfeature_extractor(inputs):
    rgb, hsv = mean(inputs)
    rgb = Flatten()(rgb)
    pad = Flatten()(calculate_pad(hsv))
    hsv = Flatten()(hsv)
    low_level_features = tf.concat([rgb, pad, hsv], axis= 1)
    return low_level_features

In [None]:
base_model = tf.keras.applications.mobilenet.MobileNet(input_shape=(224,224,3), include_top=False, weights='imagenet')

def classifier(inputs):
    x = base_model(inputs)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(8, activation='softmax', name='classification')(x)
    
    return x
    
        
def final_model(inputs):
    classification_output = classifier(inputs)
    
    return classification_output

def define_compile_model():
    
    inputs = tf.keras.layers.Input(shape=(224,224,3))
    classification_output = final_model(inputs) 
    model = tf.keras.Model(inputs= inputs, outputs= classification_output)
    model.compile(optimizer=tf.keras.optimizers.SGD(0.001), 
                loss='categorical_crossentropy',
                metrics = ['accuracy'])
  
    return model

In [None]:
model = define_compile_model()
#model.load_weights(r'ResNet50-009-0.655693-0.596582.h5')

model.summary()

In [None]:
history = model.fit(train_generator,
                    batch_size = BATCH_SIZE,
                    epochs = 100 ,
                    validation_data = test_generator,
                    validation_batch_size = BATCH_SIZE , 
                    callbacks= [checkpoint]
                   )

history = pd.DataFrame(history.history)