In [1]:
import os
import pickle
import cv2 as cv
import numpy as np
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import models, Sequential
from tensorflow.keras.layers import Dropout, Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

from models import efficientNetV2B0_model, efficientNetV2B3_model
from config import efficientNet_config

from keras.preprocessing.image import ImageDataGenerator

In [2]:
height  = efficientNet_config['height_B0']
width   = efficientNet_config['width_B0']
input_shape  = efficientNet_config['input_shape_B0']

# 讀取資料

In [3]:
train_dir      = '../dataset_smooth_22video_20221031/train'
validation_dir = '../dataset_smooth_22video_20221031/validation'

In [4]:
train_img_arrays = []
train_img_labels = []
validation_img_arrays = []
validation_img_labels = []

## train
img_paths = glob(train_dir + "/0/*.png")
for img_path in img_paths:
    # imege
    img_array = cv.imread(img_path)
    img_array = cv.resize(img_array,(height,width))     # resize to (224,224)
    train_img_arrays.append(img_array)
    #label
    train_img_labels.append([0])

img_paths = glob(train_dir + "/1/*.png")
for img_path in img_paths:
    # imege
    img_array = cv.imread(img_path)
    img_array = cv.resize(img_array,(height,width))      # resize to (224,224)
    train_img_arrays.append(img_array)
    #label
    train_img_labels.append([1])


## validation
img_paths = glob(validation_dir + "/0/*.png")
for img_path in img_paths:
    # imege
    img_array = cv.imread(img_path)
    img_array = cv.resize(img_array,(height,width))      # resize to (224,224)
    validation_img_arrays.append(img_array)
    #label
    validation_img_labels.append([0])

img_paths = glob(validation_dir + "/1/*.png")
for img_path in img_paths:
    # imege
    img_array = cv.imread(img_path)
    img_array = cv.resize(img_array,(height,width))      # resize to (224,224)
    validation_img_arrays.append(img_array)
    #label
    validation_img_labels.append([1])

In [5]:
print('訓練集數量= ', len(train_img_labels))

訓練集數量=  723


In [6]:
print('驗證集數量= ',len(validation_img_labels))

驗證集數量=  379


In [7]:
train_img_arrays = np.array(train_img_arrays)
train_img_labels = np.array(train_img_labels)
validation_img_arrays = np.array(validation_img_arrays)
validation_img_labels = np.array(validation_img_labels)

In [8]:
print('訓練集維度= ',train_img_arrays.shape)
print('驗證集維度= ',validation_img_arrays.shape)

訓練集維度=  (723, 224, 224, 3)
驗證集維度=  (379, 224, 224, 3)


# 資料擴增

In [None]:
'''https://medium.com/ai%E5%8F%8D%E6%96%97%E5%9F%8E/preprocessing-data-image-data-augmentation%E5%AF%A6%E4%BD%9C%E8%88%87%E5%8F%83%E6%95%B8%E8%AA%AA%E6%98%8E-d05f2ed24194'''

datagen = ImageDataGenerator(  
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-6,
    # rotation_range=180.,
    # width_shift_range=0.,
    # height_shift_range=0.,
    # shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    # cval=0.,
    rescale=None,
    # preprocessing_function=None
    )
# rescale=1./255,
# rotation_range=0,
# width_shift_range=0.2,
# height_shift_range=0.2,
# shear_range=0.2,
# zoom_range=0.2,
# horizontal_flip=True,
# fill_mode='nearest',

In [None]:
model = efficientNetV2B0_model()

# TODO: 嘗試看看新的
# https://stackoverflow.com/questions/71909901/guiding-tensorflow-keras-model-training-to-achieve-best-recall-at-precision-0-95


model.compile(loss='binary_crossentropy', 
              optimizer=Adam(learning_rate=1e-4),
              metrics=['accuracy',
                       tf.keras.metrics.Recall(name='recall', thresholds=0.5)
                      ])


In [None]:
day = '20220917'
checkpoint_filepath = '../model/202209/{}.weights'.format(day)
# model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
#                                                                monitor='val_accuracy',
#                                                                mode='max',
#                                                                save_weights_only=True,
#                                                                save_best_only=True)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
                                                               monitor='val_recall',
                                                               mode='max',
                                                               save_freq='epoch',
                                                               save_weights_only=True,
                                                               save_best_only=True,
                                                               verbose=1)

callbacks = [model_checkpoint_callback,]

In [None]:
# 使用第二張 GPU 卡
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# 訓練
epochs = 1000
history = model.fit(
      datagen.flow(train_img_arrays, train_img_labels, batch_size=32),
      validation_data = (validation_img_arrays,validation_img_labels),      
      epochs          = epochs,
      callbacks       = callbacks,
      verbose         = 1
)

# 畫圖

In [None]:
loss         = history.history['loss']
val_loss     = history.history['val_loss']
accuracy     = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
recall       = history.history['recall']
val_recall   = history.history['val_recall']

In [None]:
x = [i for i in range(1, epochs+1)]

max_valrecall_x = val_recall.index(max(val_recall)) + 1
max_valrecall_y = max(val_recall)

plt.figure(figsize=(24,4))
plt.plot(x, recall, 'r')     
plt.plot(x, val_recall, 'b')     # red dotted line (no marker)

plt.plot(max_valrecall_x, max_valrecall_y, 'd', color='g')
plt.text(max_valrecall_x, max_valrecall_y, "({},{})".format(max_valrecall_x,round(max_valrecall_y,2)), ha='left',va='top',fontsize=20)

plt.legend(['recall','val_recall'])
plt.ylim((0, 1.0))
plt.xlabel('epoch')
plt.ylabel('recall')
plt.title('recall')
plt.savefig('../model/202209/{}_recall.png'.format(day))
plt.show()

In [None]:
x = [i for i in range(1, epochs+1)]

max_valacc_x = val_accuracy.index(max(val_accuracy)) + 1
max_valacc_y = max(val_accuracy)

plt.figure(figsize=(24,4))
plt.plot(x, accuracy, 'r')     
plt.plot(x, val_accuracy, 'b')     # red dotted line (no marker)

plt.plot(max_valacc_x, max_valacc_y, 'd', color='g')
plt.text(max_valacc_x, max_valacc_y, "({},{})".format(max_valacc_x, round(max_valacc_y,2)), ha='left',va='top',fontsize=20)

plt.legend(['accuracy','val_accuracy'])
plt.ylim((0, 1.1))
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('accuracy')
plt.savefig('../model/202209/{}_acc.png'.format(day))
plt.show()

In [None]:
x = [i for i in range(1, epochs+1)]

min_valloss_x = val_loss.index(min(val_loss)) + 1
min_valloss_y = min(val_loss)

plt.figure(figsize=(24,4))
plt.plot(x, loss, 'r')     
plt.plot(x, val_loss, 'b')     # red dotted line (no marker)

plt.plot(min_valloss_x, min_valloss_y, 'd', color='g')
plt.text(min_valloss_x, min_valloss_y, "({},{})".format(min_valloss_x,round(min_valloss_y,2)), ha='left',va='top',fontsize=20)

plt.legend(['loss','val_loss'])
plt.ylim((0, 1.0))
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('loss')
plt.savefig('../model/202209/{}_loss.png'.format(day))
plt.show()

In [None]:
with open('../model/202209/{}.history'.format(day),'wb') as f:
    pickle.dump(history, f) 