<h1>Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#load-dataset" data-toc-modified-id="load-dataset-1">load dataset</a></span></li><li><span><a href="#train-model" data-toc-modified-id="train-model-2">train model</a></span></li></ul></div>

# load dataset

In [1]:
# python 기본
import os
from functools import partial
# 실험관리
import wandb
from wandb.keras import WandbCallback
# 딥러닝 프레임워크
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB1

In [2]:
def _parse_function(tfrecord_serialized, image_size):
    features = {'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64)
               }
    parsed_features = tf.io.parse_single_example(tfrecord_serialized, features)

    image = tf.io.decode_raw(parsed_features['image'], tf.uint8)
    image = tf.reshape(image, [256, 256, 3])
    image = tf.image.resize(image, [image_size, image_size])

    label = tf.cast(parsed_features['label'], tf.int64)
    label = tf.one_hot(label, 10)

    return image, label

In [3]:
tfr_path = './storage/data/place_10_15055.tfr'
dataset = tf.data.TFRecordDataset(tfr_path)

In [4]:
config = {
    "layer_1": 512,
    "layer_1_activation": "relu",
}

In [5]:
wandb.init(
    project="place_10_15055",
    config=config
          )

[34m[1mwandb[0m: Currently logged in as: [33mkec0130[0m (use `wandb login --relogin` to force relogin)


In [6]:
configs = wandb.config

In [7]:
dataset_size = len(list(dataset))
train_size = int(0.8 * dataset_size)
val_size = int(0.2 * dataset_size)
batch_size = 64

In [8]:
parsed_dataset = dataset.map(
    partial(_parse_function, image_size=256), 
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
dataset = parsed_dataset.shuffle(dataset_size)

In [9]:
train_ds = dataset.take(train_size)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.repeat()
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

In [10]:
val_ds = dataset.skip(train_size)
val_ds = val_ds.batch(batch_size)

# train model

In [11]:
base_model = EfficientNetB1(
    input_shape=(256, 256, 3),
    include_top=False,
    weights='imagenet',
    pooling='avg'
)

In [12]:
base_model.summary()

Model: "efficientnetb1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 256, 256, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
normalization (Normalization)   (None, 256, 256, 3)  7           rescaling[0][0]                  
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 257, 257, 3)  0           normalization[0][0]              
_____________________________________________________________________________________

Total params: 6,575,239
Trainable params: 6,513,184
Non-trainable params: 62,055
__________________________________________________________________________________________________


In [13]:
# Keras Functional API
input_layer = tf.keras.layers.Input((256, 256, 3))
model = base_model(input_layer)
model = tf.keras.layers.Dense(configs['layer_1'], activation=configs['layer_1_activation'])(model)
model = tf.keras.layers.BatchNormalization()(model)
model = tf.keras.layers.Dense(10)(model)
model = tf.keras.Model(input_layer, model)

In [14]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
efficientnetb1 (Functional)  (None, 1280)              6575239   
_________________________________________________________________
dense (Dense)                (None, 512)               655872    
_________________________________________________________________
batch_normalization (BatchNo (None, 512)               2048      
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
Total params: 7,238,289
Trainable params: 7,175,210
Non-trainable params: 63,079
_________________________________________________________________


In [15]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [16]:
save_dir = './storage/models/EfficientNetB1/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [17]:
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
mc = tf.keras.callbacks.ModelCheckpoint(
    filepath=save_dir+'{epoch}-{val_loss:.2f}-{val_accuracy:.2f}.h5',
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)
wan = WandbCallback()

In [18]:
history = model.fit(
    train_ds,
    epochs=9999,
    validation_data=val_ds,
    steps_per_epoch=dataset_size//batch_size,
    callbacks=[es, mc, wan]
)

Epoch 1/9999

Epoch 00001: val_accuracy improved from -inf to 0.93889, saving model to ./storage/models/EfficientNetB11-0.24-0.94.h5
Epoch 2/9999

Epoch 00002: val_accuracy improved from 0.93889 to 0.96247, saving model to ./storage/models/EfficientNetB12-0.11-0.96.h5
Epoch 3/9999

Epoch 00003: val_accuracy did not improve from 0.96247
Epoch 4/9999

Epoch 00004: val_accuracy improved from 0.96247 to 0.96945, saving model to ./storage/models/EfficientNetB14-0.08-0.97.h5
Epoch 5/9999

Epoch 00005: val_accuracy did not improve from 0.96945
Epoch 6/9999

Epoch 00006: val_accuracy improved from 0.96945 to 0.99469, saving model to ./storage/models/EfficientNetB16-0.02-0.99.h5
Epoch 7/9999

Epoch 00007: val_accuracy did not improve from 0.99469
Epoch 8/9999

Epoch 00008: val_accuracy did not improve from 0.99469
Epoch 9/9999

Epoch 00009: val_accuracy did not improve from 0.99469
Epoch 10/9999

Epoch 00010: val_accuracy did not improve from 0.99469
Epoch 11/9999

Epoch 00011: val_accuracy did