In [None]:
import datetime
import os

import numpy as np
import tensorflow as tf
from keras.callbacks import LearningRateScheduler, ModelCheckpoint, TensorBoard
from keras.layers import Conv2D, Dense, DepthwiseConv2D, PReLU
from keras.optimizers import SGD, Adam
from keras.regularizers import l2
from keras.utils.multi_gpu_utils import multi_gpu_model

from nets.arcface import arcface
from nets.arcface_training import ArcFaceLoss, get_lr_scheduler
from utils.callbacks import (ExponentDecayScheduler, LFW_callback, LossHistory,
                             ParallelModelCheckpoint)
from utils.dataloader import FacenetDataset, LFWDataset
from utils.utils import get_acc, get_num_classes, show_config

tf.logging.set_verbosity(tf.logging.ERROR)

if __name__ == "__main__":
    #---------------------------------------------------------------------#
    #   train_gpu   訓練用到的GPU
    #               默認為第一張卡、雙卡為[0, 1]、三卡為[0, 1, 2]
    #               在使用多GPU時，每個卡上的batch為總batch除以卡的數量。
    #---------------------------------------------------------------------#
    train_gpu       = [0]
    #--------------------------------------------------------#
    #   指向根目錄下的cls_train.txt，讀取人臉路徑與標籤
    #--------------------------------------------------------#
    annotation_path = "cls_train.txt"
    #--------------------------------------------------------#
    #   輸入圖像大小，112*112為長寬，3為RGB圖 ，1為灰階圖
    #--------------------------------------------------------#
    input_shape     = [112, 112, 3]
    #--------------------------------------------------------#
    #   主幹特徵提取網絡的選擇
    #   mobilefacenet
    #   mobilenetv1
    #   mobilenetv2
    #   mobilenetv3
    #   iresnet50
    #
    #   除了mobilenetv1外，其它的backbone均可從0開始訓練。
    #--------------------------------------------------------#
    backbone        = "mobilefacenet"
    #----------------------------------------------------------------------------------------------------------------------------#
    #   如果訓練過程中存在中斷訓練的操作，可以將model_path設置成logs文件夾下的權值文件，將已經訓練了一部分的權值再次載入。
    #   同時修改下方的訓練的參數，來保證模型epoch的連續性。
    #   
    #   當model_path = ''的時候不加載整個模型的權值。
    #
    #   此處使用的是整個模型的權重，因此是在train.py進行加載的，pretrain不影響此處的權值加載。
    #   如果想要讓模型從主幹的預訓練權值開始訓練，則設置model_path = 主幹的權值。
    #   如果想要讓模型從0開始訓練，則設置model_path = ''，此時從0開始訓練。
    #----------------------------------------------------------------------------------------------------------------------------#  
    model_path      = ""

    #----------------------------------------------------------------------------------------------------------------------------#
   #   顯存不足與數據集大小無關，提示顯存不足請調小batch_size。
    #   受到BatchNorm層影響，不能為1。
    #
    #   在此提供若干參數設置建議，各位訓練者根據自己的需求進行靈活調整：
    #   （一）從預訓練權重開始訓練：
    #       Adam：
    #           Init_Epoch = 0，Epoch = 100，optimizer_type = 'adam'，Init_lr = 1e-3，weight_decay = 0。
    #       SGD：
    #           Init_Epoch = 0，Epoch = 100，optimizer_type = 'sgd'，Init_lr = 1e-2，weight_decay = 5e-4。
    #       其中：UnFreeze_Epoch可以在100-300之間調整。
    #   （二）batch_size的設置：
    #       在顯卡能夠承受的範圍內，越大越好。顯存不足與訓練數據量無關，提示顯存不足(OOM或者CUDA out of memory）請調小請調小batch_size
    #       受到BatchNorm層影響，batch_size最小為2，不能為1。
    #----------------------------------------------------------------------------------------------------------------------------#
    #------------------------------------------------------#
    #   訓練參數
    #   Init_Epoch      模型當前開始的訓練世代
    #   Epoch           模型總共訓練的epoch(次數)
    #   batch_size      每次輸入的圖片數量
    #------------------------------------------------------#
    Init_Epoch      = 0
    Epoch           = 100
    batch_size      = 12

   #------------------------------------------------------------------#
    #   其它訓練參數：學習率、優化器、學習率下降有關
    #------------------------------------------------------------------#
    #------------------------------------------------------------------#
    #   Init_lr         模型的最大學習率
    #   Min_lr          模型的最小學習率，默認為最大學習率的0.01
    #------------------------------------------------------------------#
    Init_lr             = 1e-2
    Min_lr              = Init_lr * 0.01
    #------------------------------------------------------------------#
    #   optimizer_type  使用到的優化器種類，可選的有adam、sgd
    #                   當使用Adam優化器時建議設置  Init_lr=1e-3
    #                   當使用SGD優化器時建議設置   Init_lr=1e-2
    #   momentum        優化器內部使用到的momentum參數
    #   weight_decay    權值衰減，可防止過擬合
    #                   adam會導致weight_decay錯誤，使用adam時建議設置為0。
    #------------------------------------------------------------------#
    optimizer_type      = "sgd"
    momentum            = 0.9
    weight_decay        = 5e-4
    #------------------------------------------------------------------#
    #   lr_decay_type   使用到的學習率下降方式，可選的有step、cos
    #------------------------------------------------------------------#
    lr_decay_type       = "cos"
    #------------------------------------------------------------------#
    #   save_period     多少個epoch保存一次權值，默認每個世代都保存
    #------------------------------------------------------------------#
    save_period         = 1
    #------------------------------------------------------------------#
    #   save_dir        權值與日誌文件保存的文件夾
    #------------------------------------------------------------------#
    save_dir            = 'logs'
    #------------------------------------------------------------------#
    #   用於設置是否使用多線程讀取數據
    #   開啟後會加快數據讀取速度，但是會佔用更多內存
    #   內存較小的電腦可以設置為2或者1
    #------------------------------------------------------------------#
    num_workers     = 1
    #------------------------------------------------------------------#
    #   是否开启LFW评估
    #------------------------------------------------------------------#
    lfw_eval_flag   = False
    #------------------------------------------------------------------#
    #   LFW評估數據集的文件路徑和對應的txt文件
    #------------------------------------------------------------------#
    lfw_dir_path    = "lfw"
    lfw_pairs_path  = "model_data/lfw_pair.txt"

    #------------------------------------------------------#
    #   設置用到的顯卡
    #------------------------------------------------------#
    os.environ["CUDA_VISIBLE_DEVICES"]  = ','.join(str(x) for x in train_gpu)
    ngpus_per_node                      = len(train_gpu)
    print('Number of devices: {}'.format(ngpus_per_node))

    num_classes = get_num_classes(annotation_path)
    #-------------------------------------------#
    #   建立模型
    #-------------------------------------------#
    model_body = arcface(input_shape, num_classes, backbone=backbone, mode="train")
    if model_path != '':
        #------------------------------------------------------#
        #   載入預訓練權重
        #------------------------------------------------------#
        print('Load weights {}.'.format(model_path))
        model_body.load_weights(model_path, by_name=True, skip_mismatch=True)
        
    if ngpus_per_node > 1:
        #多GPU運算
        model   = multi_gpu_model(model_body, gpus=ngpus_per_node)
    else:
        model   = model_body
    #-------------------------------------------------------#
    #   0.01用於驗證，0.99用於訓練
    #-------------------------------------------------------#
    val_split = 0.16
    with open(annotation_path,"r") as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    num_val = int(len(lines)*val_split)
    num_train = len(lines) - num_val
    
    show_config(
        num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \
        Init_Epoch = Init_Epoch, Epoch = Epoch, batch_size = batch_size, \
        Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
        save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
    )

    for layer in model.layers:
        if isinstance(layer, DepthwiseConv2D):
            layer.add_loss(l2(weight_decay)(layer.depthwise_kernel))
        elif isinstance(layer, Conv2D) or isinstance(layer, Dense):
            layer.add_loss(l2(weight_decay)(layer.kernel))
        elif isinstance(layer, PReLU):
            layer.add_loss(l2(weight_decay)(layer.alpha))

    if True:
        #-------------------------------------------------------------------#
        #   判斷當前batch_size，自適應調整學習率
        #-------------------------------------------------------------------#
        nbs             = 64
        lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
        lr_limit_min    = 3e-4 if optimizer_type == 'adam' else 5e-4
        Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
        Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)

        optimizer = {
            'adam'  : Adam(lr = Init_lr_fit, beta_1 = momentum),
            'sgd'   : SGD(lr = Init_lr_fit, momentum = momentum, nesterov=True)
        }[optimizer_type]
        m = 0.5
        s = 32 if backbone == "mobilefacenet" else 64
        model.compile(optimizer = optimizer, loss={'ArcMargin': ArcFaceLoss(s = s, m = m)}, metrics={'ArcMargin': get_acc()})
    
        #---------------------------------------#
        #   獲得學習率下降的公式
        #---------------------------------------#
        lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, Epoch)

        epoch_step          = num_train // batch_size
        epoch_step_val      = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError('数据集过小，无法进行训练，请扩充数据集。')

        train_dataset   = FacenetDataset(input_shape, lines[:num_train], batch_size, num_classes, random = True)
        val_dataset     = FacenetDataset(input_shape, lines[num_train:], batch_size, num_classes, random = False)

        #-------------------------------------------------------------------------------#
        #   訓練參數的設置
        #   logging         用於設置tensorboard的保存地址
        #   checkpoint      用於設置權值保存的細節，period用於修改多少epoch保存一次
        #   lr_scheduler       用於設置學習率下降的方式
        #   early_stopping  用於設定早停，val_loss多次不下降自動結束訓練，表示模型基本收斂
        #-------------------------------------------------------------------------------#
        time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
        log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
        logging         = TensorBoard(log_dir)
        loss_history    = LossHistory(log_dir)
        if ngpus_per_node > 1:
            #GPU保存模型 需調用ParallelModelCheckpoint解決code error問題
            checkpoint      = ParallelModelCheckpoint(model_body, os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"), 
                                    monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period)
        else:
            checkpoint      = ModelCheckpoint(os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"), 
                                    monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period)
        lr_scheduler    = LearningRateScheduler(lr_scheduler_func, verbose = 1)
        #---------------------------------#
        #   LFW估计
        #---------------------------------#
        if lfw_eval_flag:
            lfw_callback    = LFW_callback(LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, batch_size=32, input_shape=input_shape))
            callbacks       = [logging, loss_history, checkpoint, lr_scheduler, lfw_callback]
        else:
            callbacks       = [logging, loss_history, checkpoint, lr_scheduler]

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
        model.fit_generator(
            generator           = train_dataset,
            steps_per_epoch     = epoch_step,
            validation_data     = val_dataset,
            validation_steps    = epoch_step_val,
            epochs              = Epoch,
            initial_epoch       = Init_Epoch,
            use_multiprocessing = True if num_workers > 1 else False,
            workers             = num_workers,
            callbacks           = callbacks
        )
