## Import Library

In [1]:
# utility
import os
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]="3"

# keras tensorflow wrapper
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.applications import InceptionV3, Xception
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.python.keras.metrics import top_k_categorical_accuracy
from tensorflow.python.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.python.keras.optimizers import Adam

# scikit-learn helper function
from sklearn.utils.class_weight import compute_class_weight

## Helper functions

In [2]:
def top_3_accuracy(true, pred):
    return top_k_categorical_accuracy(true, pred, k=3)

def path_join(dirname, img_paths):
    return [os.path.join(dirname, img_path) for img_path in img_paths]

## 이미지 데이터 전처리 및 generator 생성

In [3]:
TRAIN_PATH = '../training'
datagen = ImageDataGenerator(rescale=1./255,
                             validation_split=0.1)

batch_size = 32
input_shape = (224,224)

generator_train = datagen.flow_from_directory(directory=TRAIN_PATH,
                                              target_size=input_shape,
                                              shuffle=True,
                                              subset="training"
                                              )

generator_validate = datagen.flow_from_directory(directory=TRAIN_PATH,
                                                 target_size=input_shape,
                                                 shuffle=False,
                                                 subset="validation"
                                                 )
steps_train = generator_train.n / batch_size
steps_validate = generator_validate.n / batch_size

cls_train = generator_train.classes
cls_validate = generator_validate.classes

num_classes = generator_train.num_classes

class_weight = compute_class_weight(class_weight='balanced',
                                    classes=np.unique(cls_train),
                                    y=cls_train) 

Found 673383 images belonging to 20 classes.
Found 74813 images belonging to 20 classes.


## 모델 정의 및 구축

In [4]:
class Model():
    def __init__(self, name, class_weight, params):
        assert name != '', "Model name needs to be specified"
        self.name = name
        self.trained = False
        # feature extraction
        
    def construct_model(self):
        if self.name == 'inceptionv3':
            print('{:=^75}'.format('Downloading {}'.format(self.name)))
            self.base_model = InceptionV3(**params['network_params'])
            print('{:=^75}'.format('Download Complete'))
            
        elif self.name == 'xception':
            print('{:=^75}'.format('Downloading {}'.format(self.name)))
            self.base_model = Xception(**params['network_params'])
            print('{:=^75}'.format('Download Complete'))
            
            
        # 모델 구조  base model -> global average pooling -> dense
        print('{:=^75}'.format('Adding layers'))
        self.model = Sequential()
        self.model.add(self.base_model)
        self.model.add(GlobalAveragePooling2D())
        self.model.add(Dense(params['num_classes'], activation='softmax'))
        print('{:=^75}'.format('Added layers'))
        
        if params['mode'] == 'fe':
            self.model.layers[0].trainable = False
            
        # finetuning
        elif params['mode'] == 'ft':
            self.model.layers[0].trainable = True 
        
        # 지정 경로에 저장
        if not os.path.exists('weight_path/'):
            os.mkdir('weight_path/')
        self.weight_save_path = os.path.join('weight_path/', self.name + "_weights.h5")
        
        print('{:=^75}'.format('Saving weights to {}'.format(self.weight_save_path)))
        self.model.save_weights(self.weight_save_path)
        print('{:=^75}'.format('Saved weights'))
    
    
    # train with feature extraction
    def train(self):
        if self.trained == True:
            self.model.load_weights(self.weight_save_path)
            self.trained = False
        
        assert params['mode'] in ['fe', 'ft'], "mode must be either 'fe' or 'ft'"  
            
        # compile the model with designated parameters    
        self.model.compile(optimizer=Adam(lr=params['lr']),
                           loss='categorical_crossentropy',
                           metrics=['categorical_accuracy', top_3_accuracy])
        
        if not os.path.exists(params['log_path']):
            os.mkdir(params['log_path'])
        
        if not os.path.exists(params['cp_path']):
            os.mkdir(params['cp_path'])
        
        # csv logger callback 
        log_path = os.path.join(params['log_path'], self.name + '_' + params['mode'] + '.log')
        csvlog_callback = CSVLogger(log_path)
        
        # checkpoint callback 
        cp_path = os.path.join(params['cp_path'], self.name + '_' + params['mode'] + '-{epoch:04d}-{val_loss:.2f}.h5')
        cp_callback = ModelCheckpoint(cp_path,
                                      mode="max",
                                      save_best_only=True)
        
        print('{:=^75}'.format('training {} with {} mode'.format(self.name, params['mode'])))
        # actual data fitting
        self.model.fit_generator(generator=generator_train,
                                  epochs=params['epoch'],
                                  class_weight=class_weight,
                                  validation_data=generator_validate,
                                  validation_steps=steps_validate,
                                  callbacks=[cp_callback, csvlog_callback])
        
        # save model once done training    
        if not os.path.exists(params['model_path']):
            os.mkdir(params['model_path'])
            
        model_save_path = os.path.join(params['model_path'], model.name + '_' + params['mode'] + '.h5')
        self.model.save(model_save_path)
        self.trained = True

In [5]:
params = {
    'num_classes': num_classes,
    'log_path': 'log/',
    'cp_path': 'checkpoint/',
    'model_path': 'model/',
    'mode': 'fe',
    'lr': 0.0001,
    'epoch': 10,
    'network_params': {
    'include_top' : False, 
    'weights' : 'imagenet', 
    'input_shape' : input_shape + (3,)
    }
}

inception = Model(name='inceptionv3', class_weight=class_weight, params=params)
xception = Model(name='xception', class_weight=class_weight, params=params)

In [16]:
inception.construct_model()



In [6]:
xception.construct_model()



In [8]:
inception.model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
inception_v3 (Model)         (None, 5, 5, 2048)        21802784  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 20)                40980     
Total params: 21,843,764
Trainable params: 21,809,332
Non-trainable params: 34,432
_________________________________________________________________


In [7]:
xception.model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
xception (Model)             (None, 7, 7, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 20)                40980     
Total params: 20,902,460
Trainable params: 40,980
Non-trainable params: 20,861,480
_________________________________________________________________


In [None]:
inception.train()

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
  669/21044 [..............................] - ETA: 39:59 - loss: 1.2836 - categorical_accuracy: 0.5742 - top_3_accuracy: 0.8540

In [None]:
xception.train()

Epoch 1/10

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 3/10
 2939/21044 [===>..........................] - ETA: 36:06 - loss: 1.2480 - categorical_accuracy: 0.5975 - top_3_accuracy: 0.8621

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
params.update({
    'mode': 'ft',
    'lr': 0.0001
})

In [None]:
inception.train()

In [None]:
xception.train()

## 모델 훈련

### Feature Extraction

In [39]:
# compile 전 feature extraction에서는 dense를 제외한 모든 layer의 weight를 고정
incep_model.layers[0].trainable = False
xcep_model.layers[0].trainable = False

# model 훈련전 compile을 실행
incep_model.compile(optimizer=Adam(lr=0.001), 
                    loss='categorical_crossentropy', 
                    metrics=['categorical_accuracy', top_3_accuracy])

xcep_model.compile(optimizer=Adam(lr=0.001),
                   loss='categorical_crossentropy', 
                   metrics=['categorical_accuracy', top_3_accuracy])


epochs = 20

class_weight = compute_class_weight(class_weight='balanced',
                                   classes=np.unique(cls_train),
                                   y=cls_train)

log_path = 'log/inception_fe.log'
checkpoint_path = "checkpoint/model-{epoch:04d}-{val_acc:.2f}.h5"
cp_callback = ModelCheckpoint(checkpoint_path,
                              mode="max",
                              save_best_only=True)

csvlog_callback = CSVLogger(log_path)
incep_model.fit_generator(generator=generator_train,
                          epochs=epochs,
                          class_weight=class_weight,
                          validation_data=generator_validate,
                          validation_steps=steps_validate,
                          callbacks=[cp_callback, csvlog_callback])

log_path = 'log/xception_fe.log'
csvlog_callback = CSVLogger(log_path)
checkpoint_path = "checkpoint/model-{epoch:04d}-{val_acc:.2f}.h5"

cp_callback = ModelCheckpoint(checkpoint_path,
                              mode="max",
                              save_best_only=True)
xcep_model.fit_generator(generator=generator_train,
                          epochs=epochs,
                          class_weight=class_weight,
                          validation_data=generator_validate,
                          validation_steps=steps_validate)

# 훈련이 끝난 모델을 지정한 경로에 저장
incep_model.save(model_save_path + 'inceptionv3.h5')
xcep_model.save(model_save_path + 'xception.h5')

Epoch 1/20

KeyError: 'val_acc'

### Fine Tuning

In [None]:
# 훈련이 끝난 모델의 weight를 초기 weight로 다시 불러옴
incep_model.load_weights('inceptionv3_weights.h5')
xcep_model.load_weights('xception_weights.h5')

# fine tuning에서는 모든 layer의 weight를 update
incep_model.layers[0].trainable=True
xcep_model.layers[0].trainable=True

# fine tuning에서는 learning rate를 낮춰서 훈련 & recompile
incep_model.compile(optimizer=Adam(lr=0.0001), 
                    loss='categorical_crossentropy', 
                    metrics=['categorical_accuracy', top_3_accuracy])

xcep_model.compile(optimizer=Adam(lr=0.0001),
                   loss='categorical_crossentropy', 
                   metrics=['categorical_accuracy', top_3_accuracy])


log_path = 'log/inception_ft.log'
incep_model.fit_generator(generator=generator_train,
                          epochs=epochs,
                          class_weight=class_weight,
                          validation_data=generator_validate,
                          validation_steps=steps_validate,
                          callbacks=[cp_callback, csvlog_callback])

log_path = 'log/xception_ft.log'
xcep_model.fit_generator(generator=generator_train,
                          epochs=epochs,
                          class_weight=class_weight,
                          validation_data=generator_validate,
                          validation_steps=steps_validate)

# 훈련이 끝난 모델을 지정한 경로에 저장
incep_model.save(model_save_path + 'inceptionv3_finetune.h5')
xcep_model.save(model_save_path + 'xception_finetune.h5')