In [1]:
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPooling2D
from tensorflow.keras.layers import Dropout, Dense, Flatten, Input
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import OneHotEncoder
import random
import time

In [2]:
data_size = 3000
lr = 0.001
batch_size = 32
val_frac = 0.2
num_epochs = 50
random.seed(123)
np.random.seed(123)
tf.random.set_seed(123)

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train[:data_size].astype(np.float32)/255.0
y_train = y_train[:data_size]
x_test = x_test.astype(np.float32)/255.0
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))

In [4]:
y_train = y_train.reshape(-1, 1)
y_test = y_test.reshape(-1, 1)
y_train = OneHotEncoder(sparse=True).fit_transform(y_train).toarray()
y_test = OneHotEncoder(sparse=True).fit_transform(y_test).toarray()

In [5]:
def get_mean_std(images):
    mean = ()
    std = ()
    for i in range(images.shape[-1]):
        mean += (np.mean(images[:, :, :, i]),)
        std += (np.std(images[:, :, :, i]),)
    return mean, std

def normalize(images, mean, std):
    for i in range(images.shape[-1]):
        images[:, :, :, i] = (images[:, :, :, i] - mean[i])/std[i]
    return images

mean, std = get_mean_std(x_train)
x_train = normalize(x_train, mean, std)
x_test = normalize(x_test, mean, std)

In [6]:
perm_idx = np.random.permutation(x_train.shape[0])
val_idx = perm_idx[:int(val_frac*x_train.shape[0])]
train_idx = perm_idx[int(val_frac*x_train.shape[0]):]
x_val, y_val = x_train[val_idx], y_train[val_idx]
x_train, y_train = x_train[train_idx], y_train[train_idx]

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

In [17]:
def fit_gmm(losses):    
    losses = losses.numpy()
    losses = losses-losses.min()/(losses.max()-losses.min())
    input_loss = np.reshape(losses, [-1, 1])

    gmm = GaussianMixture(n_components=2, max_iter=50, tol=1e-2, reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss)
    prob = prob[:, gmm.means_.argmin()]
    return prob

@tf.function
def test_step(x_batch, y_batch, loss_fn):    
    y_pred = model(x_batch)
    loss = loss_fn(y_pred, y_batch)
    y_pred_single = tf.math.argmax(model(x_batch, training=False), axis=1)
    y_batch_single = tf.math.argmax(y_batch, axis = 1)
    acc_metric.update_state(y_pred_single, y_batch_single)
    return loss

# def eval(model, val_ds, acc_metric, batch_size, losses):
def eval(model, val_ds, acc_metric, batch_size, len_val_ds):
    val_loss = 0
    losses = tf.Variable(tf.zeros(len_val_ds))
    loss_fn = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    for idx, (x_batch, y_batch) in enumerate(val_ds):   
        loss = test_step(x_batch, y_batch, loss_fn)
        val_loss += tf.math.reduce_mean(loss)
        for b in tf.range(x_batch.shape[0]):
            # i = tf.cast(batch_size*idx, tf.int32)+b
            i = batch_size*idx+b
            # if i < losses.shape[0]:   
            losses[i].assign(tf.cast(loss[b], tf.float32)) 
    val_acc = acc_metric.result()
    acc_metric.reset_states()

    # fit loss to Gaussian mixture model
    # prob = tf.py_function(func=fit_gmm, inp=[losses], Tout=tf.float32)
    prob = fit_gmm(losses)
    
    return val_loss, val_acc, prob 

In [18]:
model = Sequential([
    Input(shape=(28, 28, 1)),
    Conv2D(10, kernel_size = 5, padding='same'),
    MaxPooling2D(pool_size = 2),
    Conv2D(20, kernel_size = 3, padding = 'same'),      
    Dropout(0.5),
    MaxPooling2D(pool_size = 2),
    Flatten(),
    Dense(30, activation = 'relu'),
    Dropout(0.5),
    Dense(10, activation = 'softmax')    
])

In [19]:
batch_size = 32
acc_metric = tf.keras.metrics.CategoricalAccuracy()
len_val_ds = len(y_val)

In [20]:
tf.keras.backend.clear_session()

model.compile(optimizer='adam', 
                loss = 'categorical_crossentropy',
                metrics = ['accuracy'])
for epoch in range(num_epochs):
    history = model.fit(train_ds, epochs=1)    
    # losses = tf.Variable(tf.zeros(len_val_ds))
    # losses = tf.zeros(len_val_ds)
    # val_loss, val_acc, _ = eval(model, val_ds, acc_metric, batch_size, losses)
    val_loss, val_acc, _ = eval(model, val_ds, acc_metric, batch_size, len_val_ds)
    print(f"Epoch {epoch+1} : val_loss : {val_loss}, val_acc : {val_acc}")

Epoch 1 : val_loss : 221.46338953326145, val_acc : 0.6315789222717285
Epoch 2 : val_loss : 168.5628434934964, val_acc : 0.6842105388641357
Epoch 3 : val_loss : 128.8714358964935, val_acc : 0.6315789222717285
Epoch 4 : val_loss : 99.69262322168409, val_acc : 0.7894737124443054
Epoch 5 : val_loss : 90.9938499895264, val_acc : 0.7894737124443054
Epoch 6 : val_loss : 71.19128402862286, val_acc : 0.7894737124443054
Epoch 7 : val_loss : 64.81008659924798, val_acc : 0.7368420958518982
Epoch 8 : val_loss : 54.75731562107679, val_acc : 0.7368420958518982
Epoch 9 : val_loss : 51.1404191283976, val_acc : 0.7368420958518982
Epoch 10 : val_loss : 44.49596979530876, val_acc : 0.7368420958518982
Epoch 11 : val_loss : 44.98336947620851, val_acc : 0.7368420958518982
Epoch 12 : val_loss : 37.99436708119962, val_acc : 0.7368420958518982
Epoch 13 : val_loss : 38.889302977943316, val_acc : 0.7368420958518982
Epoch 14 : val_loss : 38.29008540822656, val_acc : 0.7894737124443054
Epoch 15 : val_loss : 32.7265

In [31]:
model.compile(optimizer='adam', 
                loss = 'categorical_crossentropy',
                metrics = ['accuracy'])
start_time = time.time()
history = model.fit(train_ds, validation_data = val_ds, epochs=50)
print(f"Training time : {(time.time()-start_time)/60:.2f} min.")


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Training time : 0.25 min.
