#Takeaways


**Make sure to use tf.keras.backend functions in the cost function call method so that the operations are differentiable by Tensorflow. **



In [17]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.python.keras.utils.vis_utils import plot_model
from tensorflow.keras import backend as K

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFont, ImageDraw
import random


def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    n = min([len(digit_indices[d]) for d in range(10)]) - 1 #number images from the least frequent label- 1
    
    for d in range(10):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, 10)
            dn = (d + inc) % 10
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
            
    return np.array(pairs), np.array(labels)


def create_pairs_on_set(images, labels):
    
    digit_indices = [np.where(labels == i)[0] for i in range(10)]
    pairs, y = create_pairs(images, digit_indices)
    y = y.astype('float32')
    
    return pairs, y


def show_image(image):
    plt.figure()
    plt.imshow(image)
    plt.colorbar()
    plt.grid(False)
    plt.show()

def initialize_base_network():
  input_layer = Input(shape=(28,28,))
  flatten = Flatten()(input_layer)
  dense1 = Dense(128, activation="selu", kernel_initializer="lecun_normal", name="dense1")(flatten)
  dout1 = Dropout(0.1, name="dout1")(dense1)
  dense2 = Dense(128, activation="selu", kernel_initializer="lecun_normal", name="dense2")(dout1)
  dout2 = Dropout(0.1, name="dout2")(dense2)
  dense3 = Dense(128, name="dense3")(dout2)
  return Model(inputs= input_layer, outputs = dense3)

def euclidian_distance(vecs):
  v1, v2 = vecs
  return K.sqrt( K.maximum(  K.sum( K.square(v1-v2), axis=1, keepdims=True), K.epsilon() ) ) 

def eucl_dist_output_shape(shapes):
  shape1, shape2 = shapes
  return (shape1[0], 1)

# load the dataset
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# prepare train and test sets
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')

# normalize values
train_images = train_images / 255.0
test_images = test_images / 255.0

# create pairs on train and test sets
tr_pairs, tr_y = create_pairs_on_set(train_images, train_labels)
ts_pairs, ts_y = create_pairs_on_set(test_images, test_labels)

base_network = initialize_base_network()
plot_model(base_network, show_shapes=True, show_layer_names=True, to_file='base-model.png')

#Create the left input to the Siamese network. 
input_left = Input(shape=(28,28), name="left_input")
vect_output_left = base_network(input_left)

input_right = Input(shape=(28,28), name="right_input")
vect_output_right = base_network(input_right)

output = Lambda(euclidian_distance, name="output_layer", output_shape = eucl_dist_output_shape)([vect_output_left, vect_output_right])

model = Model(inputs = [input_left, input_right], outputs = output)

plot_model(model, show_shapes=True, show_layer_names=True, to_file="outer_model.png")


rms = RMSprop(momentum=0.9)

In [18]:
class ContrastiveLoss(tf.keras.losses.Loss):
  margin = 0.5
  def __init__(self, margin):
    super().__init__()
    self.margin = margin 

  def call(self, y_true, y_pred):
    return K.mean(y_true * K.square(y_pred) + (1-y_pred) * K.square(K.maximum(self.margin-y_pred, 0)))  


In [19]:
model.compile(loss= ContrastiveLoss(margin=1), optimizer=rms)
history = model.fit([tr_pairs[:,0], tr_pairs[:,1]], tr_y, epochs=20, batch_size=128, validation_data=([ts_pairs[:,0], ts_pairs[:,1]], ts_y))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
