In [1]:
import emnist
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras.layers import Input, Lambda, Conv2D, MaxPooling2D, BatchNormalization, Dense, Flatten, Activation, Dropout
from tensorflow.keras.models import Sequential, Model
import methods as M

In [2]:
images, labels = emnist.extract_training_samples('balanced')
images = images.copy().astype('float') / 255

In [3]:
target_shape = images.shape[1:]
target_shape

(28, 28)

In [4]:
classes = np.unique(labels)
n_train_classes = 40
n_val_classes = classes.shape[0] - n_train_classes

In [5]:
rng = np.random.RandomState(seed=0)
reordered_classes = classes.copy()
rng.shuffle(reordered_classes)
reordered_classes

array([28, 33, 30,  4, 18, 11, 42, 31, 22, 10, 27, 32, 29, 43,  2, 45, 26,
       15, 25, 16, 40, 20, 41,  8, 13,  5, 17, 34, 14, 37,  7, 38,  1, 12,
       35, 24,  6, 23, 36, 21, 19,  9, 39, 46,  3,  0, 44], dtype=uint8)

In [6]:
def get_image_by_label(label):
    return images[np.random.choice(np.where(labels == label)[0], 1, False)[0]]

In [7]:
def get_train_triplets(batch_size):
    triplets = [np.zeros((batch_size,) + target_shape) for _ in range(3)]
    train_classes = reordered_classes[:n_train_classes]
    for i in range(batch_size):
        class1, class2 = np.random.choice(train_classes, 2, replace=False)
        assert(class1 != class2)
        triplets[0][i] = get_image_by_label(class1)
        triplets[1][i] = get_image_by_label(class1)
        triplets[2][i] = get_image_by_label(class2)
    return triplets

In [8]:
class DistanceLayer(layers.Layer):
    """
    This layer is responsible for computing the distance between the anchor
    embedding and the positive embedding, and the anchor embedding and the
    negative embedding.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)

In [9]:
encoder = Sequential([
    Conv2D(16, (3, 3), input_shape=target_shape + (1,), activation='relu', kernel_regularizer='l2'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(pool_size=2, strides=(2, 2)),
    Dropout(0.25),

    Conv2D(32, (3, 3), kernel_regularizer='l2'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(pool_size=2, strides=(2, 2)),
    Dropout(0.25),

    Flatten(),
    
    Dense(32),
])

In [10]:
anchor_input = layers.Input(name="anchor", shape=target_shape + (1,))
positive_input = layers.Input(name="positive", shape=target_shape + (1,))
negative_input = layers.Input(name="negative", shape=target_shape + (1,))

distances = DistanceLayer()(
    encoder(anchor_input),
    encoder(positive_input),
    encoder(negative_input),
)

siamese_network = Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

In [11]:
class SiameseModel(Model):
    """
    The Siamese Network model with a custom training and testing loops.

    Computes the triplet loss using the three embeddings produced by the
    Siamese Network.

    The triplet loss is defined as:
       L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0)
    """

    def __init__(self, siamese_network, margin=0.5):
        super().__init__()
        self.siamese_network = siamese_network
        self.margin = margin
        self.loss_tracker = metrics.Mean(name="loss")

    def call(self, inputs):
        return self.siamese_network(inputs)

    def train_step(self, data):
        # GradientTape is a context manager that records every operation that
        # you do inside. We are using it here to compute the loss so we can get
        # the gradients and apply them using the optimizer specified in
        # `compile()`.
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)

        # Storing the gradients of the loss function with respect to the
        # weights/parameters.
        gradients = tape.gradient(loss, self.siamese_network.trainable_weights)

        # Applying the gradients on the model using the specified optimizer
        self.optimizer.apply_gradients(
            zip(gradients, self.siamese_network.trainable_weights)
        )

        # Let's update and return the training loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result(), "pdistance": self.siamese_network(data)[0], "ndistance": self.siamese_network(data)[1]}

    def test_step(self, data):
        loss = self._compute_loss(data)

        # Let's update and return the loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        # The output of the network is a tuple containing the distances
        # between the anchor and the positive example, and the anchor and
        # the negative example.
        ap_distance, an_distance = self.siamese_network(data)

        # Computing the Triplet Loss by subtracting both distances and
        # making sure we don't get a negative value.
        loss = ap_distance - an_distance
        loss = tf.maximum(loss + self.margin, 0.0)
        return loss

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker]

In [12]:
siamese_model = SiameseModel(siamese_network, 5)
siamese_model.compile(optimizer=optimizers.legacy.Adam())
num_iterations = 2000
batch_size = 10
for _ in range(num_iterations):
    siamese_model.fit(get_train_triplets(batch_size))



In [21]:
for _ in range(500):
    siamese_model.fit(get_train_triplets(batch_size))



In [25]:
train_images, train_labels, oneshot_images, oneshot_labels, validation_images, validation_labels = M.get_emnist(np.random.RandomState(seed=0), n_train_classes, n_val_classes, 5)

Output shapes:  [(96000, 28, 28), (96000,), (16800, 28, 28), (16800,)]
Train labels:  [ 1  2  4  5  6  7  8 10 11 12 13 14 15 16 17 18 20 21 22 23 24 25 26 27
 28 29 30 31 32 33 34 35 36 37 38 40 41 42 43 45]
Test labels:  [ 0  3  9 19 39 44 46]


In [26]:
pred = M.train_fewshot(encoder, 1, oneshot_images, oneshot_labels).predict(encoder(validation_images))

print("Accuracy: ", np.sum(pred == validation_labels)/len(validation_labels))

Accuracy:  0.7881896808827915
