In [None]:
import os
import sys

# append PYTHONPATH to load extensions
sys.path.append('fastaugment')
sys.path.append('sigmoid_like_tf_op')

# load TensorFlow
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

# import other modules
import numpy
import matplotlib.pyplot as plt
from tensorflow.data.experimental import AUTOTUNE
from model import make_model
from imagenet_tools import ImageSet
from fast_augment import augment
from sigmoid_like import sigmoid_like
from glob import glob

# Set global configuration

# path to ImageNet 2012 dataset
imagenet_path = '/imagenet'

# filename pattern on where to store the training set TF record cache
tfrecord_pattern = os.path.join(imagenet_path, 'dogs_subset', "train-*.tfrecord")

# folder to store checkpoints during training
checkpoints_dir = 'checkpoints'

# classes numbers: 119 dogs and tiger cat
classes = list(range(151,269)) + [275, 282]

# training and validation batch sizes
train_batch_size = 64
valid_batch_size = 64

In [None]:
# Setup the model
model = make_model(input_size=385, activation=sigmoid_like, num_classes=120)
model.summary()

In [None]:
# Dataset preparation

# Get image size in pixels from the model
image_size = model.layers[0].input_shape[0][1]

# Load training set as an ImageSet.
# This takes some time the very first time to write relevant information into a cache file (train_cache.txt)
train_set = ImageSet(os.path.join(imagenet_path, 'Annotations', 'CLS-LOC', 'train'),
                     os.path.join(imagenet_path, 'Data', 'CLS-LOC', 'train'),
                     os.path.join(imagenet_path, 'train_cache.txt'),
                     batch_size=train_batch_size,
                     image_size=image_size)

# Supply class names to the training set
train_set.supply_class_names('map_clsloc.txt')

# Filter the training set to keep only the selected classes
train_set.filter(classes)

# Shuffle the training set
train_set.shuffle()

# Make TFRecord for the training set for better speed in training.
# Takes additional ~25 GB and some time when run the first time.
train_set.make_tfrecord(tfrecord_pattern, 16)


def read_tfrecord(example):
    """ TF Record file reading function
    """
    tfrecord_format = ({
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
    })
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = tf.io.decode_jpeg(example["image"], channels=3)
    label = tf.cast(example["label"], tf.int32)
    return image, tf.one_hot(label, len(classes))


def dataset_from_tfrecord(pattern, batch_size):
    """ Constructs a TFRecordDataset
    """
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False      # disable order, increase speed
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(pattern))
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    dataset = dataset.shuffle(batch_size, reshuffle_each_iteration=True)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda x, y: augment(x, y,
                                               mixup=0.5,
                                               perspective=20))    # enable augmentation
    return dataset


# Turn training set from ImageSet to an augmented TFRecordDataset
train_set = dataset_from_tfrecord(tfrecord_pattern, train_set.batch_size)


# Proceed similarly with the validation set.
# No TF record here, but could be.
val_set = ImageSet(os.path.join(imagenet_path, 'Annotations', 'CLS-LOC', 'val'),
                   os.path.join(imagenet_path, 'Data', 'CLS-LOC', 'val'),
                   os.path.join(imagenet_path, 'val_cache.txt'),
                   batch_size=valid_batch_size,
                   image_size=image_size,
                   use_annotations=False)
val_set.supply_class_names('map_clsloc.txt')
val_set.filter(classes)

In [None]:
# Sample augmented images from the training set
it = train_set.take(1).unbatch().as_numpy_iterator()

plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    image, _ = next(it)
    plt.imshow(image)
plt.show()

In [None]:
# Get images of a specific class from the validation set
class_idx = 119
imgs = val_set.samples(class_idx)

# Plot the images
plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    image, class_name = next(imgs)
    plt.imshow(image, cmap=plt.cm.binary)
plt.suptitle(class_name)
plt.show()

In [None]:
# Load last checkpoint if available

checkpoint_files = sorted(glob(os.path.join(checkpoints_dir, '*.h5')))
initial_epoch = len(checkpoint_files)

if checkpoint_files:
    model.load_weights(checkpoint_files[initial_epoch - 1])
    print('Loading %s (epoch %d)' % (checkpoint_files[initial_epoch - 1], initial_epoch))

In [None]:
# FFEEEEAAAAATTTTTT

if initial_epoch == 0:
    !rm -rf logs
    !rm checkpoints/*.h5

# set up saving callback
save_callback = tf.keras.callbacks.ModelCheckpoint(
    os.path.join(checkpoints_dir, 'weights{epoch:04d}-{loss:.2f}.h5'),
    monitor='loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=True,
    mode='min')

# define LR schedule
def lr_scheduler(epoch, _, initial_lr=0.01, cliff=400, step=50):
    step = max(0, (epoch - cliff) // step + 1)
    factor = 0.5
    lr = initial_lr * (factor ** step)
    return lr

# fit
model.fit(train_set,
          initial_epoch=initial_epoch, epochs=600,
          validation_data=val_set,
          callbacks=[save_callback,
                     tf.keras.callbacks.TensorBoard(log_dir='./logs'),
                     tf.keras.callbacks.LearningRateScheduler(lr_scheduler)])

In [None]:
# Run validation pass
model.evaluate(val_set)

In [None]:
# Compute average of selected checkpoints

# Get checkpoint files
checkpoint_files = glob(os.path.join('ensembling','*.h5'))
print('Averaging', len(checkpoint_files), 'models')

# Load and grab weights in an array
weights = []
for entry in checkpoint_files:
    model.load_weights(entry)
    weights.append(model.get_weights())

# Compute the average
avg_weights = []
for i in range(len(weights[0])):
    avg_weights.append(numpy.array([w[i] for w in weights]).mean(axis=0))

# Assign the averaged weights to the model
model.set_weights(avg_weights)
model.evaluate(val_set)

# Save the model
model.save_weights('model.h5')