# Binary image classifier with:
* TPU / Multi-GPU ready code
* TFRecords dataset created for parallel distributed processing
* Dataset created from directories with separate classes
* Preprocessing and augmentation as a Keras layer in dataset preprocessor
* Transfer learning based on ResNET and EfficientNET

* Builds on:

https://medium.com/ai%C2%B3-theory-practice-business/image-dataset-with-tfrecord-files-7188b565bfc

https://keras.io/examples/keras_recipes/creating_tfrecords/

https://www.kaggle.com/code/donkeys/keras-binary-cats-dogs-resnet-98

https://towardsdatascience.com/a-comprehensive-guide-to-training-cnns-on-tpu-1beac4b0eb1c

In [None]:
# Initial imports
import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from tensorflow.keras import layers
import tensorflow_hub as hub
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import PIL
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from keras.regularizers import l2
from keras.models import Sequential, Model, load_model
from keras.layers import (Activation, Dropout, Flatten, Dense, GlobalMaxPooling2D,
                         BatchNormalization, Input, Conv2D, GlobalAveragePooling2D)

import glob

In [None]:
try: 
    # For use with TPU:

    # Detect TPUs
    
    # Locate TPUs on the network
    # tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    
    # TPUStrategy contains the necessary distributed training code that will work on TPUs 
    # with their 8 compute cores
    # strategy = tf.distribute.TPUStrategy(tpu)
    
    # Multi GPU training
    strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0"]) #, "/gpu:1"])

except ValueError: # If TPU or GPU is not available
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU

In [None]:
print(f'Number of accelerators: {strategy.num_replicas_in_sync}')

In [None]:
!pwd

In [None]:
PATH_IMAGES = './data/PetImages'

In [None]:
!ls $PATH_IMAGES

In [None]:
BATCH_SIZE = 32 * strategy.num_replicas_in_sync

AUTOTUNE = tf.data.AUTOTUNE

# This is related to the feature size optimization, a multiple of 128 required for TPU
IMG_SIZE = 128 * 2

In [None]:
train_val_dir = PATH_IMAGES
train_val_cat_files = os.listdir(PATH_IMAGES + '/Cat')
train_val_dog_files = os.listdir(PATH_IMAGES + '/Dog')

# Add a set for final model testing if needed
# test_dir = 

# Directory where tfrecords will be stored
tfrecords_dir = 'tfrecords'

# Train val split ratio

In [None]:
TRAIN_TOTAL = len(train_val_cat_files) + len(train_val_dog_files)
TRAIN_CNT = int(0.75 * TRAIN_TOTAL)
VALID_CNT = TRAIN_TOTAL - TRAIN_CNT

In [None]:
CAT = 'cat'
DOG = 'dog'

# Converting image files dataset to TFRecords

In [None]:
# Train val split

# Defining how many samples will be stored in a single TFRecords file
samples_per_tfrecord = 4096

# Training and validation sets
tfrecords_cnt_trn = TRAIN_CNT // samples_per_tfrecord
tfrecords_cnt_val = VALID_CNT // samples_per_tfrecord

# Adding potential remaining samples into one extra TFRecords file
if tfrecords_cnt_trn % samples_per_tfrecord:
    tfrecords_cnt_trn += 1

if tfrecords_cnt_val % samples_per_tfrecord:
    tfrecords_cnt_val += 1
    
if not os.path.exists(tfrecords_dir):
    os.makedirs(f'{tfrecords_dir}/train')
    os.makedirs(f'{tfrecords_dir}/valid')

# Defining TFRecords auxilliary routines

In [None]:
# Byte list for storing images
def image_feature(value):
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])
    )

# Inte64 list for storing label integers
def int64_feature(value):
    return tf.train.Feature(
        int64_list=tf.train.Int64List(value=[value])
    )

In [None]:
def create_example(image, label):
    feature = {
        "image": image_feature(image),
        "label": int64_feature(label),
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
# Decodes example stored in a TFR and returns it as a readable sample
def parse_tfrecord_fn(example):
    feature_spec = {
        "image": tf.io.FixedLenFeature([], dtype=tf.string),
        "label": tf.io.FixedLenFeature([], dtype=tf.int64),
    }
    
    example = tf.io.parse_single_example(example, feature_spec)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)

    return example

In [None]:
all_files_list = glob.glob('./data/PetImages/*/*')
random.shuffle(all_files_list)

# Train val split
train_files_list = all_files_list[:TRAIN_CNT]
valid_files_list = all_files_list[TRAIN_CNT:]

# Creating and storing TFRecords

In [None]:
%%time
for tfrec_id in range(tfrecords_cnt_trn):

    files_batch = train_files_list[tfrec_id*samples_per_tfrecord:(tfrec_id+1)*samples_per_tfrecord]

    with tf.io.TFRecordWriter(
        tfrecords_dir + "/train/tfrecord_%.6i.tfrec" % (tfrec_id)
    ) as writer:
    
        for i in range(len(files_batch)):
    
            image = tf.io.decode_jpeg(tf.io.read_file(files_batch[i]))
        
            if 'Dog' in files_batch[i]:
                example = create_example(image, 0)
    
            elif 'Cat' in files_batch[i]:
                example = create_example(image, 1)
        
            else:
                continue
            
            writer.write(example.SerializeToString())
    
for tfrec_id in range(tfrecords_cnt_val):

    files_batch = valid_files_list[tfrec_id*samples_per_tfrecord:(tfrec_id+1)*samples_per_tfrecord]

    with tf.io.TFRecordWriter(
        tfrecords_dir + "/valid/tfrecord_%.6i.tfrec" % (tfrec_id)
    ) as writer:
    
        for i in range(len(files_batch)):
    
            image = tf.io.decode_jpeg(tf.io.read_file(files_batch[i]))
        
            if 'Dog' in files_batch[i]:
                example = create_example(image, 0)
    
            elif 'Cat' in files_batch[i]:
                example = create_example(image, 1)
        
            else:
                continue
            
            writer.write(example.SerializeToString())
        

# Testing raw dataset made out of stored TFRecords

In [None]:
raw_dataset = tf.data.TFRecordDataset("tfrecords" + "/valid/tfrecord_000000.tfrec")

In [None]:
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)

In [None]:
parsed_dataset

In [None]:
for features in parsed_dataset.take(1):
    for key in features.keys():
        if key != 'image':
            print(f'{key}: {features[key]}')
            
    plt.figure(figsize=[2, 2])
    plt.imshow(features['image'].numpy())
    plt.show()
        

# Keras preprocessing layer for the ResNET case and Imagenet dataset

In [None]:
#from keras.applications.resnet50 import preprocess_input as resnet_preprocess
from keras.applications.imagenet_utils import preprocess_input as resnet_preprocess

def prepare_sample(features):
    image = tf.image.resize(features['image'], size=(IMG_SIZE, IMG_SIZE))
    return resnet_preprocess(image), features['label']

# Dataset creation from TFRecords with all auxilliary mappings

In [None]:
def get_dataset(filenames, batch_size, augment_sample_fn=None):
    
    dataset = (
    tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
        .map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
        .map(prepare_sample, num_parallel_calls=AUTOTUNE)
        .shuffle(10 * batch_size)
        .repeat()
        .batch(batch_size)
    )
    
    # Apply data augmentation
    if augment_sample_fn:
        dataset = dataset.map(augment_sample_fn,
                   num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.prefetch(AUTOTUNE)
    
    return dataset

In [None]:
train_filenames = tf.io.gfile.glob(f'{tfrecords_dir}/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(f'{tfrecords_dir}/valid/*.tfrec')

In [None]:
train_filenames

# Plotting a batch from the dataset

In [None]:
ds = get_dataset(train_filenames, 9)

In [None]:
def plot_batch_9(ds):
    aux_ds=iter(ds)
    #aux_ds.reset()
    plt.clf()
    plt.figure(figsize=[30, 30])
    batch = next(aux_ds)
    for n in range(9):
        plt.subplot(3, 3, n+1)
        plt.imshow(tf.reshape(batch[0][n], (IMG_SIZE, IMG_SIZE, 3)))
        
    plt.show()

In [None]:
plot_batch_9(ds)

In [None]:
def plot_learning_curves(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs = range(len(acc))
    
    plt.plot(epochs, acc, 'r', label='Training accuracy')
    plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
    plt.title('Trainig and validation accuracy')
    plt.legend(loc=0)
    plt.figure();
    
    plt.show();

# For an infinite dataset training (ds.repeat()) one has to set 
* steps_per_epoch
* validation_steps

# Note: remember to tune batch size for TPU and learning rate accordingly to the (large) batch size (not done here)

In [None]:
steps_per_epoch = math.ceil(TRAIN_CNT/BATCH_SIZE)
validation_steps = math.ceil(VALID_CNT/BATCH_SIZE)

# ResNET-50 Transfer Learning from TFRecords

In [None]:
from keras.applications.resnet50 import ResNet50

* Model creation function allows to specify how many layers are to be kept frozen

In [None]:
def define_model(trainable_layers_count, show_summary=False):
    
    input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    base_model = ResNet50(include_top=False,
                         #weights=None,
                          weights='imagenet',
                         input_tensor=input_tensor)
    # base_model.load_weights('./resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
    
    if trainable_layers_count=='all':
        for layer in base_model.layers:
            layer.trainable = True
    else:
        for layer in base_model.layers:
            layer.trainable = False
            
        for layer in base_model.layers[-trainable_layers_count:]:
            layer.trainable = True
        
    print('Base model has {} layers'.format(len(base_model.layers)))
    
    x = GlobalAveragePooling2D()(base_model.output)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu', kernel_regularizer=l2(5e-4))(x)
    x = Dropout(0.5)(x)
    final_outpu = Dense(1, activation='sigmoid', name='final_output')(x)
    
    model = Model(input_tensor, final_outpu)
    
    if show_summary:
        model.summary()
        
    model.compile(loss='binary_crossentropy', 
                  optimizer='adam', 
                  metrics=['accuracy'],
                 steps_per_execution=32,
                 jit_compile=True,)
    
    return model

# Creating useful callback functions

In [None]:
from keras.callbacks import (ModelCheckpoint, LearningRateScheduler, 
                            EarlyStopping, ReduceLROnPlateau, CSVLogger)

checkpoint = ModelCheckpoint('./working/Resnet50_best.h5', monitor='val_loss',
                            verbose=1, save_best_only=True, mode='min', save_weights_only=True)

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, 
                                  verbose=1, mode='auto', epsilon=0.0001)

early = EarlyStopping(monitor='val_loss',
                      mode='min',
                     patience=7)

csv_logger = CSVLogger(filename='./working/training_log_csv',
                      separator=',',
                      append=True)

callbacks_list = [checkpoint, csv_logger, early]

# Creating model in the distributed strategy scope

In [None]:
with strategy.scope():
    model = define_model(3, show_summary=True)

# Training

In [None]:
train_ds = get_dataset(train_filenames, batch_size)
val_ds = get_dataset(valid_filenames, batch_size)

In [None]:
%%time

history = model.fit(train_ds,
                   steps_per_epoch=steps_per_epoch,
                   epochs=6,
                   validation_data=val_ds,
                   validation_steps=validation_steps,
                   verbose=1,
                   callbacks=callbacks_list)

In [None]:
# This loads the best weights stored by the ES callback
model.load_weights('./working/Resnet50_best.h5')

In [None]:
model.evaluate(val_ds, steps=validation_steps)

In [None]:
plot_learning_curves(history)

# Training with data augmentation

# Data augmentation placed outside model in the data pipeline, TPU may not support augmentation ops

In [None]:
data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip('horizontal'),
    layers.experimental.preprocessing.RandomRotation(0.2),
    layers.experimental.preprocessing.RandomZoom(0.2),
    layers.experimental.preprocessing.RandomContrast(factor=0.2),
])

def data_augment(img, label):
    return data_augmentation(img), label

In [None]:
steps_per_epoch = math.ceil(TRAIN_CNT/BATCH_SIZE)
validation_steps = math.ceil(VALID_CNT/BATCH_SIZE)
batch_size = 32
train_ds = get_dataset(train_filenames, batch_size, augment_sample_fn=data_augment)
val_ds = get_dataset(valid_filenames, batch_size)

In [None]:
# Early stopping callback automatically retrieving best weights
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=4,
                                                    restore_best_weights=True)

In [None]:
with strategy.scope():
    model = define_model(3, show_summary=True)

In [None]:
%%time

history = model.fit(train_ds,
                   steps_per_epoch=steps_per_epoch,
                   epochs=6,
                   validation_data=val_ds,
                   validation_steps=validation_steps,
                   verbose=1,
                   callbacks=[early_stopping_cb])



In [None]:
plot_learning_curves(history)

# Transfer learning based on EfficientNet

In [None]:
os.environ["TFHUB_MODLE_LOAD_FORMAT"] = "UNCOMPRESSED"

efficientnet_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/classification/2"

In [None]:
def create_feature_vectors_model(model_url):
    feature_extractor_layer = hub.KerasLayer(model_url,
                                            trainable=False,
                                            name='feature_extraction_layer')
    
    model = tf.keras.Sequential([
        feature_extractor_layer,
        layers.Dropout(0.5),
        layers.Dense(1, activation='sigmoid', name='output_layer')
    ])
    
    model.build([None, IMG_SIZE, IMG_SIZE, 3])
    
    model.summary()
    
    model.compile(loss='binary_crossentropy',
                 optimizer='adam',
                 metrics=['accuracy'],
                 steps_per_execution=32)
    
    return model

In [None]:
with strategy.scope():
    model = create_feature_vectors_model(efficientnet_url)

In [None]:
history = model.fit(train_ds,
                   steps_per_epoch=steps_per_epoch,
                   epochs=3,
                   validation_data=val_ds,
                   validation_steps=validation_steps,
                   verbose=1,
                   callbacks=[early_stopping_cb])