在mnist数据集上测试self-train半监督算法

本实验测试了两种self-train算法，
- 第1种是将无标签数据集作为一个整体，选择样本时从整个无标签数据集中选择（有放回抽样）；
- 第2种是将无标签数据集分成$N$份，选择样本时，每次在其中一份进行选择；

第2种算法的速度要比第1种算法快，因为第2种算法每个迭代只需要在一个无标签数据子集上做预测（更快）；

但是有一点难理解的时，第2种算法的预测正确率同样要比第1种算法高（不清楚为什么？）

In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
import tensorflow.keras.layers as L
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# tf.enable_eager_execution()

处理数据集

In [2]:
# Load dataset
(xtr,ytr),(xte,yte) = mnist.load_data()
xtr = np.expand_dims(xtr,3).astype('float32')/255.
ytr = np.eye(10)[ytr]
xte = np.expand_dims(xte,3).astype('float32')/255.
yte = np.eye(10)[yte]
perm = np.load('perm-60000.npy')

# 100 label samples, 59900 unlabel samples
xla = xtr[perm[0:100]]
yla = ytr[perm[0:100]]
xun = xtr[perm[100:]]
yun = ytr[perm[100:]]

构建模型

In [4]:
# Build Model
model = Sequential([
    L.Conv2D(64, (3,3), padding='same', activation='relu', name='conv1', input_shape=(28,28,1)),
    L.MaxPooling2D((2,2), name='pool1'),
    L.Dropout(0.2),
    L.Conv2D(16, (3,3), padding='same', activation='relu', name='conv2'),
    L.MaxPooling2D((2,2)),
    L.Dropout(0.2),
    L.Flatten(),
    L.Dense(128),
    L.Dropout(0.5),
    L.Dense(10, activation='softmax')])
model.compile(loss='categorical_crossentropy', \
              optimizer=tf.train.AdamOptimizer(0.001), \
              metrics=['accuracy'])

使用带标签数据训练数据

In [13]:
# # Train the model with labeled sample
# history = model.fit(x=xla, y=yla, batch_size=20, epochs=30, validation_data=(xte, yte), \
#           shuffle=True)
# model.save_weights('./model/model.ckpt')
model.load_weights('./model/model.ckpt')
acc = model.evaluate(xte, yte)
print('模型正确率：{}'.format(acc[1]))

模型正确率：0.7968


In [12]:
# Train the model with labeled samples and unlabled samples
threshold = 1
while True:
    xbatch = xun # 无标签数据分 N 拨迭代（这一步似乎很重要）
    pbatch = model.predict(xbatch)
    idx = np.max(pbatch, axis=-1)
    xbatch = xbatch[idx>=threshold]
    pbatch = pbatch[idx>=threshold]
    pbatch = np.argmax(pbatch, axis=-1)
    pbatch = np.eye(10)[pbatch]
    model.fit(xbatch, pbatch, \
          batch_size=128, epochs=1, \
          validation_data = (xte, yte), \
          shuffle = True)
    if threshold <=0:
        break
    threshold -= 0.02

Train on 61 samples, validate on 10000 samples
Epoch 1/1
Train on 28853 samples, validate on 10000 samples
Epoch 1/1
Train on 43911 samples, validate on 10000 samples
Epoch 1/1
Train on 49126 samples, validate on 10000 samples
Epoch 1/1
Train on 51860 samples, validate on 10000 samples
Epoch 1/1
Train on 52991 samples, validate on 10000 samples
Epoch 1/1
Train on 54393 samples, validate on 10000 samples
Epoch 1/1
Train on 54976 samples, validate on 10000 samples
Epoch 1/1
Train on 56055 samples, validate on 10000 samples
Epoch 1/1
Train on 56178 samples, validate on 10000 samples
Epoch 1/1
Train on 56831 samples, validate on 10000 samples
Epoch 1/1
Train on 56996 samples, validate on 10000 samples
Epoch 1/1
Train on 57583 samples, validate on 10000 samples
Epoch 1/1
Train on 57742 samples, validate on 10000 samples
Epoch 1/1
Train on 58174 samples, validate on 10000 samples
Epoch 1/1
Train on 58395 samples, validate on 10000 samples
Epoch 1/1
Train on 58568 samples, validate on 10000 s

In [14]:
# Train the model with labeled samples and unlabled samples
threshold = 0.9
for i in range(10):
    for j in range(0, len(xun), 10000):
        xbatch = xun[j:j+10000] # 无标签数据分 N 拨迭代（这一步似乎很重要）
        pbatch = model.predict(xbatch)
        idx = np.max(pbatch, axis=-1)
        xbatch = xbatch[idx>threshold]
        pbatch = pbatch[idx>threshold]
        pbatch = np.argmax(pbatch, axis=-1)
        pbatch = np.eye(10)[pbatch]
        print('SSL iteratin {}-{}, threshold: {}, {} samples selected'.format(i,j, threshold, xbatch.shape[0]))

        model.fit(xbatch, pbatch, \
              batch_size=128, epochs=1, \
              validation_data = (xte, yte), \
              shuffle = True)
    threshold -= 0.1


SSL iteratin 0-0, threshold: 0.9, 6642 samples selected
Train on 6642 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 0-10000, threshold: 0.9, 7633 samples selected
Train on 7633 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 0-20000, threshold: 0.9, 8076 samples selected
Train on 8076 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 0-30000, threshold: 0.9, 8293 samples selected
Train on 8293 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 0-40000, threshold: 0.9, 8410 samples selected
Train on 8410 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 0-50000, threshold: 0.9, 8572 samples selected
Train on 8572 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 1-0, threshold: 0.8, 9168 samples selected
Train on 9168 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 1-10000, threshold: 0.8, 9130 samples selected
Train on 9130 samples, validate on 10000 samples
Epoch 1/1
SSL iteratin 1-20000, threshold: 0.8, 9227 samples selec