In [11]:
import os
# model
MODEL_NAME = 'inception'
IMAGE_SIZE = 224

# hyperparameter
EPOCH = 100
LEARNING_RATE = 0.001
BATCH_SIZE = 16

#shuffle True or Flase
SHUFFLE = True

# tfrecord 및 결과 저장 상위 경로
CONST_ROOT_PATH = 'C:/Users/USER/Desktop/Git/capstone/Capstone' # '~/Desktop/Git/capstone/Capstone'

# tfrecord 경로 (하위폴더 구조 : (train / val)/*.tfrecords) 
CONST_DIR_PATH = f"{CONST_ROOT_PATH}/tfrecord/tfrecord_1_2_3_4_5_6_7"

In [12]:
import os
CONST_ROOT_PATH = 'C:/Users/USER/Desktop/Git/capstone/Capstone'
CONST_DIR_PATH = f"{CONST_ROOT_PATH}/tfrecord/tfrecord_1_2_3_4_5_6_7"

CONST_SAVE_DIR = f"{CONST_ROOT_PATH}/result"
EXPERIMENT_DIR = f"{MODEL_NAME}_{IMAGE_SIZE}_{EPOCH}_{LEARNING_RATE}_{BATCH_SIZE}_{SHUFFLE}"
RESULT_FILE_DIR = os.path.join(os.path.expanduser(CONST_SAVE_DIR), EXPERIMENT_DIR)

In [None]:
import os

import pandas as pd
import matplotlib.pyplot as plt

import glob
from sklearn.model_selection import train_test_split

import keras.backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.applications import InceptionResNetV2

In [None]:
class Import_data:
    def __init__(self, train_data_path, val_data_path):
        self.train_data_path = train_data_path
        self.val_data_path = val_data_path

    def get_generators(self):
        # data 전처리
        data_generator = ImageDataGenerator(
            rescale=1./255,
            featurewise_std_normalization=True,
            shear_range=0.2, 
            zoom_range=0.2,                        
            channel_shift_range=0.1,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            fill_mode='constant',
            cval=0
        )
        
        # batch_size만큼 train_data 불러오기
        train_generator = data_generator.flow_from_directory(
            self.train_data_path,
            target_size=(224, 224),
            batch_size=8,
            class_mode='categorical'
        )
        
        # batch_size만큼 val_data 불러오기
        val_generator = data_generator.flow_from_directory(
            self.val_data_path,
            target_size=(224, 224),
            batch_size=8,
            class_mode='categorical'
        )
        
        return train_generator, val_generator

In [None]:
class Load_model:
    def __init__(self, train_data_path, model_name):
        self.num_class = len(os.listdir(train_data_path))
        self.model_name = model_name

    def build_network(self):
        if self.model_name == '':
            network = InceptionResNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
            model = Sequential()
            model.add(network)
            model.add(Dense(2048, activation='relu'))
            model.add(Dense(self.num_class, activation='softmax'))
            
        model.summary()

        return model

In [None]:
class Train_model:
    def __init__(self, train_data_path, val_data_path, model_name, epoch):
        self.data = Import_data(train_data_path, val_data_path)
        self.model = Load_model(train_data_path, model_name)
        self.model_name = model_name
        self.epoch = epoch

    def train(self):
        train_generator, val_generator = self.data.get_generators()

        save_folder = './checkpoint/' + self.model_name + '-' + str(self.epoch) + '/'
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        check_point = ModelCheckpoint(
            save_folder + 'model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5',
            monitor='val_acc', 
            save_best_only=True, 
            mode='auto', 
            verbose=1
        )

        model = self.model.build_network()
        
        model.compile(
            loss='categorical_crossentropy', 
            optimizer=Adam(learning_rate=0.001), 
            metrics=['acc']
        )

        history = model.fit(
            train_generator,
            steps_per_epoch=train_generator.samples // train_generator.batch_size,
            validation_data=val_generator,
            validation_steps=val_generator.samples // val_generator.batch_size,
            epochs=self.epoch,
            callbacks=[check_point],
            verbose=1
        )

        return history
    

    def save_result(self, history):
        save_folder = './checkpoint/' + self.model_name + '-' + str(self.epoch) + '/'
        
        train_loss = history.history['loss']
        val_loss = history.history['val_loss']
        train_acc = history.history['acc']
        val_acc = history.history['val_acc']
        
        epochs = range(len(train_loss))
        epoch_list = list(epochs)
        
        df = pd.DataFrame({
            'epoch': epoch_list, 
            'train_loss': train_loss,
            'validation_loss': val_loss,
            'train_accuracy': train_acc, 
            'validation_accuracy': val_acc
        })
        df_save_path = save_folder + 'result.csv'
        df.to_csv(df_save_path, index=False, encoding='euc-kr')
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.plot(epochs, train_loss, 'b', label='Train Loss')
        plt.plot(epochs, val_loss, 'r', label='Validation Loss')
        plt.title('Train and Validation Loss')
        plt.legend()
        plt.subplot(1, 2, 2)
        plt.plot(epochs, train_acc, 'b', label='Train Accuracy')
        plt.plot(epochs, val_acc, 'r', label='Validation Accuracy')
        plt.title('Train and Validation Accuracy')
        plt.legend()
        save_path = save_folder + 'result.png'
        plt.savefig(save_path)
        plt.cla()

        K.clear_session()

In [None]:
train_data_path = 'train/'
val_data_path = 'val/'
model_name = 'resnet_v1_50'
epoch = 100

if __name__ == '__main__':
    model = Train_model(train_path=train_data_path, model_name=model_name, epoch=epoch)
    history = model.train()
    model.save_result(history)