In [1]:
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

%load_ext tensorboard
import easydict


In [2]:
import sys
import os
module_path = os.path.abspath(os.path.join('..', 'CIFAR'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from CIFARLoader import save_cifar_fit, check_cifar


In [3]:
def clear_session():
    keras.backend.clear_session()
    tf.random.set_seed(42)
    np.random.seed(42)


In [4]:
VALIDATION_SPLIT = 0.1

train_datagen = ImageDataGenerator(
    featurewise_center=True, 
    samplewise_center=False,
    featurewise_std_normalization=True, 
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-06,
    rescale= 1./255, 
    rotation_range=40, 
    width_shift_range=0.2, 
    height_shift_range= 0.2, 
    shear_range= 0.2, 
    zoom_range=0.2, 
    horizontal_flip=True, 
    fill_mode='nearest', 
    validation_split=VALIDATION_SPLIT
)

test_datagen = ImageDataGenerator(featurewise_center=True,
                                  samplewise_center=False,
                                  featurewise_std_normalization=True,
                                  samplewise_std_normalization=False,
                                  zca_whitening=False,
                                  zca_epsilon=1e-06,
                                  rescale=1./255,
                                  rotation_range=40,
                                  width_shift_range=0.2,
                                  height_shift_range=0.2,
                                  shear_range=0.2,
                                  zoom_range=0.2,
                                  horizontal_flip=True,
                                  fill_mode='nearest'
                                  )

In [5]:
OUTPUT_DIR = '..\\CIFAR'


In [6]:
args = easydict.EasyDict({
    "dataset": "cifar10",
    "output": OUTPUT_DIR,
    "name_with_batch_index": False
})

if check_cifar(args.dataset, OUTPUT_DIR) :
    save_cifar_fit(args, train_datagen, test_datagen)


                      

Saving train images: 100%|██████████| 50000/50000 [00:19<00:00, 2504.56it/s]


                     

Saving test images: 100%|██████████| 10000/10000 [00:03<00:00, 2532.65it/s]


In [7]:
import glob
IMAGE_COUNT = len(glob.glob(OUTPUT_DIR + '/' + 'train'  + '/*/*'))
BATCH_SIZE = 50
STEPS_PER_EPOCH = np.ceil((IMAGE_COUNT - IMAGE_COUNT * VALIDATION_SPLIT)/BATCH_SIZE)
IMAGE_COUNT, STEPS_PER_EPOCH

(50000, 900.0)

In [8]:
TARGET_SIZE = (224, 224)

train_generator = train_datagen.flow_from_directory(
    OUTPUT_DIR + '/' + 'train', 
    target_size=TARGET_SIZE,
    batch_size=BATCH_SIZE, 
    class_mode='categorical', 
    subset='training',
    shuffle=True,
    seed=42
)
validation_generator = train_datagen.flow_from_directory(
    OUTPUT_DIR + '/' + 'train',
    target_size=TARGET_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=True,
    seed=42
)


Found 45000 images belonging to 10 classes.
Found 5000 images belonging to 10 classes.


inception net/ google net

In [9]:
from datetime import datetime
from InceptionModules import InceptionStem

date_time = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")

inputs = keras.Input(shape=(224, 224, 3))

image_net_model = InceptionStem()
image_net_model(inputs)

<tf.Tensor 'inception_v1/Identity:0' shape=(None, 10) dtype=float32>

In [10]:
image_net_model.summary()


Model: "inception_v1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 112, 112, 64)      9472      
_________________________________________________________________
maxpool_1 (MaxPooling2D)     (None, 56, 56, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 56, 56, 64)        4160      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 56, 56, 192)       110784    
_________________________________________________________________
maxpool_2 (MaxPooling2D)     (None, 28, 28, 192)       0         
_________________________________________________________________
mixed_inc_blk_1 (InceptionMo (None, 28, 28, 256)       155504    
_________________________________________________________________
mixed_inc_blk_2 (InceptionMo (None, 28, 28, 480)      

In [11]:
model_checkpoint = tf.keras.callbacks.ModelCheckpoint("image_net_model_" + date_time + ".h5", overwrite=True, save_best_only=True)

run_index = 1
run_log_dir = Path(Path().resolve(), "cifar10_logs", "run_{:03d}".format(run_index))
tensorboard_cb = tf.keras.callbacks.TensorBoard(run_log_dir)

callbacks  = [model_checkpoint, tensorboard_cb]

In [12]:
image_net_model.compile(loss=tf.losses.CategoricalCrossentropy(), metrics=[
                       'accuracy'], optimizer=tf.optimizers.Adam(learning_rate=0.0001))

history = image_net_model.fit(train_generator, validation_data=validation_generator,
    callbacks=[callbacks], verbose=2, workers=6, 
    epochs=100, steps_per_epoch=STEPS_PER_EPOCH)


Epoch 1/100


KeyboardInterrupt: 

In [None]:
test_generator = test_datagen.flow_from_directory(
    OUTPUT_DIR + '/' + 'test', target_size=TARGET_SIZE)


In [None]:
score = image_net_model.evaluate(test_generator, workers=8,verbose=2)

In [None]:

acc = history.history['accuracy'][-1]
val_acc = history.history['val_accuracy'][-1]
loss = history.history['loss'][-1]
val_loss = history.history['val_loss'][-1]

print("Training accuracy: ", acc)
print("Training loss: ", loss)

print("Validation accuracy: ", val_acc)
print("Validation loss: ", val_loss)

# Plotting the graphs to visualize the trend of accuracy and loss

epochs = range(1, len(history.history['accuracy'])+1)

plt.plot(epochs, history.history['accuracy'], 'bo', label='Training acc')
plt.plot(epochs, history.history['val_accuracy'], 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, history.history['loss'], 'bo', label='Training loss')
plt.plot(epochs, history.history['val_loss'], 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()