### Import libraries

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import random
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from pylab import imread,subplot,imshow,show
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow
from tensorflow import keras
from tensorflow.keras import optimizers, layers, models, callbacks
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Conv2D, Flatten, MaxPool2D, Dropout
from tensorflow.keras.applications import xception
from tensorflow.keras.preprocessing.image import ImageDataGenerator  

### Train, validation and test directories

In [None]:
base_dir = 'data/grocery_store_dataset'

train_dir =  os.path.join(base_dir, 'train')
validation_dir =  os.path.join(base_dir, 'val')
test_dir =  os.path.join(base_dir, 'test')

print('Total number of training images:', sum(len(files) for _, _, files in os.walk(train_dir)))
print('Total number of validation images:', sum(len(files) for _, _, files in os.walk(validation_dir)))
print('Total number of test images:', sum(len(files) for _, _, files in os.walk(test_dir)))

classes = ['fruit', 'vegetables', 'dairy']

### Display dataset images

In [None]:
multiple_images = glob('grocery_store_dataset/train/Fruit/Avocado/**')
random_samples = random.sample(multiple_images, 3)

fig, ax = plt.subplots(3, 1,figsize=(20,20))
for i in range(3):
    ax[i].imshow(imread(random_samples[i])), plt.axis('off')

### Data preprocessing / augmentation

In [None]:
train_generator = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        fill_mode='nearest',
        rescale=1./255,
        zoom_range=0.2,
        horizontal_flip=True)

train_dataset = train_generator.flow_from_directory(
        train_dir,
        target_size=(224, 224),
        batch_size=25,
        class_mode='categorical')

test_dataset = ImageDataGenerator(rescale=1./255).flow_from_directory(
        test_dir,
        target_size=(224, 224),
        batch_size=1,
        shuffle = False,
        class_mode='categorical')

validation_dataset = ImageDataGenerator(rescale=1./255).flow_from_directory(
        validation_dir,
        target_size=(224, 224),
        batch_size=10,
        class_mode='categorical')

### Model configuration

In [None]:
model = xception.Xception(weights='imagenet', include_top=False, input_shape=(224, 224, 3),classes=3,pooling='avg')
#model.summary()

for layer in model.layers:
    layer.trainable = False


x = model.output
x = Dropout(0.2)(x)
output=  Dense(3,activation='softmax')(x)
model = Model(model.input,output)

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.RMSprop(0.0001),
              metrics=['accuracy'])

### Model fitting

In [None]:
early_stopping = callbacks.EarlyStopping(monitor='val_accuracy', min_delta=0.01, patience= 3, verbose=1)
checkpoint = callbacks.ModelCheckpoint("./models.h5", monitor='val_accuracy', verbose=1, save_best_only=True)
callbacks=[early_stopping, checkpoint]

history = model.fit(train_dataset,
                    steps_per_epoch=150,
                    epochs=50,
                    validation_data=validation_dataset,
                    verbose=1,
                    callbacks=callbacks)

### Saving the model

In [None]:
model.save('base.h5')

### Plotting the results

In [None]:
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.title('Loss')
plt.plot(np.arange(0, 1), history.history['loss'], label='train')
plt.plot(np.arange(0, 1), history.history['val_loss'], label='val')
plt.legend(loc='best')

plt.subplot(1, 2, 2)
plt.title('Accuracy')
plt.plot(np.arange(0, 1), history.history['accuracy'], label='train')
plt.plot(np.arange(0, 1), history.history['val_accuracy'], label='val')
plt.legend(loc='best')

plt.show()

### Prediction

In [None]:
base_pred = model.predict(test_dataset)
base_pred = base_pred.argmax(axis=1)

print(classification_report(test_dataset.classes, base_pred, target_names=classes))