## **Setup**

In [1]:
# Import the necesary packages
import os
import gc
import numpy as np

import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

import tensorflow as tf
import tensorflow_addons as tfa

In [2]:
config = {
    # Basic information
    "AUTHOR": "Kiernan",
    
    # Data information
    "IMAGE_SIZE": (28,28,1),
    
    # Training params
    "LR_STYLE": "REDUCE", #['REDUCE', 'SCHEDULE']
    "LR": 0.001, #0.000001,
    "BATCH_SIZE": 64,
    "EPOCHS": 30,
    
    # Loss parameters
    "MARGIN": 0.5,
    
    
    # Model params
    # "RUN_FOR_BASE": "3cmjv1lo",
    # "FREEZE": "ALL", #['ALL', 'BN', 'None'] which layers to freeze in the body model
    
    # Model params
    "FIRST_FILTERS": 16,
    "CONV_LAYERS": 4,
    "N_FILTERS": 8,
    "KERNEL_SIZE": (3,3),
    "EMBEDDING_SIZE": 16,
    "VECTOR_SIZE": 16,
    "DROPOUT": 0.2
}

## **Initialize WANDB**

In [3]:
import wandb
from wandb.keras import WandbCallback
from secrets import WANDB
wandb.login(key=WANDB)
run = wandb.init(project="deep-clustering-evaluation", entity="kmcguigan", group="arcface-model", config=config, job_type="train")

[34m[1mwandb[0m: Currently logged in as: [33mall-off-nothing[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\kiern/.netrc
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


## **Loading Data**

### **Load the presplit data**

In [4]:
with open('data/train.npy', mode='rb') as infile:
    X_train = np.load(infile, allow_pickle=True)
    y_train = np.load(infile, allow_pickle=True)

with open('data/val.npy', mode='rb') as infile:
    X_val = np.load(infile, allow_pickle=True)
    y_val = np.load(infile, allow_pickle=True)

with open('data/test.npy', mode='rb') as infile:
    X_test = np.load(infile, allow_pickle=True)
    y_test = np.load(infile, allow_pickle=True)

print(f"Train data shape: {X_train.shape} Val data shape: {X_val.shape} Test data shape: {X_test.shape}")

Train data shape: (50000, 28, 28, 1) Val data shape: (10000, 28, 28, 1) Test data shape: (10000, 28, 28, 1)


### **Create a data generator**

In [20]:
def to_dataset(X, y):
    ds = tf.data.Dataset.from_tensor_slices(({"images":X,"labels":y},y))
    ds = ds.cache().shuffle(X.shape[0]+1).batch(config["BATCH_SIZE"]).prefetch(tf.data.experimental.AUTOTUNE)
    return ds

train_ds = to_dataset(X_train, y_train)
val_ds = to_dataset(X_val, y_val)
test_ds = to_dataset(X_test, y_test)

## **Define Metrics**

In [6]:
def pairwise_distance(embeddings, squared=False):
    dot = tf.matmul(embeddings, tf.transpose(embeddings))
    square_norm = tf.linalg.diag_part(dot)
    distances = tf.expand_dims(square_norm, 1) - 2.0 * dot + tf.expand_dims(square_norm, 0)
    distances = tf.maximum(distances, 0.0)
    if(not squared):
        mask = tf.cast(tf.equal(distances, 0.0), tf.float32)
        distances = distances + mask * 1e-16
        distances = tf.sqrt(distances)
        distances = distances * (1.0 - mask)
    return distances

def angular_distances(embeddings):
    embeddings = tf.math.l2_normalize(embeddings, axis=-1)
    angular_distances = 1 - tf.matmul(embeddings, tf.transpose(embeddings))
    angular_distances = tf.maximum(angular_distances, 0.0)
    mask_offdiag = tf.ones_like(angular_distances) - tf.linalg.diag(tf.ones([tf.shape(angular_distances)[0]]))
    angular_distances = tf.math.multiply(angular_distances, mask_offdiag)
    return angular_distances

def apply_metric(embeddings, labels, metric):
    adj = tf.equal(labels, tf.transpose(labels))
    adj_not = tf.math.logical_not(adj)
    adj = tf.cast(adj, tf.float32) - tf.linalg.diag(tf.ones([tf.shape(labels)[0]]))
    adj_not = tf.cast(adj_not, tf.float32)
    distances = metric(embeddings)
    pos_dist = tf.math.multiply(distances, adj)
    neg_dist = tf.math.multiply(distances, adj_not)
    pos_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(pos_dist, mask=tf.math.equal(adj, 1.0)))
    neg_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(neg_dist, mask=tf.math.equal(adj_not, 1.0)))
    return pos_dist_mean, neg_dist_mean

In [7]:
def positive_distance(labels, embeddings):
    adj = tf.equal(labels, tf.transpose(labels))
    adj = tf.cast(adj, tf.float32) - tf.linalg.diag(tf.ones([tf.shape(labels)[0]]))
    distances = pairwise_distance(embeddings)
    pos_dist = tf.math.multiply(distances, adj)
    pos_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(pos_dist, mask=tf.math.equal(adj, 1.0)))
    return pos_dist_mean

def negative_distance(labels, embeddings):
    adj = tf.math.logical_not(tf.equal(labels, tf.transpose(labels)))
    adj = tf.cast(adj, tf.float32)
    distances = pairwise_distance(embeddings)
    neg_dist = tf.math.multiply(distances, adj)
    neg_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(neg_dist, mask=tf.math.equal(adj, 1.0)))
    return neg_dist_mean

def positive_angular(labels, embeddings):
    adj = tf.equal(labels, tf.transpose(labels))
    adj = tf.cast(adj, tf.float32) - tf.linalg.diag(tf.ones([tf.shape(labels)[0]]))
    distances = angular_distances(embeddings)
    pos_dist = tf.math.multiply(distances, adj)
    pos_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(pos_dist, mask=tf.math.equal(adj, 1.0)))
    return pos_dist_mean

def negative_angular(labels, embeddings):
    adj = tf.math.logical_not(tf.equal(labels, tf.transpose(labels)))
    adj = tf.cast(adj, tf.float32)
    distances = angular_distances(embeddings)
    neg_dist = tf.math.multiply(distances, adj)
    neg_dist_mean = tf.reduce_mean(tf.ragged.boolean_mask(neg_dist, mask=tf.math.equal(adj, 1.0)))
    return neg_dist_mean

In [8]:
class MetricHandler:
    def __init__(self, metric):
        self.has_read = {0: False, 1: False}
        self.metric = metric
        self.results = {0: None, 1: None}
        
    def read_metric(self, reader, embeddings, labels):
        if(self.has_read[reader]):
            raise Excpetion(f'{reader} reader re-reading data it already has')
        other = 1 - reader
        if(self.has_read[other]):
            value = self.results[reader]
            self.results[0] = None
            self.results[1] = None
            return value
        metric_results = apply_metric(embeddings, labels, self.metric)
        self.results[0] = metric_results[0]
        self.results[1] = metric_results[1]
        return self.results[reader]
    
distance_handler = MetricHandler(pairwise_distance)
angular_handler = MetricHandler(angular_distances)
    
def pos_distance(labels, embeddings):
    return distance_handler.read_metric(0, embeddings, labels)
def neg_distance(labels, embeddings):
    return distance_handler.read_metric(1, embeddings, labels)

def pos_angle(labels, embeddings):
    return angular_handler.read_metric(0, embeddings, labels)
def neg_angle(labels, embeddings):
    return angular_handler.read_metric(1, embeddings, labels)

In [9]:
def get_lr_callback(plot=False, batch_size=config['BATCH_SIZE'], epochs=config['EPOCHS']):
    lr_start   = config['LR']
    lr_max     = config['LR'] * 5 * batch_size  
    lr_min     = config['LR']
    lr_ramp_ep = 4
    lr_sus_ep  = 0
    lr_decay   = 0.9
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
        return lr
    if(plot):
        epochs = list(range(epochs))
        learning_rates = [lrfn(x) for x in epochs]
        plt.scatter(epochs,learning_rates)
        ax = plt.gca()
        ax.get_yaxis().get_major_formatter().set_scientific(False)
        plt.show()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

if(config["LR_STYLE"] == "SCHEDULE"):
    lr_callback = get_lr_callback(plot=True)
elif(config["LR_STYLE"] == "REDUCE"):
    lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=2)
else:
    raise Exception(f"config LR_STYLE {config['LR_STYLE']} is not understood")

## **Create Model**

### **Load the pretrained body model**

In [10]:
def freeze_all(model):
    for layer in model.layers:
        layer.trainable=False

def freeze_BN(model):
    # Unfreeze layers while leaving BatchNorm layers frozen
    for layer in model.layers:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True
        else:
            layer.trainable = False
            
def freeze_none(model):
    for layer in model.layers:
        layer.trainable = True

In [11]:
def create_body(image_shape):
    inputs = tf.keras.layers.Input(shape=image_shape)
    
    def conv_block(layer_inputs, n_filters, kernel_size, **kwargs):
        x = tf.keras.layers.Conv2D(n_filters, kernel_size, padding="same", **kwargs)(layer_inputs)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        return x
    
    x = conv_block(inputs, config["FIRST_FILTERS"], config["KERNEL_SIZE"], strides=2)
    for _ in range(config["CONV_LAYERS"]):
        x = conv_block(x, config["N_FILTERS"], config["KERNEL_SIZE"])
    
    x = tf.keras.layers.Conv2D(config["EMBEDDING_SIZE"], (1,1), padding="same")(x)
    outputs = tf.keras.layers.GlobalAveragePooling2D()(x)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs, name="body")

body = create_body(X_train.shape[1:])
body.summary()

Model: "body"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 16)        160       
                                                                 
 batch_normalization (BatchN  (None, 14, 14, 16)       64        
 ormalization)                                                   
                                                                 
 re_lu (ReLU)                (None, 14, 14, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 14, 14, 8)         1160      
                                                                 
 batch_normalization_1 (Batc  (None, 14, 14, 8)        32        
 hNormalization)                                              

### **Create the head**

In [12]:
def create_head(input_shape):
    inputs = tf.keras.layers.Input(shape=(input_shape,))
    x = tf.keras.layers.Dropout(config["DROPOUT"])(inputs)
    x = tf.keras.layers.Dense(config['VECTOR_SIZE'])(x)
    outputs = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=-1))(x)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs, name="head")

head = create_head(input_shape=config['EMBEDDING_SIZE'])
head.summary()

Model: "head"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 16)]              0         
                                                                 
 dropout (Dropout)           (None, 16)                0         
                                                                 
 dense (Dense)               (None, 16)                272       
                                                                 
 lambda (Lambda)             (None, 16)                0         
                                                                 
Total params: 272
Trainable params: 272
Non-trainable params: 0
_________________________________________________________________


### **Create the loss function**

In [13]:
# Arcmarginproduct class keras layer
import math
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])
        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

### **Create the full model**

In [14]:
def get_model(image_size, nclasses):
    inputs = tf.keras.layers.Input(shape=image_size, name="images")
    labels = tf.keras.layers.Input(shape=(), name="labels")
    x = body(inputs)
    embeddings = head(x)
    x = ArcMarginProduct(nclasses)([embeddings, labels])
    outputs = tf.keras.layers.Softmax(dtype='float32')(x)
    model = tf.keras.models.Model(inputs=[inputs, labels], outputs=outputs)
    embedding_model = tf.keras.models.Model(inputs=inputs, outputs=embeddings)
    
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam(learning_rate=config['LR'])
    metrics = [
        positive_distance,
        negative_distance,
        positive_angular,
        negative_angular
    ]
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    return model, embedding_model

model, embedding_model = get_model(config["IMAGE_SIZE"], nclasses=10)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 images (InputLayer)            [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 body (Functional)              (None, 16)           3408        ['images[0][0]']                 
                                                                                                  
 head (Functional)              (None, 16)           272         ['body[0][0]']                   
                                                                                                  
 labels (InputLayer)            [(None,)]            0           []                               
                                                                                              

## **Evaluate Models Initial Performance**

In [15]:
def kmeans_cluster_accuracy(X, y):
    embeddings = embedding_model.predict(X)
    kmeans = KMeans(n_clusters=10, random_state=123)
    labels = kmeans.fit_predict(embeddings)
    
    label_mappings = {}
    for label in np.unique(labels):
        values, counts = np.unique(y[np.where(labels==label)], return_counts=True)
        label_mappings[label] = values[np.argmax(counts)]
    print(label_mappings)
    
    map_labels = np.vectorize(lambda x: label_mappings[x])
    mapped_labels = map_labels(labels)
    return accuracy_score(y.reshape((-1,1)), mapped_labels.reshape((-1,1)))

In [16]:
acc = kmeans_cluster_accuracy(X_test, y_test)
print(acc)
run.log({'test/init-test-clustering-accuracy': acc})

{0: 4, 1: 2, 2: 1, 3: 4, 4: 0, 5: 0, 6: 9, 7: 0, 8: 1, 9: 8}
0.2509


In [17]:
acc = kmeans_cluster_accuracy(X_val, y_val)
print(acc)
run.log({'test/init-val-clustering-accuracy': acc})

{0: 3, 1: 2, 2: 0, 3: 1, 4: 4, 5: 4, 6: 0, 7: 0, 8: 1, 9: 9}
0.249


## **Train the Model**

In [18]:
stopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4, restore_best_weights=True)
hist = model.fit(train_ds,
                 validation_data=val_ds,
                 epochs=config["EPOCHS"],
                 callbacks=[stopper, lr_callback, WandbCallback()])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30


In [21]:
ev = model.evaluate(test_ds, return_dict=True)
log_dict = {f'test/{met}': val for met, val in ev.items()}
run.log(log_dict)



In [22]:
acc = kmeans_cluster_accuracy(X_test, y_test)
print(acc)
run.log({'test/test-clustering-accuracy': acc})

{0: 1, 1: 3, 2: 6, 3: 0, 4: 4, 5: 2, 6: 9, 7: 8, 8: 5, 9: 7}
0.9756


In [23]:
acc = kmeans_cluster_accuracy(X_val, y_val)
print(acc)
run.log({'test/val-clustering-accuracy': acc})

{0: 4, 1: 3, 2: 1, 3: 0, 4: 5, 5: 2, 6: 7, 7: 9, 8: 8, 9: 6}
0.9722


In [24]:
run.finish()

0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
loss,█▂▂▁▁▁▁▁▁▁▁
lr,█████████▁▁
negative_angular,▁▆▇▇▇██████
negative_distance,▁▆▇▇▇██████
positive_angular,█▆▄▃▃▂▂▂▁▁▁
positive_distance,█▇▅▄▃▂▂▂▂▁▁
test/init-test-clustering-accuracy,▁
test/init-val-clustering-accuracy,▁
test/loss,▁

0,1
best_epoch,6.0
best_val_loss,1.27693
epoch,10.0
loss,1.17225
lr,0.0009
negative_angular,0.96424
negative_distance,1.33124
positive_angular,0.26055
positive_distance,0.40224
test/init-test-clustering-accuracy,0.2509
