In [15]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import Sequence
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback

# Libraries for VGG 
from tensorflow.keras.layers import Dense, Conv2D
from tensorflow.keras.layers import BatchNormalization, Activation
from tensorflow.keras.layers import MaxPooling2D, Input
from tensorflow.keras.layers import Flatten, AveragePooling2D
from tensorflow.keras.models import Model

import numpy as np
import os
import scipy.io as io
from skimage.transform import resize, rotate
from scipy.optimize import linear_sum_assignment

In [16]:
# A to E are standard VGG backbones
# F was customized for IIC
# G is experimental
cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',512, 512, 512, 512, 'M'],
    'F': [64, 'M', 128, 'M', 256, 'M', 512],
    'G': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'A'],
}

class VGG:
    def __init__(self, cfg, input_shape=(28, 28, 1)):
        """VGG network model creator to be used as backbone
            feature extractor

        Arguments:
            cfg (dict): Summarizes the network configuration
            input_shape (list): Input image dims
        """
        self.cfg = cfg
        self.input_shape = input_shape
        self._model = None
        self.build_model()

    def build_model(self):
        """Model builder uses a helper function
            make_layers to read the config dict and
            create a VGG network model
        """
        inputs = Input(shape=self.input_shape, name='x')
        x = VGG.make_layers(self.cfg, inputs)
        self._model = Model(inputs, x, name='VGG')

    @property
    def model(self):
        return self._model

    @staticmethod
    def make_layers(cfg,
                    inputs, 
                    batch_norm=True, 
                    in_channels=1):
        """Helper function to ease the creation of VGG
            network model

        Arguments:
            cfg (dict): Summarizes the network layer 
                configuration
            inputs (tensor): Input from previous layer
            batch_norm (Bool): Whether to use batch norm
                between Conv2D and ReLU
            in_channel (int): Number of input channels
        """
        x = inputs
        for layer in cfg:
            if layer == 'M':
                x = MaxPooling2D()(x)
            elif layer == 'A':
                x = AveragePooling2D(pool_size=3)(x)
            else:
                x = Conv2D(layer,
                           kernel_size=3,
                           padding='same',
                           kernel_initializer='he_normal'
                           )(x)
                if batch_norm:
                    x = BatchNormalization()(x)
                x = Activation('relu')(x)
    
        return x

In [17]:
# build backbone
backbone = VGG(cfg['F'])
backbone.model.summary()

Model: "VGG"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
x (InputLayer)               [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 128)       73856     
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 128)       512     

In [19]:
# Pre-load test data for evaluation
(_, _), (x_test, y_test) = dataset.load_data()
x_test = np.reshape(x_test,[-1, x_test.shape[1], x_test.shape[2], 1])
x_test = x_test.astype('float32') / 255
x_eval = np.zeros([x_test.shape[0], *train_gen.input_shape])
for i in range(x_eval.shape[0]):
    x_eval[i] = x_test[i]
#x_test = x_eval[0:2000]
#y_test = y_test[0:2000]

In [20]:
class DataGenerator(Sequence): # Multi-threaded data generator. Each thread reads a batch of images and performs image transformation such that the image class is unaffected
    def __init__(self, shuffle=True):
        self.shuffle = shuffle # shuffle (Bool): Whether to shuffle the dataset before sampling or not
        (self.data, _), (_, _) = dataset.load_data()
        #self.data = self.data[0:2000]
        self.n_channels = 1
        self.input_shape = [self.data.shape[1], self.data.shape[2], self.n_channels]
        self.n_labels = 10
        self.indexes = [i for i in range(self.data.shape[0])]
        # reshape and normalize input images
        new_shape = [-1, self.data.shape[1], self.data.shape[2], self.n_channels]
        self.data = np.reshape(self.data, new_shape)
        self.data = self.data.astype('float32') / 255                
        self.on_epoch_end()

    def __len__(self): # Number of batches per epoch
        return int(np.floor(len(self.indexes) / batch_size))

    def __getitem__(self, index): # Image sample Indexes for the current batch
        start_index = index * batch_size
        end_index = (index+1) * batch_size
        return self.__data_generation(start_index, end_index)

    def on_epoch_end(self): # If opted, shuffle dataset after each epoch
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def random_crop(self, image, target_shape, crop_sizes): # Perform random crop, resize back to its target shape
        # image (tensor): Image to crop and resize
        # target_shape (tensor): Output shape
        # crop_sizes (list): A list of sizes the image can be cropped
        height, width = image.shape[0], image.shape[1]
        crop_size_idx = np.random.randint(0, len(crop_sizes))
        d = crop_sizes[crop_size_idx]
        x = height - d
        y = width - d
        center = np.random.randint(0, 2)
        if center:
            dx = dy = d // 2
        else:
            dx = np.random.randint(0, d + 1)
            dy = np.random.randint(0, d + 1)
        image = image[dx:(x + dx),dy:(y + dy), :]
        image = resize(image, target_shape)
        return image
    
    def __data_generation(self, start_index, end_index): # Data generation algorithm. The method generates a batch of pair of images (original image X and transformed imaged Xbar). 
        x = self.data[self.indexes[start_index : end_index]] # Given an array of images.  the start index to retrieve a batch, the end index to retrieve a batch
        target_shape = (x.shape[0], *self.input_shape)
        x1 = np.zeros(target_shape)
        x2 = np.zeros(target_shape)
        for i in range(x.shape[0]):
            image = x[i]
            crop_sizes = [8 + i for i in range(0,5,2)]
            image_bar = self.random_crop(image,target_shape[1:],crop_sizes)
            x1[i] = image
            x2[i] = image_bar
        x_train = np.concatenate([x1, x2], axis=0) # for IIC, we are mostly interested in paired images X and Xbar = G(X)

        y = np.zeros(len(x_train))
        return x_train,y

In [21]:
def mi_loss(y_true, y_pred): # Mutual information loss computed from the joint distribution matrix and the marginals
    # y_true (tensor): Not used since this is unsupervised learning
    # y_pred (tensor): stack of softmax predictions the latent vectors (Z and Zbar)
    size = batch_size
    n_labels = y_pred.shape[-1]
    # lower half is Z
    Z = y_pred[0: size, :]
    Z = K.expand_dims(Z, axis=2)
    # upper half is Zbar
    Zbar = y_pred[size: y_pred.shape[0], :]
    Zbar = K.expand_dims(Zbar, axis=1)
    # compute joint distribution (Eq 10.3.2 & .3)
    P = K.batch_dot(Z, Zbar)
    P = K.sum(P, axis=0)
    # enforce symmetric joint distribution (Eq 10.3.4)
    P = (P + K.transpose(P)) / 2.0
    # normalization of total probability to 1.0
    P = P / K.sum(P)
    # marginal distributions (Eq 10.3.5 & .6)
    Pi = K.expand_dims(K.sum(P, axis=1), axis=1)
    Pj = K.expand_dims(K.sum(P, axis=0), axis=0)
    Pi = K.repeat_elements(Pi, rep=n_labels, axis=1)
    Pj = K.repeat_elements(Pj, rep=n_labels, axis=0)
    P = K.clip(P, K.epsilon(), np.finfo(float).max)
    Pi = K.clip(Pi, K.epsilon(), np.finfo(float).max)
    Pj = K.clip(Pj, K.epsilon(), np.finfo(float).max)
    # negative MI loss (Eq 10.3.7)
    neg_mi = K.sum((P * (K.log(Pi) + K.log(Pj) - K.log(P))))
    # each head contribute 1/n_heads to the total loss
    return neg_mi/heads

train_gen = DataGenerator(shuffle=True)
print(train_gen.data.shape)
n_labels = train_gen.n_labels
inputs = Input(shape=train_gen.input_shape, name='x') # Build the n_heads of the IIC model
my_backbone = backbone.model
x = my_backbone(inputs)
x = Flatten()(x)
outputs = [] # number of output heads
for i in range(heads):
    name = "z_head%d" % i
    outputs.append(Dense(n_labels,
                         activation='softmax',
                         name=name)(x))
my_model = Model(inputs, outputs, name='encoder')
my_model.compile(optimizer=Adam(learning_rate=1e-3), loss=mi_loss)
my_model.summary()

(60000, 28, 28, 1)
Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
x (InputLayer)               [(None, 28, 28, 1)]       0         
_________________________________________________________________
VGG (Functional)             (None, 3, 3, 512)         1553664   
_________________________________________________________________
flatten_2 (Flatten)          (None, 4608)              0         
_________________________________________________________________
z_head0 (Dense)              (None, 10)                46090     
Total params: 1,599,754
Trainable params: 1,597,834
Non-trainable params: 1,920
_________________________________________________________________


In [22]:
accuracy = 0

def lr_schedule(epoch): # Simple learning rate scheduler // Argument: epoch (int): Which epoch 
    lr = 1e-3
    power = epoch // 400
    lr *= 0.8**power
    return lr

def unsupervised_labels(y, yp, n_classes, n_clusters): # Linear assignment algorithm
    assert n_classes == n_clusters # Arguments: y (tensor): Ground truth labels // yp (tensor): Predicted clusters // n_classes (int): Number of classes // n_clusters (int): Number of clusters
    C = np.zeros([n_clusters, n_classes]) # initialize count matrix
    for i in range(len(y)): # populate count matrix
        C[int(yp[i]), int(y[i])] += 1
    row, col = linear_sum_assignment(-C) # optimal permutation using Hungarian Algo the higher the count, the lower the cost so we use -C for linear assignment
    accuracy = C[row, col].sum() / C.sum() # compute accuracy
    return accuracy * 100

class AccuracyCallback(Callback): # Callback to compute the accuracy every epoch by calling the eval() method.
    def __init__(self):
        super(AccuracyCallback, self).__init__()
        self.general_accuracy = 0
    def on_epoch_end(self, epoch, logs=None): # Evaluate the accuracy of the current model weights
        y_pred = my_model.predict(x_test)
        for head in range(heads): # accuracy per head
            if heads == 1:
                y_head = y_pred
            else:
                y_head = y_pred[head]
            y_head = np.argmax(y_head, axis=1)

            accuracy = unsupervised_labels(list(y_test),list(y_head),n_labels,n_labels)
            info = "Head %d accuracy: %0.2f%%"
            if self.general_accuracy > 0:
                info += ", Old best accuracy: %0.2f%%"
                data = (head, accuracy, self.general_accuracy)
            else:
                data = (head, accuracy)
            print(info % data)
            if accuracy > self.general_accuracy:
                self.general_accuracy = accuracy # if accuracy improves during training, save the model weights on a file
                    
# Train function uses the data generator, accuracy computation, and learning rate scheduler callbacks        
my_model.fit_generator(generator=train_gen,
                            use_multiprocessing=False,
                            epochs=epochs,
                            callbacks=[AccuracyCallback(), LearningRateScheduler(lr_schedule,verbose=1)],
                            workers=4,
                            shuffle=True)

Epoch 1/10

Epoch 00001: LearningRateScheduler reducing learning rate to 0.001.
Head 0 accuracy: 11.87%
Epoch 2/10

Epoch 00002: LearningRateScheduler reducing learning rate to 0.001.
Head 0 accuracy: 17.40%, Old best accuracy: 11.87%
Epoch 3/10

Epoch 00003: LearningRateScheduler reducing learning rate to 0.001.
Head 0 accuracy: 11.35%, Old best accuracy: 17.40%
Epoch 4/10

Epoch 00004: LearningRateScheduler reducing learning rate to 0.001.

KeyboardInterrupt: 