<h1>Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#load-dataset" data-toc-modified-id="load-dataset-1">load dataset</a></span></li><li><span><a href="#train-model" data-toc-modified-id="train-model-2">train model</a></span></li></ul></div>

# load dataset

In [1]:
import os
import tensorflow as tf
from functools import partial
from tensorflow.keras.applications import EfficientNetB0

In [2]:
def _parse_function(tfrecord_serialized, image_size, num_classes):
    features = {'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64)
               }
    parsed_features = tf.io.parse_single_example(tfrecord_serialized, features)

    image = tf.io.decode_raw(parsed_features['image'], tf.uint8)
    image = tf.reshape(image, [224, 224, 3])
    image = tf.image.resize(image, [image_size, image_size])

    label = tf.cast(parsed_features['label'], tf.int64)
    label = tf.one_hot(label, num_classes)

    return image, label

In [3]:
BATCH_SIZE = 16

In [4]:
train_path = os.getenv('HOME') + '/UDIGO/data/place_55_74604_train_shuffle.tfr'
train_dataset = tf.data.TFRecordDataset(train_path)

train_ds = train_dataset.map(
    partial(_parse_function, image_size=224, num_classes=55), 
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

train_ds = train_ds.shuffle(BATCH_SIZE * 100)
train_ds = train_ds.batch(BATCH_SIZE)
# train_ds = train_ds.repeat()
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
val_path = os.getenv('HOME') + '/UDIGO/data/place_55_21516_val_shuffle.tfr'
val_dataset = tf.data.TFRecordDataset(val_path)

val_ds = val_dataset.map(
    partial(_parse_function, image_size=224, num_classes=55), 
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

val_ds = val_ds.batch(BATCH_SIZE)

# train model

In [6]:
base_model = EfficientNetB0(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet',
    pooling='avg'
)

In [7]:
base_model.summary()

Model: "efficientnetb0"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 224, 224, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
normalization (Normalization)   (None, 224, 224, 3)  7           rescaling[0][0]                  
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 225, 225, 3)  0           normalization[0][0]              
_____________________________________________________________________________________

In [8]:
input_layer = tf.keras.layers.Input((224, 224, 3))
model = base_model(input_layer)
# model = tf.keras.layers.Dense(256, activation='relu')(model)
# model = tf.keras.layers.BatchNormalization()(model)
model = tf.keras.layers.Dense(55)(model)
model = tf.keras.Model(input_layer, model)

In [9]:
model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
efficientnetb0 (Functional)  (None, 1280)              4049571   
_________________________________________________________________
dense (Dense)                (None, 55)                70455     
Total params: 4,120,026
Trainable params: 4,078,003
Non-trainable params: 42,023
_________________________________________________________________


In [10]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [11]:
history = model.fit(
    train_ds,
#     steps_per_epoch=int(TRAIN_SIZE/BATCH_SIZE),
#     validation_steps=int(VAL_SIZE/BATCH_SIZE),
    epochs=10,
    validation_data=val_ds
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
