In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd

from model import CNN
from main import train, validate

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.InteractiveSession(config=config)

# Parameters

In [5]:
args = {
    'epochs': 10,                       # number of epochs
    'batch_size': 64,                   # examples per batch (default: 64)
    'learning_rate': 0.01,              # learning_rate, (default: 0.01)
    'labelled_examples': 1,             # percentage labelled examples
    'validation_examples': 20,          # percentage validation examples
    'val_iteration': 800,               # number of batches before validation
    'T': 0.5,                           # temperature sharpening ratio (default: 0.5)
    'K': 2,                             # number of rounds of augmentation (default: 2)
    'alpha': 0.75,                      # param for sampling from Beta distribution (default: 0.75)
    'lambda_u': 100,                    # multiplier for unlabelled loss (default: 100)
    'rampup_length': 8,                 # rampup length for unlabelled loss multiplier (default: 16)
    'weight_decay': 0.02,               # decay rate for model vars (default: 0.02)
    'ema_decay': 0.999,                 # ema decay for ema model vars (default: 0.999)
}

# Dataset

In [8]:
train_labelled = tfds.load('mnist', split=f'train[0:{args["labelled_examples"]}%]')
train_unlabelled = tfds.load('mnist', split=f'train[{args["labelled_examples"]}:{100-args["validation_examples"]}%]')
val_dataset = tfds.load('mnist', split=f'train[{100-args["validation_examples"]}:100%]')
test_dataset = tfds.load('mnist', split='test')

pd.DataFrame({"Training": [len(train_labelled), len(train_unlabelled)], "Validation": [len(val_dataset), 0], "Test": [len(test_dataset), 0]}, index=["Labelled", "Unlabelled"])

Unnamed: 0,Training,Validation,Test
Labelled,600,12000,10000
Unlabelled,47400,0,0


# Model
We creeren 2 identieke modellen, een normale en een "Exponential Moving Average" (EMA). De eerste wordt gebruikt om direct op te training, het EMA model is een moving average van het eerste model en kan als alternatief eindproduct gebruikt worden. Deze is minder gevoelig voor grote veranderingen aan het model. Beide modellen worden getest aan het eind.

In [5]:
# construct 2 versions of the same model
model = CNN()
ema_model = CNN()
ema_model.set_weights(model.get_weights())

model.build(input_shape=(None, 28, 28 ,1))
ema_model.build(input_shape=(None, 28, 28 ,1))

# choose the optimizer
optimizer = tf.keras.optimizers.Adam(lr=args['learning_rate'])

# Training
We trainen voor een aantal epochs waarbij willekeurig door de trainigsdata wordt gelopen. Totaal worden er <val_iteration * batch_size> gelabelde en ongelabelde voorbeelden bekeken per epoch. Tussen de epochs door wordt zowel het EMA model als het normale model getest op de validatie set. 

In [6]:
for epoch in range(args['epochs']):
    train(train_labelled, train_unlabelled, model, ema_model, optimizer, epoch, args)
    validate(val_dataset, ema_model, epoch, args, split='Validation EMA')
    validate(val_dataset, model, epoch, args, split='Validation')


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0000: Validation EMA XE Loss: 2.4070, Validation EMA Accuracy: 13.900%
Epoch 0000: Validation XE Loss: 0.3012, Validation Accuracy: 96.050%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0001: Validation EMA XE Loss: 2.1021, Validation EMA Accuracy: 17.117%
Epoch 0001: Validation XE Loss: 0.3077, Validation Accuracy: 96.317%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0002: Validation EMA XE Loss: 0.5827, Validation EMA Accuracy: 83.308%
Epoch 0002: Validation XE Loss: 0.2226, Validation Accuracy: 96.342%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0003: Validation EMA XE Loss: 0.5510, Validation EMA Accuracy: 82.075%
Epoch 0003: Validation XE Loss: 0.1889, Validation Accuracy: 96.925%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0004: Validation EMA XE Loss: 0.1590, Validation EMA Accuracy: 97.533%
Epoch 0004: Validation XE Loss: 0.1384, Validation Accuracy: 97.175%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0005: Validation EMA XE Loss: 0.2019, Validation EMA Accuracy: 97.433%
Epoch 0005: Validation XE Loss: 0.2133, Validation Accuracy: 96.300%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0006: Validation EMA XE Loss: 1.3287, Validation EMA Accuracy: 88.642%
Epoch 0006: Validation XE Loss: 0.1599, Validation Accuracy: 97.633%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0007: Validation EMA XE Loss: 10.1088, Validation EMA Accuracy: 26.842%
Epoch 0007: Validation XE Loss: 0.1378, Validation Accuracy: 97.417%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0008: Validation EMA XE Loss: 0.1778, Validation EMA Accuracy: 98.167%
Epoch 0008: Validation XE Loss: 0.1148, Validation Accuracy: 97.617%


  0%|          | 0/800 [00:00<?, ?batch/s]

Epoch 0009: Validation EMA XE Loss: 0.1338, Validation EMA Accuracy: 98.283%
Epoch 0009: Validation XE Loss: 0.1271, Validation Accuracy: 97.583%


# Test
Uiteindelijk testen we het EMA model op de test dataset.

In [7]:
validate(test_dataset, ema_model, epoch, args, split='Test')

Epoch 0009: Test XE Loss: 0.1304, Test Accuracy: 98.340%
