In [1]:
import tensorflow as tf
from train.train_tool import arcface_loss,read_single_tfrecord,average_gradients
from core import Arcface_model,config
import time
import os
from evaluate.evaluate import evaluation,load_bin


def train(image,label,train_phase_dropout,train_phase_bn, images_batch, images_f_batch, issame_list_batch):

    train_images_split = tf.split(image, config.gpu_num)
    train_labels_split = tf.split(label, config.gpu_num)      
    
    global_step = tf.Variable(name='global_step', initial_value=0, trainable=False)
    inc_op = tf.assign_add(global_step, 1, name='increment_global_step')    
    scale = int(512.0/batch_size)
    lr_steps = [scale*s for s in config.lr_steps]
    lr_values = [v/scale for v in config.lr_values]
    lr = tf.train.piecewise_constant(global_step, boundaries=lr_steps, values=lr_values, name='lr_schedule')
    opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=config.momentum)

    embds = []
    logits = []
    inference_loss = []
    wd_loss = []
    total_train_loss = []
    pred = []
    tower_grads = []
    update_ops = []
    
    for i in range(config.gpu_num):
        sub_train_images = train_images_split[i]
        sub_train_labels = train_labels_split[i]
        
        with tf.device("/gpu:%d"%(i)):
            with tf.variable_scope(tf.get_variable_scope(),reuse=(i>0)):
                
                net, end_points = Arcface_model.get_embd(sub_train_images, train_phase_dropout, train_phase_bn,config.model_params)
                        
                logit = arcface_loss(net,sub_train_labels,config.s,config.m)
                arc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logit , labels = sub_train_labels))
                L2_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                train_loss = arc_loss + L2_loss
                
                pred.append(tf.to_int32(tf.argmax(tf.nn.softmax(logit),axis=1)))
                tower_grads.append(opt.compute_gradients(train_loss))
                update_ops.append(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
                
                embds.append(net)
                logits.append(logit)
                inference_loss.append(arc_loss)
                wd_loss.append(L2_loss)
                total_train_loss.append(train_loss)

    embds = tf.concat(embds, axis=0)
    logits = tf.concat(logits, axis=0)
    pred = tf.concat(pred, axis=0)
    wd_loss = tf.add_n(wd_loss)/config.gpu_num
    inference_loss = tf.add_n(inference_loss)/config.gpu_num
    
    train_ops = [opt.apply_gradients(average_gradients(tower_grads))]
    train_ops.extend(update_ops)
    train_op = tf.group(*train_ops) 
    
    with tf.name_scope('loss'):
        train_loss = tf.add_n(total_train_loss)/config.gpu_num
        tf.summary.scalar('train_loss',train_loss)    

    with tf.name_scope('accuracy'):
        train_accuracy = tf.reduce_mean(tf.cast(tf.equal(pred, label), tf.float32))
        tf.summary.scalar('train_accuracy',train_accuracy) 
        
    saver=tf.train.Saver(max_to_keep=20)
    merged=tf.summary.merge_all() 
    
    train_images,train_labels=read_single_tfrecord(addr,batch_size,img_size)
    
    tf_config = tf.ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run((tf.global_variables_initializer(),
                  tf.local_variables_initializer()))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        writer_train=tf.summary.FileWriter(model_path,sess.graph)
        print("start")
        try:
            for i in range(1,train_step):                
                image_batch,label_batch=sess.run([train_images,train_labels])
                sess.run([train_op,inc_op],feed_dict={image:image_batch,label:label_batch,train_phase_dropout:True,train_phase_bn:True})
                
                if(i%100==0):
                    summary=sess.run(merged,feed_dict={image:image_batch,label:label_batch,train_phase_dropout:True,train_phase_bn:True})
                    writer_train.add_summary(summary,i)
                    
                if(i%1000==0):
                    print('times: ',i)    
#                     print('train_accuracy: ',sess.run(train_accuracy,feed_dict={image:image_batch,label:label_batch,train_phase_dropout:True,train_phase_bn:True}))
#                     print('train_loss: ',sess.run(train_loss,{image:image_batch,label:label_batch,train_phase_dropout:True,train_phase_bn:True}))       
                    print('time: ',time.time()-begin)
                    
                if(i%5000==0):
                    f.write("itrations: %d"%(i)+'\n')
                    for idx in range(len(eval_datasets)):
                        tpr, fpr, accuracy, best_thresholds = evaluation(sess, images_batch[idx], images_f_batch[idx], issame_list_batch[idx], batch_size, img_size, dropout_flag=config.eval_dropout_flag, bn_flag=config.eval_bn_flag, embd=embds, image=image, train_phase_dropout=train_phase_dropout, train_phase_bn=train_phase_bn) 
                        print("%s datasets get %.3f acc"%(eval_datasets[idx].split("/")[-1].split(".")[0],accuracy))
                        f.write("\t %s \t %.3f \t \t "%(eval_datasets[idx].split("/")[-1].split(".")[0],accuracy)+str(best_thresholds)+'\n')
                    f.write('\n')
                    
                if((i>150000)&(i%config.model_save_gap==0)):
                    saver.save(sess,os.path.join(model_path,model_name),global_step=i)
        except  tf.errors.OutOfRangeError:
            print("finished")
        finally:
            coord.request_stop()
            writer_train.close()
        coord.join(threads)
        f.close()
        
            
def main():
    
    with tf.name_scope('input'):
        image = tf.placeholder(tf.float32,[batch_size,img_size,img_size,3],name='image')
        label = tf.placeholder(tf.int32,[batch_size],name='label')
        train_phase_dropout = tf.placeholder(dtype=tf.bool, shape=None, name='train_phase_dropout')
        train_phase_bn = tf.placeholder(dtype=tf.bool, shape=None, name='train_phase_bn') 

    images_batch = []
    images_f_batch = []
    issame_list_batch = []
    for dataset_path in eval_datasets:
        images, images_f, issame_list = load_bin(dataset_path, img_size)    
        images_batch.append(images)
        images_f_batch.append(images_f)
        issame_list_batch.append(issame_list)
    
    train(image,label, train_phase_dropout, train_phase_bn, images_batch, images_f_batch, issame_list_batch)


if __name__ == "__main__":
    
    img_size = config.img_size
    batch_size = config.batch_size
    addr = config.addrt
    model_name = config.model_name
    train_step = config.train_step
    model_path = config.model_patht
    eval_datasets = config.eval_datasets
    
    begin=time.time()
    
    f = open("./eval_record.txt", 'w')
    f.write("\t dataset \t accuracy \t best_thresholds \t"+'\n')    
    main()
# tensorboard --logdir=/home/dell/Desktop/insightface/model/Arcface_model/

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.
reading ./data/lfw.bin
done!
reading ./data/agedb_30.bin
done!
reading ./data/calfw.bin
done!
reading ./data/cfp_ff.bin
done!
reading ./data/cfp_fp.bin
done!
reading ./data/cplfw.bin
done!
reading ./data/lfw_face.db
done!
start
times:  1000
time:  733.7754483222961
times:  2000
time:  1272.9938457012177
times:  3000
time:  1811.6488132476807
times:  4000
time:  2350.833826303482
times:  5000
time:  2890.7193360328674
lfw datasets get 0.947 acc
agedb_30 datasets get 0.774 acc
calfw datasets get 0.822 acc
cfp_ff datasets get 0.943 acc
cfp_fp datasets get 0.793 acc
cplfw datasets get 0.730 acc
lfw_face datasets get 0.774 acc
times:  6000
time:  3650.346690416336
times:  7000
time:  4190.842154979706
times:  8000
time:  4737.433312654495
times:  9000
time:  5276.991191864014
times:  10000
time:  5816.671426773071
lfw datasets get 0.955 acc
agedb_30 datasets get 0.819 acc
calfw datasets get 0.843 acc
cfp_ff datasets ge

cplfw datasets get 0.829 acc
lfw_face datasets get 0.897 acc
times:  96000
time:  56140.9234855175
times:  97000
time:  56679.2535970211
times:  98000
time:  57218.10208916664
times:  99000
time:  57758.682121515274
times:  100000
time:  58298.09419798851
lfw datasets get 0.983 acc
agedb_30 datasets get 0.903 acc
calfw datasets get 0.901 acc
cfp_ff datasets get 0.983 acc
cfp_fp datasets get 0.914 acc
cplfw datasets get 0.847 acc
lfw_face datasets get 0.901 acc
times:  101000
time:  59054.986095666885
times:  102000
time:  59593.79058718681
times:  103000
time:  60132.30726027489
times:  104000
time:  60671.68413114548
times:  105000
time:  61209.03508281708
lfw datasets get 0.988 acc
agedb_30 datasets get 0.910 acc
calfw datasets get 0.904 acc
cfp_ff datasets get 0.984 acc
cfp_fp datasets get 0.919 acc
cplfw datasets get 0.836 acc
lfw_face datasets get 0.899 acc
times:  106000
time:  61964.09339737892
times:  107000
time:  62501.722397089005
times:  108000
time:  63040.20493555069
time

lfw_face datasets get 0.879 acc
times:  196000
time:  114466.33062195778
times:  197000
time:  115004.25596928596
times:  198000
time:  115543.2208340168
times:  199000
time:  116082.37719321251
times:  200000
time:  116621.50878477097
lfw datasets get 0.982 acc
agedb_30 datasets get 0.896 acc
calfw datasets get 0.894 acc
cfp_ff datasets get 0.981 acc
cfp_fp datasets get 0.907 acc
cplfw datasets get 0.827 acc
lfw_face datasets get 0.869 acc
times:  201000
time:  117378.05904436111
times:  202000
time:  117916.85774278641
times:  203000
time:  118457.09539484978
times:  204000
time:  118996.05268406868
times:  205000
time:  119535.05997920036
lfw datasets get 0.981 acc
agedb_30 datasets get 0.895 acc
calfw datasets get 0.898 acc
cfp_ff datasets get 0.984 acc
cfp_fp datasets get 0.914 acc
cplfw datasets get 0.831 acc
lfw_face datasets get 0.881 acc
times:  206000
time:  120291.70757699013
times:  207000
time:  120830.53612017632
times:  208000
time:  121368.76617360115
times:  209000
tim

cfp_fp datasets get 0.909 acc
cplfw datasets get 0.816 acc
lfw_face datasets get 0.871 acc
times:  296000
time:  172824.65280985832
times:  297000
time:  173362.2290751934
times:  298000
time:  173901.64044332504
times:  299000
time:  174441.52443432808
times:  300000
time:  174981.22070002556
lfw datasets get 0.981 acc
agedb_30 datasets get 0.887 acc
calfw datasets get 0.885 acc
cfp_ff datasets get 0.978 acc
cfp_fp datasets get 0.904 acc
cplfw datasets get 0.818 acc
lfw_face datasets get 0.870 acc
times:  301000
time:  175763.1212360859
times:  302000
time:  176301.44482898712
times:  303000
time:  176839.60068941116
times:  304000
time:  177377.89766669273
times:  305000
time:  177917.10427331924
lfw datasets get 0.982 acc
agedb_30 datasets get 0.884 acc
calfw datasets get 0.893 acc
cfp_ff datasets get 0.979 acc
cfp_fp datasets get 0.910 acc
cplfw datasets get 0.821 acc
lfw_face datasets get 0.874 acc
times:  306000
time:  178674.22027683258
times:  307000
time:  179213.20144462585
t

cfp_ff datasets get 0.978 acc
cfp_fp datasets get 0.904 acc
cplfw datasets get 0.822 acc
lfw_face datasets get 0.863 acc
times:  396000
time:  230947.4613852501
times:  397000
time:  231481.21852517128
times:  398000
time:  232015.21160316467
times:  399000
time:  232550.84869384766
times:  400000
time:  233086.77590489388
lfw datasets get 0.981 acc
agedb_30 datasets get 0.887 acc
calfw datasets get 0.884 acc
cfp_ff datasets get 0.980 acc
cfp_fp datasets get 0.911 acc
cplfw datasets get 0.823 acc
lfw_face datasets get 0.867 acc
times:  401000
time:  233838.3126449585
times:  402000
time:  234372.8021941185
times:  403000
time:  234908.0305583477
times:  404000
time:  235443.64558243752
times:  405000
time:  235977.6408891678
lfw datasets get 0.982 acc
agedb_30 datasets get 0.886 acc
calfw datasets get 0.882 acc
cfp_ff datasets get 0.980 acc
cfp_fp datasets get 0.903 acc
cplfw datasets get 0.814 acc
lfw_face datasets get 0.875 acc
times:  406000
time:  236729.4195933342
times:  407000
t

calfw datasets get 0.884 acc
cfp_ff datasets get 0.979 acc
cfp_fp datasets get 0.907 acc
cplfw datasets get 0.824 acc
lfw_face datasets get 0.862 acc
times:  496000
time:  288880.93689084053
times:  497000
time:  289415.4436380863
times:  498000
time:  289950.80377697945
times:  499000
time:  290485.6396062374
times:  500000
time:  291021.55594825745
lfw datasets get 0.981 acc
agedb_30 datasets get 0.890 acc
calfw datasets get 0.888 acc
cfp_ff datasets get 0.979 acc
cfp_fp datasets get 0.904 acc
cplfw datasets get 0.821 acc
lfw_face datasets get 0.876 acc
times:  501000
time:  291774.38452744484
times:  502000
time:  292309.7618224621
times:  503000
time:  292844.26513433456
times:  504000
time:  293379.3446998596
times:  505000
time:  293914.9059064388
lfw datasets get 0.982 acc
agedb_30 datasets get 0.876 acc
calfw datasets get 0.881 acc
cfp_ff datasets get 0.978 acc
cfp_fp datasets get 0.906 acc
cplfw datasets get 0.820 acc
lfw_face datasets get 0.871 acc
times:  506000
time:  29466

calfw datasets get 0.883 acc
cfp_ff datasets get 0.980 acc
cfp_fp datasets get 0.904 acc
cplfw datasets get 0.816 acc
lfw_face datasets get 0.868 acc
times:  596000
time:  346869.8447408676
times:  597000
time:  347405.83916926384
times:  598000
time:  347942.94552493095
times:  599000
time:  348480.22383499146
times:  600000
time:  349018.0357296467
lfw datasets get 0.980 acc
agedb_30 datasets get 0.888 acc
calfw datasets get 0.892 acc
cfp_ff datasets get 0.981 acc
cfp_fp datasets get 0.906 acc
cplfw datasets get 0.823 acc
lfw_face datasets get 0.877 acc
times:  601000
time:  349788.9088294506
times:  602000
time:  350324.583565712
times:  603000
time:  350859.54296803474
times:  604000
time:  351396.11115789413
times:  605000
time:  351930.73644447327
lfw datasets get 0.978 acc
agedb_30 datasets get 0.877 acc
calfw datasets get 0.874 acc
cfp_ff datasets get 0.976 acc
cfp_fp datasets get 0.909 acc
cplfw datasets get 0.816 acc
lfw_face datasets get 0.865 acc
times:  606000
time:  35268

agedb_30 datasets get 0.886 acc
calfw datasets get 0.884 acc
cfp_ff datasets get 0.980 acc
cfp_fp datasets get 0.908 acc
cplfw datasets get 0.817 acc
lfw_face datasets get 0.864 acc
times:  696000
time:  404945.2719180584
times:  697000
time:  405479.98262023926
times:  698000
time:  406014.7625286579
times:  699000
time:  406550.1990222931
times:  700000
time:  407087.20648503304
lfw datasets get 0.978 acc
agedb_30 datasets get 0.885 acc
calfw datasets get 0.889 acc
cfp_ff datasets get 0.981 acc
cfp_fp datasets get 0.905 acc
cplfw datasets get 0.813 acc
lfw_face datasets get 0.864 acc
times:  701000
time:  407839.3524568081
times:  702000
time:  408373.6162369251
times:  703000
time:  408908.68925857544
times:  704000
time:  409444.52795910835
times:  705000
time:  409979.20729732513
lfw datasets get 0.979 acc
agedb_30 datasets get 0.870 acc
calfw datasets get 0.878 acc
cfp_ff datasets get 0.976 acc
cfp_fp datasets get 0.909 acc
cplfw datasets get 0.820 acc
lfw_face datasets get 0.869

lfw datasets get 0.980 acc
agedb_30 datasets get 0.890 acc
calfw datasets get 0.880 acc
cfp_ff datasets get 0.981 acc
cfp_fp datasets get 0.908 acc
cplfw datasets get 0.823 acc
lfw_face datasets get 0.869 acc
times:  796000
time:  462895.6436638832
times:  797000
time:  463431.0734779835
times:  798000
time:  463965.2667198181
times:  799000
time:  464500.9242298603
times:  800000
time:  465036.8528599739
lfw datasets get 0.981 acc
agedb_30 datasets get 0.878 acc
calfw datasets get 0.883 acc
cfp_ff datasets get 0.978 acc
cfp_fp datasets get 0.906 acc
cplfw datasets get 0.817 acc
lfw_face datasets get 0.861 acc
times:  801000
time:  465788.4464726448
times:  802000
time:  466322.47972750664
times:  803000
time:  466857.70747327805
times:  804000
time:  467392.72637319565
times:  805000
time:  467928.9157040119
lfw datasets get 0.980 acc
agedb_30 datasets get 0.881 acc
calfw datasets get 0.881 acc
cfp_ff datasets get 0.979 acc
cfp_fp datasets get 0.911 acc
cplfw datasets get 0.818 acc
lf

lfw datasets get 0.980 acc
agedb_30 datasets get 0.887 acc
calfw datasets get 0.886 acc
cfp_ff datasets get 0.981 acc
cfp_fp datasets get 0.902 acc
cplfw datasets get 0.815 acc
lfw_face datasets get 0.874 acc
times:  896000
time:  520740.40788531303
times:  897000
time:  521273.7773902416
times:  898000
time:  521806.9630625248
times:  899000
time:  522340.1413769722
times:  900000
time:  522874.2444908619
lfw datasets get 0.977 acc
agedb_30 datasets get 0.877 acc
calfw datasets get 0.879 acc
cfp_ff datasets get 0.977 acc
cfp_fp datasets get 0.901 acc
cplfw datasets get 0.819 acc
lfw_face datasets get 0.867 acc
times:  901000
time:  523644.3012678623
times:  902000
time:  524178.50917220116
times:  903000
time:  524712.9828300476
times:  904000
time:  525246.8427066803
times:  905000
time:  525782.444631815
lfw datasets get 0.979 acc
agedb_30 datasets get 0.880 acc
calfw datasets get 0.881 acc
cfp_ff datasets get 0.977 acc
cfp_fp datasets get 0.905 acc
cplfw datasets get 0.815 acc
lfw_

KeyboardInterrupt: 