Import libraries

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

Define Hyperparameters

In [None]:
BATCH_SIZE = 32
EPOCHS = 10
VALIDATION_SPLIT = 0.2
THRESHOLD = 0.95

Fetch and preprocess MNIST data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Rescale the images from [0,255] to the [0.0,1.0] range.
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

print("Number of original training examples:", len(x_train))
print("Number of original test examples:", len(x_test))

# Simulate unlabeled dataset
labeled_index = np.random.choice(x_train.shape[0], int(x_train.shape[0] * 0.1), replace=False)
labeled_mask = np.zeros(len(x_train), dtype=bool)
labeled_mask[labeled_index] = 1
labeled_x_train, labeled_y_train = x_train[labeled_mask], y_train[labeled_mask]
unlabeled_x_train, unlabeled_y_train = x_train[~labeled_mask], y_train[~labeled_mask]

print("Number of labeled training examples:", len(labeled_x_train))
print("Number of unlabeled training examples:", len(unlabeled_x_train))

Define model

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D((2, 2)),
  tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D((2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(32,activation='relu'),
  tf.keras.layers.Dense(10,activation='softmax')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.save_weights("./init")

Supervised (Baseline)

In [None]:
model.load_weights("./init")
best_loss = 1000
for epoch in range(EPOCHS):
  history = model.fit(labeled_x_train, labeled_y_train, validation_split=VALIDATION_SPLIT)
  loss = history.history['val_loss'][0]
  if loss < best_loss:
    model.save_weights("./checkpoint")
    print("Saving checkpoint")
    best_loss = loss
model.load_weights("./checkpoint")
hist = model.evaluate(x_test, y_test)
supervised_loss, supervised_accuracy = hist[0], hist[1]

Semi-Supervised

In [None]:
semi_supervised_loss = []
semi_supervised_accuracy = []

while True:
  best_loss = 1000
  model.load_weights("./init")

  # Shuffle dataset
  shuffler = np.random.permutation(len(labeled_y_train))
  labeled_x_train, labeled_y_train = labeled_x_train[shuffler], labeled_y_train[shuffler]

  # Supervised training
  for epoch in range(EPOCHS):
    model.evaluate(x_test,y_test)
    history = model.fit(labeled_x_train, labeled_y_train, validation_split=VALIDATION_SPLIT)
    loss = history.history['val_loss'][0]
    if loss < best_loss:
      model.save_weights("./checkpoint")
      print("Saving checkpoint")
      best_loss = loss

  # Label unlabeled data
  model.load_weights("./checkpoint")
  prediction = model.predict(unlabeled_x_train)

  # Select and append "confident" unlabeled entries
  mask = np.amax(prediction, axis=1) > THRESHOLD
  new_x_train, new_y_train = unlabeled_x_train[mask], np.argmax(prediction[mask], axis=1)
  labeled_x_train, labeled_y_train = np.vstack((labeled_x_train, new_x_train)), np.concatenate((labeled_y_train, new_y_train))
  
  # Evaluate
  hist = model.evaluate(x_test,y_test)
  semi_supervised_loss.append(hist[0])
  semi_supervised_accuracy.append(hist[1])

  # Exit training when not enough unlabeled data is appended
  if np.sum(mask) < 20:
    break


Generating plots

In [None]:
plt.plot([0, 10], [supervised_loss, supervised_loss], color='b', linestyle='--', linewidth=2)
plt.plot(semi_supervised_loss)
plt.title('Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(["supervised", "semi-supervised"], loc='upper left')
plt.show()


plt.plot([0, 10], [supervised_accuracy, supervised_accuracy], color='b', linestyle='--', linewidth=2)
plt.plot(semi_supervised_accuracy)
plt.title('Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(["supervised", "semi-supervised"], loc='upper left')
plt.show()