Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time

1.4. 训练 resnet50-flowers

数据集下载地址:http://download.tensorflow.org/example_images/flower_photos.tgz

将数据集解压到目录/tmp/tensorflow/

1.4.1. 定义主函数和超参

引入MoXing-TensorFlow

import moxing.tensorflow as mox

通过mox.get_flag获取命令行参数num_gpusworker_hosts,从而获取当前使用的GPU数量和节点数量

num_gpus = mox.get_flag('num_gpus')
num_workers = len(mox.get_flag('worker_hosts').split(','))

flowers数据集的格式如下:

/tmp/tensorflow/flower_photos
    |-- daisy
        |-- xxx0.jpg
        ...
    |-- dandelion
        |-- xxx1.jpg
        ...
    |-- roses
        |-- xxx2.jpg
        ...
    |-- sunflowers
        |-- xxx3.jpg
        ...
    |-- tulips
        |-- xxx4.jpg
        ...

每一个子目录代表一个分类,每个分类下有若干张图片,对于这种类型的数据集,可以使用mox.ImageClassificationRawMetadatamox.ImageClassificationRawDataset来读取。MoXing-TensorFlow预置了若干种解析数据集的类,一般会使用数据集元信息类+数据集读取类的模式来读取。数据集元信息类不会创建任何TensorFlow的数据流图,建议在main方法中直接实例化,那么代码的其他地方都能获取数据集的元信息(如样本数量,分类数量)。数据集读取类必须在input_fn中实例化,该类的实例化会在TensorFlow数据流图中创建节点。

创建一个数据集元信息类base_dir即指定flower_photos所在目录

data_meta = mox.ImageClassificationRawMetadata(base_dir=flags.data_url)

input_fn中创建一个数据增强方法(基于resnet50)和一个数据集读取类

def input_fn(mode):
  # 创建一个数据增强方法,该方法基于resnet50论文实现
  augmentation_fn = mox.get_data_augmentation_fn(name='resnet_v1_50',
                                                 run_mode=mode,
                                                 output_height=224,
                                                 output_width=224)

  # 创建`数据集读取类`,并将数据增强方法传入,最多读取20个epoch
  dataset = mox.ImageClassificationRawDataset(data_meta,
                                              batch_size=flags.batch_size,
                                              num_epochs=20,
                                              augmentation_fn=augmentation_fn)
  image, label = dataset.get(['image', 'label'])
  return image, label

定义model_fn

def model_fn(inputs, mode):
  images, labels = inputs

  # 获取一个resnet50的模型,输入images,输入logits和end_points,这里不关心end_points,仅取logits
  logits, _ = mox.get_model_fn(name='resnet_v1_50',
                               run_mode=mode,
                               num_classes=data_meta.num_classes,
                               weight_decay=0.00004)(images)

  # 计算交叉熵损失值
  labels_one_hot = slim.one_hot_encoding(labels, data_meta.num_classes)
  loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels_one_hot)

  # 获取正则项损失值,并加到loss上,这里必须要用mox.get_collection代替tf.get_collection
  regularization_losses = mox.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  regularization_loss = tf.add_n(regularization_losses)
  loss = loss + regularization_loss

  # 计算分类正确率
  accuracy = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, labels, 1), tf.float32))

  # 返回MoXing-TensorFlow用于定义模型的类ModelSpec
  return mox.ModelSpec(loss=loss, log_info={'loss': loss, 'accuracy': accuracy})

定义optimizer_fn

from moxing.tensorflow.optimizer import learning_rate_scheduler

def optimizer_fn():
  # 使用分段式学习率,0-10个epoch为0.01,10-20个epoch为0.001
  lr = learning_rate_scheduler.piecewise_lr('10:0.01,20:0.001',
                                            num_samples=data_meta.total_num_samples,
                                            global_batch_size=flags.batch_size * num_gpus * num_workers)
  return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)

完整代码请参考:mox_flowers.py

执行训练:

python mox_flowers.py \
--data_url=/tmp/tensorflow/flower_photos \
--train_url=/tmp/flowers \
--num_gpus=4

使用 4 * Nvidia-Tesla-K80 运行时间大约为:698秒,在训练集上的训练精度约为:50%