In [1]:
import os
import numpy as np
import tarfile
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_file

In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
TRAIN = False

In [3]:
def residual_module(data,
                     filters,
                     stride,
                     reduce=False,
                     reg=0.0001,
                     bn_eps=2e-5,
                     bn_momentum=0.9):
    shortcut = data
    bn_1 = BatchNormalization(axis=-1, 
                              epsilon=bn_eps, 
                              momentum=bn_momentum)(data)
    act_1 = ReLU()(bn_1)
    conv_1 = Conv2D(filters=int(filters / 4.),
                    kernel_size=(1,1),
                    use_bias=False,
                    kernel_regularizer=l2(reg))(act_1)
    bn_2 = BatchNormalization(axis=-1, 
                              epsilon=bn_eps, 
                              momentum=bn_momentum)(conv_1)
    act_2 = ReLU()(bn_2)
    conv_2 = Conv2D(filters=int(filters / 4.),
                    kernel_size=(3,3),
                    strides=stride,
                    padding='same',
                    use_bias=False,
                    kernel_regularizer=l2(reg))(act_2)
    bn_3 = BatchNormalization(axis=-1,
                              epsilon=bn_eps,
                              momentum=bn_momentum)(conv_2)
    act_3 = ReLU()(bn_3)
    conv_3 = Conv2D(filters=filters,
                    kernel_size=(1,1),
                    use_bias=False,
                    kernel_regularizer=l2(reg))(act_3)
    if reduce:
        shortcut = Conv2D(filters=filters,
                          kernel_size=(1,1),
                          strides=stride,
                          use_bias=False,
                          kernel_regularizer=l2(reg))(act_3)
    x = Add()([conv_3, shortcut])
    return x

In [4]:
def build_resnet(input_shape,
                 classes,
                 stages,
                 filters,
                 reg=1e-3,
                 bn_eps=2e-5,
                 bn_momentum=0.9):
    inputs = Input(shape = input_shape)
    x = BatchNormalization(axis=-1,
                           epsilon=bn_eps,
                           momentum=bn_momentum)(inputs)
    x = Conv2D(filters[0], (3,3), 
               use_bias=False,
               padding='same',
               kernel_regularizer=l2(reg))(x)
    
    for i in range(len(stages)):
        stride = (1,1) if i == 0 else (2,2)
        x = residual_module(data=x,
                            filters=filters[i + 1],
                            stride=stride,
                            reduce=True,
                            bn_eps=bn_eps,
                            bn_momentum=bn_momentum)
        for j in range(stages[i] - 1):
            x = residual_module(data=x,
                                filters=filters[i + 1],
                                stride=(1,1),
                                bn_eps=bn_eps,
                                bn_momentum=bn_momentum,)
    
    x = BatchNormalization(axis=-1, 
                           epsilon=bn_eps,
                           momentum=bn_momentum)(x)
    x = ReLU()(x)
    x = AvaragePooling2D((8,8))(x)
    x = Flatten()(x)
    x = Densen(classes, kernel_regularizer=l2(reg))(x)
    x = Softmax()(x)

    return Model(inputs, x, name='resnet')

In [5]:
def load_images_and_labels(image_path, target_size=(32,32)):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.convert_image_dtype(image, np.float32)
    image -= CINIC_MEAN_RGB
    image = tf.image.resize(image, target_size)

    label = tf.strings.split(image_path, os.path.sep)[-2]
    label = (label == CINIC_10_CLASSES)
    label = tf.dtypes.cast(label, tf.float32)

    return image, label


In [6]:
def prepare_dataset(data_pattern, shuffle=False):
    dataset = (tf.data.Dataset
               .list_files(data_pattern)
               .map(load_images_and_labels,
                    num_parallel_calls=AUTOTUNE)
                    .batch(BATCH_SIZE))
    if shuffle:
        dataset = dataset.shuffle(BUFFER_SIZE)

    return dataset.prefetch(BATCH_SIZE)

In [7]:
CINIC_MEAN_RGB = np.array([0.47889522, 0.47227842, 0.43047404])

In [8]:
CINIC_10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

In [9]:
DATASET_URL = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz?sequence=4&isAllowed=y'
DATA_NAME = 'cinic10'
FILE_EXTENSION = 'tar.gz'
FILE_NAME = '.'.join([DATA_NAME, FILE_EXTENSION])

In [10]:
downloaded_file_location = get_file(origin=DATASET_URL, fname=FILE_NAME, extract=False)

data_directory, _ = downloaded_file_location.rsplit(os.path.sep, maxsplit=1)
data_directory = os.path.sep.join([data_directory, DATA_NAME])
tar = tarfile.open(downloaded_file_location)

if not os.path.exists(data_directory):
    tar.extractall(data_directory)

In [11]:
train_pattern = os.path.sep.join([data_directory, 'train/*/*.png'])
test_pattern = os.path.sep.join([data_directory, 'test/*/*.png'])
valid_pattern = os.path.sep.join([data_directory, 'valid/*/*.png'])

In [12]:
BATCH_SIZE = 128
BUFFER_SIZE = 1024

train_dataset = prepare_dataset(train_pattern, shuffle=True)
test_dataset = prepare_dataset(test_pattern, shuffle=True)
valid_dataset = prepare_dataset(valid_pattern, shuffle=True)

In [13]:
if TRAIN:
    model = build_resnet(input_shape=(32, 32, 3),
                         classes=10,
                         stages=(9, 9, 9),
                         filters=(64, 64, 128, 256),
                         reg=5e-3)
    model.compile(loss='categorical_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])

    model_checkpoint_callback = ModelCheckpoint(
        filepath='./model.{epoch:02d}-{val_accuracy:.2f}.hdf5',
        save_weights_only=False,
        monitor='val_accuracy')

    EPOCHS = 100
    model.fit(train_dataset,
              validation_data=valid_dataset,
              epochs=EPOCHS,
              callbacks=[model_checkpoint_callback])

In [14]:
model = load_model('model.38-0.72.hdf5')
result = model.evaluate(test_dataset)
print(f'Test accuracy: {result[1]}')

Test accuracy: 0.7195666432380676
