In [None]:
import numpy as np
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn import metrics
from sklearn.cluster import KMeans
from autoencoder import autoencoder
from tensorflow.keras.datasets import mnist
from clusteringlayer import ClusteringLayer
from tensorflow.keras.initializers import VarianceScaling
from sklearn.utils.linear_assignment_ import linear_assignment

In [None]:
sns.set(font_scale = 3)

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))
x = x.reshape((x.shape[0], -1))
x = np.divide(x, 255.)

In [None]:
dims = [x.shape[-1], 500, 500, 2000, 10]
init = VarianceScaling(scale = 1. / 3., mode = 'fan_in',
                           distribution = 'uniform')
pretrain_optimizer = tf.keras.optimizers.SGD(lr = 1, momentum = 0.9)
pretrain_epochs = 30
batch_size = 256

In [None]:
autoencoder, encoder = autoencoder(dims, init = init)

In [None]:
tf.keras.utils.plot_model(encoder, show_shapes = True, dpi = 64)

In [None]:
tf.keras.utils.plot_model(autoencoder, show_shapes = True, dpi = 64)

In [None]:
autoencoder.compile(optimizer = pretrain_optimizer, loss = 'mse')
autoencoder.fit(x, x, batch_size = batch_size, epochs = pretrain_epochs, verbose = 2)

In [None]:
clustering_layer = ClusteringLayer(10, name = 'clustering')(encoder.output)
model = tf.keras.models.Model(inputs = encoder.input, outputs = clustering_layer)

In [None]:
tf.keras.utils.plot_model(model, show_shapes = True, dpi = 64)

In [None]:
model.compile(optimizer = tf.keras.optimizers.SGD(0.01, 0.9), loss = 'kld')

In [None]:
#  initialize cluster centers using k-means
kmeans = KMeans(n_clusters = 10, n_init = 20)
y_pred = kmeans.fit_predict(encoder.predict(x))

In [None]:
y_pred_last = np.copy(y_pred)

In [None]:
model.get_layer(name = 'clustering').set_weights([kmeans.cluster_centers_])

In [None]:
# Compute p_i by first raising q_i to the second power and then normalizing by frequency per cluster:

# computing an auxiliary target distribution
def target_distribution(q):
    weight = q ** 2 / q.sum(0)
    return (weight.T / weight.sum(1)).T

In [None]:
loss = 0
index = 0
maxiter = 6000
update_interval = 140
index_array = np.arange(x.shape[0])

In [None]:
tol = 0.001

In [None]:
for ite in range(int(maxiter)):
    if ite % update_interval == 0:
        q = model.predict(x, verbose = 0)
        p = target_distribution(q)  # update the auxiliary target distribution p

        # evaluate the clustering performance
        y_pred = q.argmax(1)
        if y is not None:
            acc = np.round(metrics.accuracy_score(y, y_pred), 5)
            nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
            ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
            loss = np.round(loss, 5)
            print('Iter %d: acc = %.5f, nmi = %.5f, ari = %.5f' % (ite, acc, nmi, ari), ' ; loss=', loss)

        # check stop criterion - model convergence
        delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
        y_pred_last = np.copy(y_pred)
        if ite > 0 and delta_label < tol:
            print('delta_label ', delta_label, '< tol ', tol)
            print('Reached tolerance threshold. Stopping training.')
            break
    idx = index_array[index * batch_size: min((index+1) * batch_size, x.shape[0])]
    loss = model.train_on_batch(x = x[idx], y = p[idx])
    index = index + 1 if (index + 1) * batch_size <= x.shape[0] else 0

In [None]:
# Eval.
q = model.predict(x, verbose=0)
p = target_distribution(q)  # update the auxiliary target distribution p

# evaluate the clustering performance
y_pred = q.argmax(1)
if y is not None:
    acc = np.round(metrics.accuracy_score(y, y_pred), 5)
    nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
    ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
    loss = np.round(loss, 5)
    print('Acc = %.5f, nmi = %.5f, ari = %.5f' % (acc, nmi, ari), ' ; loss=', loss)

In [None]:
confusion_matrix = metrics.confusion_matrix(y, y_pred)

plt.figure(figsize = (16, 14))
sns.heatmap(confusion_matrix, annot = True, fmt = "d", annot_kws = {"size": 20});
plt.title("Confusion matrix", fontsize = 30)
plt.ylabel('True label', fontsize = 25)
plt.xlabel('Clustering label', fontsize = 25)
plt.show()

In [None]:
# linear assignment- Munkres' Assignment Algorithm


y_true = y.astype(np.int64)
D = max(y_pred.max(), y_true.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
# Confusion matrix.
for i in range(y_pred.size):
    w[y_pred[i], y_true[i]] += 1
ind = linear_assignment(-w)

sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

In [None]:
w

In [None]:
ind

In [None]:
w.argmax(1)