参考 http://www.cnblogs.com/zyly/p/9146787.html

In [7]:
import tensorflow as tf
from tensorflow.contrib.slim.nets import vgg

IMAGE_SIZE = vgg.vgg_16.default_image_size
print(IMAGE_SIZE)

224


In [13]:
import os
import numpy as np
import cv2
import time

def input_data(npz_file):
    if os.path.exists(npz_file) :
        bird_data = np.load(npz_file)
        return bird_data['train_img'],bird_data['test_img'],bird_data['train_label'],bird_data['test_label']
    else:      
        data_path = os.path.join('../../..','data','CUB_200_2011')
        print(os.listdir(data_path))

        train_test_split_file = os.path.join(data_path,'train_test_split.txt')
        with open(train_test_split_file,'r') as file:
            train_test_split = np.array([i.split()[1] for i in file.readlines()]).astype('bool')
        print(train_test_split,train_test_split.size)

        img_paths_file = os.path.join(data_path,'images.txt')
        with open(img_paths_file,'r') as file:
            img_paths = [i.split()[1] for i in file.readlines()]
        print(img_paths[:1],len(img_paths))

        img_labels_file = os.path.join(data_path,'image_class_labels.txt')
        with open(img_labels_file,'r') as file:
            img_labels = np.array([i.split()[1] for i in file.readlines()]).astype('int')
        print(img_labels,len(img_labels))

        img_dir = os.path.join(data_path,'images')

        img_paths_train = [os.path.join(img_dir,os.path.sep.join(path.split('/'))) for i,path in enumerate(img_paths) if train_test_split[i]]
        print(img_paths_train[:1],len(img_paths_train))
        img_paths_test = [os.path.join(img_dir,os.path.sep.join(path.split('/'))) for i,path in enumerate(img_paths) if not train_test_split[i]]
        print(img_paths_test[:1],len(img_paths_test))

        train_img = np.array([cv2.resize(cv2.imread(i),(224,224)) for i in img_paths_train])
        test_img = np.array([cv2.resize(cv2.imread(i),(224,224)) for i in img_paths_test])
        train_label = np.array([l for i,l in enumerate(img_labels) if train_test_split[i] ])
        test_label = np.array([l for i,l in enumerate(img_labels) if not train_test_split[i]])
        print(train_label,train_label.size)
        print(test_label,test_label.size)

        np.savez(npz_file,train_img=train_img,test_img=test_img,train_label=train_label,test_label=test_label)
        return train_img,test_img,train_label,test_label
    
x_train,x_test,y_train,y_test = input_data('bird_data_224.npz')
print('type:',type(x_train),type(y_train))
print('shape:',x_train.shape,y_train.shape)
print('size:',x_train.size,y_train.size)


num_classes = 200

# 数据预处理，把 0-255的灰度值转成 0-1 之间的浮点数
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
y_train, y_test = np.array(y_train)-1, np.array(y_test)-1


type: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
shape: (5994, 224, 224, 3) (5994,)
size: 902264832 5994


MemoryError: 

In [None]:

'''
演示一个VGG16的例子 
微调 这里只调整VGG16最后一层全连接层，把1000类改为5类 
对网络进行训练   使用slim库简化代码
'''    
batch_size = 128

learning_rate = 1e-4

#用于保存微调后的检查点文件和日志文件路径
train_log_dir = './vgg_16_2016_08_28/slim_fine_tune'    

#官方下载的检查点文件路径
checkpoint_file = './vgg_16_2016_08_28/vgg_16.ckpt'

if not tf.gfile.Exists(train_log_dir):
    tf.gfile.MakeDirs(train_log_dir)

#创建一个图，作为当前图
with tf.Graph().as_default():

    #加载数据
    train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)          


    #创建vgg16网络  如果想冻结所有层，可以指定slim.conv2d中的 trainable=False
    logits,end_points =  vgg.vgg_16(train_images, is_training=True,num_classes = NUM_CLASSES)        

    #交叉熵代价函数
    slim.losses.softmax_cross_entropy(logits, onehot_labels=train_labels)
    total_loss = slim.losses.get_total_loss()

    #设置写入到summary中的变量
    tf.summary.scalar('losses/total_loss', total_loss)

    '''
    设置优化器 这里不能指定成Adam优化器，因为我们的官方模型文件中使用的就是GradientDescentOptimizer优化器，
    因此我们要和官方模型一致，如果想使用AdamOptimizer优化器，我们可以在调用完vgg16()网络后，就执行恢复模型。
    而把执行恢复模型的代码放在后面，会由于我们在当前图中定义了一些检查点中不存在变量，恢复时在检查点文件找不
    到变量，因此会报错。
    '''
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    #optimizer = tf.train.AdamOptimizer(learning_rate)
    # create_train_op that ensures that when we evaluate it to get the loss,
    # the update_ops are done and the gradient updates are computed.
    train_tensor = slim.learning.create_train_op(total_loss, optimizer)

    #检查最近的检查点文件
    ckpt = tf.train.latest_checkpoint(train_log_dir)
    if ckpt != None:
        variables_to_restore = slim.get_model_variables()
        init_fn = slim.assign_from_checkpoint_fn(ckpt,variables_to_restore)
        print('从上次训练保存后的模型继续训练！')
    else:
        # Restore only the convolutional layers: 从检查点载入除了fc8层之外的参数到当前图             
        variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8']) 
        init_fn = slim.assign_from_checkpoint_fn(checkpoint_file, variables_to_restore)
        print('从官方模型加载训练！')


    print('开始训练！')
    #开始训练网络        
    slim.learning.train(train_tensor,
                        train_log_dir,
                        number_of_steps=100,             #迭代次数 一次迭代batch_size个样本
                        save_summaries_secs=300,         #存summary间隔秒数
                        save_interval_secs=300,          #存模模型间隔秒数                         
                        init_fn=init_fn)