In [None]:
import sys, os, json
import numpy as np
import pandas as pd
import seaborn as sns
import argparse, progressbar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib.axes._axes as axes
import tiny_imagenet_config as config
sns.set()

In [None]:
# import the necessary packages
from imutils import paths
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from loader_util.preprocessing import ImageToArrayPreprocessor, \
    AspectAwarePreprocessor, SimplePreProcessor, MeanPreprocessor
from loader_util.io import HDF5DatasetGenerator
from loader_util.datasets import SimpleDatasetLoader
from loader_util.nn.conv import FCHeadNet, DeepGoogleNet
from loader_util.callbacks import EpochCheckpoint, TrainingMonitor
##
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Conv2D, Activation, Flatten, Dense, \
    BatchNormalization, MaxPooling2D, AveragePooling2D, Dropout
from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop, SGD, Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model, load_model

In [None]:

args = {
    "checkpoints": r"",
    "model": r"",
    "start_epoch": 0
}

In [None]:
# image generator
aug = ImageDataGenerator(rotation_range=18,
                         zoom_range=0.15,
                         width_shift_range=0.2,
                         height_shift_range=0.2,
                         shear_range=0.15,
                         horizontal_flip=True,
                         fill_mode='nearest')
# load the RGB means for the train set
means = json.loads(open(config.dataset_mean).read())

In [None]:
# init the image preprocessors
sp = SimplePreProcessor(64, 64)
mp = MeanPreprocessor(means["R"], means["G"], means["B"])
iap = ImageToArrayPreprocessor()

# init the train and valid dataset generators
traingen = HDF5DatasetGenerator(dbPath=config.train_hdf5,
                                batchSize=64,
                                preprocessors=[sp, mp, iap],
                                classes=config.num_classes)
validgen = HDF5DatasetGenerator(dbPath=config.valid_hdf5,
                                batchSize=64,
                                preprocessors=[sp, mp, iap],
                                classes=config.num_classes)

In [None]:
# if no specific checkpoint supplied then init model and compile it
if not args["model"]:
    print(f"[INFO] compiling the model......")
    model = DeepGoogleNet.build(64, 64, 3,
                                classes=config.num_classes, reg=0.0002)
    opt = Adam(learning_rate=0.001)
    model.compile(optimizer=opt,
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])
else:
    print(f"[INFO] loading model: {args['model']}......")
    model = load_model(args["model"]) # type: Model

    #update the learning rate
    print(f"[INFO] old learning rate: {K.get_value(model.optimizer.learning_rate)}......")
    K.set_value(model.optimizer.learning_rate, 0.00001)
    print(f"[INFO] new learning rate: {K.get_value(model.optimizer.learning_rate)}......")

In [None]:
# construct a set of callbacks
callbacks = [
    EpochCheckpoint(outputPath=args["checkpoints"],
                    every=5,
                    startAt=args["start_epoch"]),
    TrainingMonitor(figPath=config.fig_path,
                    jsonPath=config.json_path,
                    startAt=args["start_epoch"])
]

In [None]:
# train the net
H = model.fit_generator(traingen.generator(),
                        steps_per_epoch=traingen.numImages//64,
                        validation_data=validgen.generator(),
                        validation_steps=validgen.numImages//64,
                        epochs=10,
                        max_queue_size=10,
                        callbacks=callbacks,
                        verbose=1)
# close the database
traingen.close()
validgen.close()