## Image Classification Training Pipeline

#### 使用说明

* 编辑  "config/your-config-file.json"
* 编辑 “全局变量”中的 ```config``` 和  ``` run_dir``` 参数
* 运行 notebook

In [1]:
import os
import json
import tensorflow as tf
from datasets.dataset import ImageData
from nets import dishnet1_7 as dishnet
from utils.utils import write_pb, save_model, mkdir_p
from train import *

  from ._conv import register_converters as _register_converters


#### 修改运行相关的变量

In [2]:
config = "config/dishnet-config.json"
run_dir = 'RUN'

### 准备训练需要的环境，变量，数据，创建神经网络

In [3]:
if not os.path.exists(run_dir):
    os.makedirs(run_dir)

assert (os.path.exists(run_dir) and os.path.isdir( run_dir))
assert (os.path.exists(config) and not os.path.isdir( config))
hypes = json.load(open(config, 'r'))
hypes_train = hypes['train']
project_name = hypes_train['project_name']

hypes_train['run_dir'] = os.path.join( run_dir, project_name)
hypes_train['log_dir'] = os.path.join( run_dir, project_name, 'logs')
hypes_train['ckpt_dir'] = os.path.join( run_dir, project_name, 'checkpoints')
mkdir_p(hypes_train['run_dir'])
mkdir_p(hypes_train['log_dir'])
mkdir_p(hypes_train['ckpt_dir'])

# 设置log级别
tf.logging.set_verbosity(tf.logging.INFO)

# recover variables from the input dict (config file)
run_dir = hypes["train"]["run_dir"]
log_dir = hypes["train"]["log_dir"]
ckpt_dir = hypes["train"]["ckpt_dir"]
pretrain_model = hypes["train"]["pretrain_model"]
data_train = hypes["train"]["data_train"]
data_val = hypes["train"]["data_val"]
learning_rate = hypes["train"]["learning_rate"]
export_model = hypes["train"]["export_model"]
is_tfrecord = hypes["train"].get("tfrecord_flag", False)
batch_size = hypes["train"].get("batch_size", 32)

# 数据增强的参数: 获取默认参数，或者从配置文件中加载
augment_params = hypes["train"].get("augment", get_default_augment_params())

if is_tfrecord:
    _num_classes = hypes['train'].get('num_classes', None)
    if _num_classes is None:
        raise ValueError("Missing num_classes setting in the config file when training data is tfrecord files.")

    img_data = ImageData(data_train, data_val,
                         is_tfrecord=is_tfrecord,
                         batch_size=batch_size,
                         augment_params=augment_params,
                         output_like='Inception')

else:
    # 如果只想在部分类别熵测试，可以在配置文件中指定 use_classes
    # 例如 "use_classes": ["000101", "000011"]
    # 即只使用上面两种类别的数据训练和验证
    # 如果 "use_classes"为[]或不存在，默认使用全部类别
    small_classes_set = None
    _classes_set = hypes['train'].get('use_classes', None)
    if _classes_set:
        small_classes_set = _classes_set

    img_data = ImageData(data_train, data_val,
                         small_classes_set=small_classes_set,
                         batch_size=batch_size,
                         augment_params=augment_params,
                         output_like='Inception')
    _num_classes = len(img_data.classes)


# 网络相关的配置参数
net_params = {
    'num_classes': _num_classes,
    'pretrain_model': pretrain_model,
    'checkpoint_path': ckpt_dir,
    'exclude': ['logits'],
    'adam_beta1': hypes_train['adam_beta1']
}

tf.logging.info('Train on {} classes.'.format(net_params['num_classes']))

graph = tf.Graph()
with graph.as_default():
    train_dataset = img_data.data_input_fn(mode='train')
    val_dataset = img_data.data_input_fn(mode='eval')
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)
    train_iterator = train_dataset.make_initializable_iterator()
    val_iterator = val_dataset.make_initializable_iterator()
    images, labels = iterator.get_next()

    net = build_net(images, labels, net_params)
    net['handle'] = handle
    saver = tf.train.Saver()

# Indicates whether we are in training or in test mode
is_training = net['is_training']

# create writer to log the training; later for tensorboard
train_writer = tf.summary.FileWriter(os.path.join(log_dir, 'train'), graph)
eval_writer = tf.summary.FileWriter(os.path.join(log_dir, 'eval'))


INFO:tensorflow:Train on 1083 classes.
Instructions for updating:
Use the retry module or similar alternatives.


### 开始训练

In [None]:
with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    net['train_handle'] = sess.run(train_iterator.string_handle())
    net['val_handle'] = sess.run(val_iterator.string_handle())
    # 如果有之前的checkpoint, 从其restore
    saved_ckpt = tf.train.latest_checkpoint(ckpt_dir)
    if saved_ckpt:
        saver.restore(sess, saved_ckpt)

    # Here we initialize the iterator with the training set.
    # This means that we can go through an entire epoch until the iterator becomes empty.
    train_and_evaluate(sess, train_iterator, val_iterator,
                       train_writer, eval_writer, net,
                       saver, hypes, float(learning_rate))
    tf.logging.info('Training finished, evaluating the model on validation set...')
    eval_result = fully_evaluate(sess, val_iterator, net['eval_metric'],
                                 {is_training: False, handle: net['val_handle']}, -1)
    tf.logging.info('======= eval result ========')
    tf.logging.info(eval_result)

INFO:tensorflow:step 0 -- train loss: 6.913 acc_top1: 0.000 acc_top5: 0.031
INFO:tensorflow:step 0 -- on validation set: acc_top1: 0.000 acc_top5: 0.003
INFO:tensorflow:step 20 -- train loss: 6.908 acc_top1: 0.125 acc_top5: 0.188
INFO:tensorflow:step 40 -- train loss: 6.844 acc_top1: 0.094 acc_top5: 0.125
INFO:tensorflow:step 60 -- train loss: 6.410 acc_top1: 0.219 acc_top5: 0.219
INFO:tensorflow:step 80 -- train loss: 6.748 acc_top1: 0.219 acc_top5: 0.219
INFO:tensorflow:step 100 -- train loss: 6.682 acc_top1: 0.125 acc_top5: 0.125
INFO:tensorflow:step 100 -- on validation set: acc_top1: 0.014 acc_top5: 0.020
INFO:tensorflow:step 120 -- train loss: 6.838 acc_top1: 0.156 acc_top5: 0.156
INFO:tensorflow:step 140 -- train loss: 7.074 acc_top1: 0.125 acc_top5: 0.125
INFO:tensorflow:step 160 -- train loss: 6.864 acc_top1: 0.156 acc_top5: 0.156
INFO:tensorflow:step 180 -- train loss: 5.654 acc_top1: 0.281 acc_top5: 0.281
INFO:tensorflow:step 200 -- train loss: 5.792 acc_top1: 0.312 acc_top5

### 训练完成后导出模型

In [None]:
export_pb_file(net_params, run_dir, export_model)