# vgg16对cifar-10的迁移学习

In [1]:
import numpy as np
import tensorflow as tf

import vgg16
import utils
import pickle
import skimage
import skimage.io
import skimage.transform


## 数据读取和预处理

In [2]:
def load_image(img):
    # load image
    img = img.reshape(3, 32, 32)
    img = img.transpose(1,2,0)
    img = img / 255.0
    assert (0 <= img).all() and (img <= 1.0).all()
    # print "Original Image Shape: ", img.shape
    # we crop image from center
    short_edge = min(img.shape[:2])
    yy = int((img.shape[0] - short_edge) / 2)
    xx = int((img.shape[1] - short_edge) / 2)
    crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
    # resize to 224, 224
    resized_img = skimage.transform.resize(crop_img, (224, 224))
    resized_img = resized_img.reshape((1,224, 224, 3))

    return resized_img
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [3]:
# cifar文件放在当前目录下
fil=unpickle('test_batch')
data=fil[b'data']
labels=fil[b'labels'][:5000]
batch=()
for i in range(5000):
    p=load_image(data[i])
    batch=batch + (p,)
batch = np.concatenate(batch, 0)


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


## 将cifar的每一张图片用vgg的卷积层转换为4096\*1的向量

In [4]:
codes=None

with tf.Session() as sess:
    
    vgg = vgg16.Vgg16()
    input_ = tf.placeholder("float", [None,224, 224, 3])
    with tf.name_scope("content_vgg"):
        vgg.build(input_)
    turn=len(batch)//64
    for i in range(turn+1):
        if i==turn:
            feed_dict = {input_: batch[turn*64:]}
        else:
            feed_dict = {input_: batch[i*64:i*64+64]}
        codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
        if codes is None:
                codes = codes_batch
        else:
                codes = np.concatenate((codes, codes_batch))
        print(i)


C:\Users\a\Desktop\vgg\tensorflow-vgg\vgg16.npy
npy file loaded
build model started
build model finished: 3s
0
1


"\nfor x in get_batches_for_x(batch):\n        feed_dict = {input_: x}\n        codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)\n        if codes is None:\n            codes = codes_batch\n        else:\n            codes = np.concatenate((codes, codes_batch))\n        print('one turn!')\n"

## 划分训练集

In [16]:
with open('codes', 'w') as f:
    codes.tofile(f)
    
import csv
with open('labels', 'w') as f:
    writer = csv.writer(f, delimiter='\n')
    writer.writerow(labels)
    
from sklearn.preprocessing import LabelBinarizer

lb = LabelBinarizer()
lb.fit(labels)

labels_vecs = lb.transform(labels)

from sklearn.model_selection import StratifiedShuffleSplit

ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)

train_idx, val_idx = next(ss.split(codes, labels))

half_val_len = int(len(val_idx)/2)
val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]

train_x, train_y = codes[train_idx], labels_vecs[train_idx]
val_x, val_y = codes[val_idx], labels_vecs[val_idx]
test_x, test_y = codes[test_idx], labels_vecs[test_idx]

print("Train shapes (x, y):", train_x.shape, train_y.shape)
print("Validation shapes (x, y):", val_x.shape, val_y.shape)
print("Test shapes (x, y):", test_x.shape, test_y.shape)

Train shapes (x, y): (3993, 4096) (3993, 10)
Validation shapes (x, y): (499, 4096) (499, 10)
Test shapes (x, y): (500, 4096) (500, 10)


## 添加全连接网络

In [17]:
# 输入数据的维度
inputs_ = tf.placeholder(tf.float32, shape=[None, codes.shape[1]])
# 标签数据的维度
labels_ = tf.placeholder(tf.int64, shape=[None, labels_vecs.shape[1]])

# 加入一个256维的全连接的层
fc = tf.contrib.layers.fully_connected(inputs_, 256)

# 加入一个10维的全连接层
logits = tf.contrib.layers.fully_connected(fc, labels_vecs.shape[1], activation_fn=None)

# 计算cross entropy值
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=labels_, logits=logits)

# 计算损失函数
cost = tf.reduce_mean(cross_entropy)

# 采用用得最广泛的AdamOptimizer优化器
optimizer = tf.train.AdamOptimizer().minimize(cost)

# 得到最后的预测分布
predicted = tf.nn.softmax(logits)

# 计算准确度
correct_pred = tf.equal(tf.argmax(predicted, 1), tf.argmax(labels_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.



## 训练添加的全连接层网络

In [42]:
def get_batches(x, y, n_batches=10):
    """ 这是一个生成器函数，按照n_batches的大小将数据划分了小块 """
    batch_size = len(x)//n_batches
    
    for ii in range(0, n_batches*batch_size, batch_size):
        # 如果不是最后一个batch，那么这个batch中应该有batch_size个数据
        if ii != (n_batches-1)*batch_size:
            X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size] 
        # 否则的话，那剩余的不够batch_size的数据都凑入到一个batch中
        else:
            X, Y = x[ii:], y[ii:]
        # 生成器语法，返回X和Y
        yield X, Y
        

In [27]:
# 运行多少轮次
epochs = 100
# 统计训练效果的频率
iteration = 0
# 保存模型的保存器
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for x, y in get_batches(train_x, train_y):
            feed = {inputs_: x,
                    labels_: y}
            # 训练模型
            loss, _ = sess.run([cost, optimizer], feed_dict=feed)
            print("Epoch: {}/{}".format(e+1, epochs),
                  "Iteration: {}".format(iteration),
                  "Training loss: {:.5f}".format(loss))
            iteration += 1
            
            if iteration % 5 == 0:
                feed = {inputs_: val_x,
                        labels_: val_y}
                val_acc = sess.run(accuracy, feed_dict=feed)
                # 输出用验证机验证训练进度
                print("Epoch: {}/{}".format(e, epochs),
                      "Iteration: {}".format(iteration),
                      "Validation Acc: {:.4f}".format(val_acc))
    # 保存模型
    saver.save(sess, "checkpoints/cif.ckpt")

Epoch: 1/100 Iteration: 0 Training loss: 4.54140
Epoch: 1/100 Iteration: 1 Training loss: 5.77719
Epoch: 1/100 Iteration: 2 Training loss: 6.52939
Epoch: 1/100 Iteration: 3 Training loss: 8.52608
Epoch: 1/100 Iteration: 4 Training loss: 7.31165
Epoch: 0/100 Iteration: 5 Validation Acc: 0.4990
Epoch: 1/100 Iteration: 5 Training loss: 5.52604
Epoch: 1/100 Iteration: 6 Training loss: 5.32081
Epoch: 1/100 Iteration: 7 Training loss: 4.15429
Epoch: 1/100 Iteration: 8 Training loss: 2.53710
Epoch: 1/100 Iteration: 9 Training loss: 1.79610
Epoch: 0/100 Iteration: 10 Validation Acc: 0.6232
Epoch: 2/100 Iteration: 10 Training loss: 0.79379
Epoch: 2/100 Iteration: 11 Training loss: 0.86459
Epoch: 2/100 Iteration: 12 Training loss: 1.28435
Epoch: 2/100 Iteration: 13 Training loss: 1.39403
Epoch: 2/100 Iteration: 14 Training loss: 1.36121
Epoch: 1/100 Iteration: 15 Validation Acc: 0.5812
Epoch: 2/100 Iteration: 15 Training loss: 1.28178
Epoch: 2/100 Iteration: 16 Training loss: 1.20269
Epoch: 2/10

Epoch: 16/100 Iteration: 152 Training loss: 0.07581
Epoch: 16/100 Iteration: 153 Training loss: 0.08277
Epoch: 16/100 Iteration: 154 Training loss: 0.07821
Epoch: 15/100 Iteration: 155 Validation Acc: 0.7936
Epoch: 16/100 Iteration: 155 Training loss: 0.07891
Epoch: 16/100 Iteration: 156 Training loss: 0.10431
Epoch: 16/100 Iteration: 157 Training loss: 0.05541
Epoch: 16/100 Iteration: 158 Training loss: 0.06253
Epoch: 16/100 Iteration: 159 Training loss: 0.08634
Epoch: 15/100 Iteration: 160 Validation Acc: 0.7735
Epoch: 17/100 Iteration: 160 Training loss: 0.07726
Epoch: 17/100 Iteration: 161 Training loss: 0.09466
Epoch: 17/100 Iteration: 162 Training loss: 0.05601
Epoch: 17/100 Iteration: 163 Training loss: 0.04807
Epoch: 17/100 Iteration: 164 Training loss: 0.05719
Epoch: 16/100 Iteration: 165 Validation Acc: 0.7896
Epoch: 17/100 Iteration: 165 Training loss: 0.07346
Epoch: 17/100 Iteration: 166 Training loss: 0.13227
Epoch: 17/100 Iteration: 167 Training loss: 0.07627
Epoch: 17/10

Epoch: 29/100 Iteration: 284 Training loss: 0.01312
Epoch: 28/100 Iteration: 285 Validation Acc: 0.8016
Epoch: 29/100 Iteration: 285 Training loss: 0.01189
Epoch: 29/100 Iteration: 286 Training loss: 0.01557
Epoch: 29/100 Iteration: 287 Training loss: 0.01221
Epoch: 29/100 Iteration: 288 Training loss: 0.00912
Epoch: 29/100 Iteration: 289 Training loss: 0.01038
Epoch: 28/100 Iteration: 290 Validation Acc: 0.8156
Epoch: 30/100 Iteration: 290 Training loss: 0.00865
Epoch: 30/100 Iteration: 291 Training loss: 0.00749
Epoch: 30/100 Iteration: 292 Training loss: 0.00868
Epoch: 30/100 Iteration: 293 Training loss: 0.00840
Epoch: 30/100 Iteration: 294 Training loss: 0.00929
Epoch: 29/100 Iteration: 295 Validation Acc: 0.8216
Epoch: 30/100 Iteration: 295 Training loss: 0.00891
Epoch: 30/100 Iteration: 296 Training loss: 0.01146
Epoch: 30/100 Iteration: 297 Training loss: 0.01070
Epoch: 30/100 Iteration: 298 Training loss: 0.00709
Epoch: 30/100 Iteration: 299 Training loss: 0.00910
Epoch: 29/10

Epoch: 42/100 Iteration: 415 Training loss: 0.00410
Epoch: 42/100 Iteration: 416 Training loss: 0.00500
Epoch: 42/100 Iteration: 417 Training loss: 0.00445
Epoch: 42/100 Iteration: 418 Training loss: 0.00391
Epoch: 42/100 Iteration: 419 Training loss: 0.00445
Epoch: 41/100 Iteration: 420 Validation Acc: 0.8156
Epoch: 43/100 Iteration: 420 Training loss: 0.00377
Epoch: 43/100 Iteration: 421 Training loss: 0.00337
Epoch: 43/100 Iteration: 422 Training loss: 0.00361
Epoch: 43/100 Iteration: 423 Training loss: 0.00397
Epoch: 43/100 Iteration: 424 Training loss: 0.00423
Epoch: 42/100 Iteration: 425 Validation Acc: 0.8156
Epoch: 43/100 Iteration: 425 Training loss: 0.00395
Epoch: 43/100 Iteration: 426 Training loss: 0.00481
Epoch: 43/100 Iteration: 427 Training loss: 0.00430
Epoch: 43/100 Iteration: 428 Training loss: 0.00378
Epoch: 43/100 Iteration: 429 Training loss: 0.00428
Epoch: 42/100 Iteration: 430 Validation Acc: 0.8156
Epoch: 44/100 Iteration: 430 Training loss: 0.00362
Epoch: 44/10

Epoch: 56/100 Iteration: 551 Training loss: 0.00230
Epoch: 56/100 Iteration: 552 Training loss: 0.00242
Epoch: 56/100 Iteration: 553 Training loss: 0.00269
Epoch: 56/100 Iteration: 554 Training loss: 0.00283
Epoch: 55/100 Iteration: 555 Validation Acc: 0.8176
Epoch: 56/100 Iteration: 555 Training loss: 0.00266
Epoch: 56/100 Iteration: 556 Training loss: 0.00314
Epoch: 56/100 Iteration: 557 Training loss: 0.00284
Epoch: 56/100 Iteration: 558 Training loss: 0.00259
Epoch: 56/100 Iteration: 559 Training loss: 0.00282
Epoch: 55/100 Iteration: 560 Validation Acc: 0.8136
Epoch: 57/100 Iteration: 560 Training loss: 0.00241
Epoch: 57/100 Iteration: 561 Training loss: 0.00224
Epoch: 57/100 Iteration: 562 Training loss: 0.00236
Epoch: 57/100 Iteration: 563 Training loss: 0.00261
Epoch: 57/100 Iteration: 564 Training loss: 0.00276
Epoch: 56/100 Iteration: 565 Validation Acc: 0.8156
Epoch: 57/100 Iteration: 565 Training loss: 0.00259
Epoch: 57/100 Iteration: 566 Training loss: 0.00305
Epoch: 57/10

Epoch: 70/100 Iteration: 694 Training loss: 0.00201
Epoch: 69/100 Iteration: 695 Validation Acc: 0.8136
Epoch: 70/100 Iteration: 695 Training loss: 0.00190
Epoch: 70/100 Iteration: 696 Training loss: 0.00222
Epoch: 70/100 Iteration: 697 Training loss: 0.00201
Epoch: 70/100 Iteration: 698 Training loss: 0.00188
Epoch: 70/100 Iteration: 699 Training loss: 0.00200
Epoch: 69/100 Iteration: 700 Validation Acc: 0.8156
Epoch: 71/100 Iteration: 700 Training loss: 0.00171
Epoch: 71/100 Iteration: 701 Training loss: 0.00162
Epoch: 71/100 Iteration: 702 Training loss: 0.00168
Epoch: 71/100 Iteration: 703 Training loss: 0.00187
Epoch: 71/100 Iteration: 704 Training loss: 0.00196
Epoch: 70/100 Iteration: 705 Validation Acc: 0.8136
Epoch: 71/100 Iteration: 705 Training loss: 0.00186
Epoch: 71/100 Iteration: 706 Training loss: 0.00217
Epoch: 71/100 Iteration: 707 Training loss: 0.00196
Epoch: 71/100 Iteration: 708 Training loss: 0.00184
Epoch: 71/100 Iteration: 709 Training loss: 0.00196
Epoch: 70/10

Epoch: 84/100 Iteration: 836 Training loss: 0.00165
Epoch: 84/100 Iteration: 837 Training loss: 0.00150
Epoch: 84/100 Iteration: 838 Training loss: 0.00142
Epoch: 84/100 Iteration: 839 Training loss: 0.00150
Epoch: 83/100 Iteration: 840 Validation Acc: 0.8136
Epoch: 85/100 Iteration: 840 Training loss: 0.00128
Epoch: 85/100 Iteration: 841 Training loss: 0.00122
Epoch: 85/100 Iteration: 842 Training loss: 0.00126
Epoch: 85/100 Iteration: 843 Training loss: 0.00141
Epoch: 85/100 Iteration: 844 Training loss: 0.00147
Epoch: 84/100 Iteration: 845 Validation Acc: 0.8136
Epoch: 85/100 Iteration: 845 Training loss: 0.00140
Epoch: 85/100 Iteration: 846 Training loss: 0.00162
Epoch: 85/100 Iteration: 847 Training loss: 0.00147
Epoch: 85/100 Iteration: 848 Training loss: 0.00140
Epoch: 85/100 Iteration: 849 Training loss: 0.00147
Epoch: 84/100 Iteration: 850 Validation Acc: 0.8136
Epoch: 86/100 Iteration: 850 Training loss: 0.00126
Epoch: 86/100 Iteration: 851 Training loss: 0.00120
Epoch: 86/10

Epoch: 99/100 Iteration: 982 Training loss: 0.00098
Epoch: 99/100 Iteration: 983 Training loss: 0.00109
Epoch: 99/100 Iteration: 984 Training loss: 0.00114
Epoch: 98/100 Iteration: 985 Validation Acc: 0.8156
Epoch: 99/100 Iteration: 985 Training loss: 0.00109
Epoch: 99/100 Iteration: 986 Training loss: 0.00125
Epoch: 99/100 Iteration: 987 Training loss: 0.00114
Epoch: 99/100 Iteration: 988 Training loss: 0.00109
Epoch: 99/100 Iteration: 989 Training loss: 0.00114
Epoch: 98/100 Iteration: 990 Validation Acc: 0.8156
Epoch: 100/100 Iteration: 990 Training loss: 0.00098
Epoch: 100/100 Iteration: 991 Training loss: 0.00093
Epoch: 100/100 Iteration: 992 Training loss: 0.00097
Epoch: 100/100 Iteration: 993 Training loss: 0.00107
Epoch: 100/100 Iteration: 994 Training loss: 0.00112
Epoch: 99/100 Iteration: 995 Validation Acc: 0.8156
Epoch: 100/100 Iteration: 995 Training loss: 0.00107
Epoch: 100/100 Iteration: 996 Training loss: 0.00123
Epoch: 100/100 Iteration: 997 Training loss: 0.00112
Epoc

## 测试准确率

In [28]:
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    
    feed = {inputs_: test_x,
            labels_: test_y}
    test_acc = sess.run(accuracy, feed_dict=feed)
    print("Test accuracy: {:.4f}".format(test_acc))

INFO:tensorflow:Restoring parameters from checkpoints\cif.ckpt
Test accuracy: 0.8360
