In [8]:
from __future__ import print_function
import keras
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense, Lambda
from siamese import SiameseNetwork
import numpy as np

In [2]:
batch_size = 128
num_classes = 10
epochs = 10

# input image dimensions
img_rows, img_cols = 28, 28

In [3]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [27]:
def create_base_model(input_shape):
    model_input = Input(shape=input_shape)

    embedding = Conv2D(32, kernel_size=(3, 3), input_shape=input_shape)(model_input)
    embedding = BatchNormalization()(embedding)
    embedding = Activation(activation='relu')(embedding)
    embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
    embedding = Conv2D(64, kernel_size=(3, 3))(embedding)
    embedding = BatchNormalization()(embedding)
    embedding = Activation(activation='relu')(embedding)
    embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
    embedding = Flatten()(embedding)
    embedding = Dense(128)(embedding)
    embedding = BatchNormalization()(embedding)
    embedding = Activation(activation='relu')(embedding)

    return Model(model_input, embedding)


def create_head_model(embedding_shape):
    embedding_a = Input(shape=embedding_shape[1:])
    embedding_b = Input(shape=embedding_shape[1:])

#     head = Concatenate()([embedding_a, embedding_b])
#     head = Dense(8)(head)
#     head = BatchNormalization()(head)
#     head = Activation(activation='sigmoid')(head)

#     head = Dense(1)(head)
#     head = BatchNormalization()(head)
#     head = Activation(activation='sigmoid')(head)
    
    L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
    L1_distance = L1_layer([embedding_a, embedding_b])
    prediction = Dense(1,activation='sigmoid')(L1_distance)

    return Model([embedding_a, embedding_b], prediction)

In [9]:
base_model = create_base_model(input_shape)
head_model = create_head_model(base_model.output_shape)

siamese_network = SiameseNetwork(base_model, head_model)
siamese_network.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])

siamese_checkpoint_path = "./siamese_checkpoint"

siamese_callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
    ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
]

In [10]:
siamese_network.fit(x_train, y_train,
                    validation_data=(x_test, y_test),
                    batch_size=1000,
                    epochs=epochs,
                    callbacks=siamese_callbacks)



Epoch 1/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 2/10
Epoch 3/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 4/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 5/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 6/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 7/10
Epoch 8/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 9/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 10/10
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets


In [11]:
siamese_network.evaluate(x_train, y_train, batch_size=64)



[0.03771126642823219, 0.9888559579849243]

In [22]:
base_model = create_base_model(input_shape)
head_model = create_head_model(base_model.output_shape)

siamese_network = SiameseNetwork(base_model, head_model)
siamese_network.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])

siamese_checkpoint_path = "./siamese_checkpoint"

siamese_callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
    ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
]

few_shot_index = [np.where(y_train == i)[0][:20] for i in range(num_classes)]
few_shot_index = np.concatenate(few_shot_index, axis=None)
x_train_fewshot = x_train[few_shot_index]
y_train_fewshot = y_train[few_shot_index]

siamese_network.fit(x_train_fewshot, y_train_fewshot,
                    validation_data=(x_test[:500], y_test[:500]),
                    batch_size=10,
                    epochs=20,
                    callbacks=siamese_callbacks)



Epoch 1/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 2/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 3/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 4/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 5/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 6/20
Epoch 7/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 8/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 9/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 10/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
INFO:tensorflow:Assets written to: .\siamese_checkpoint\assets


In [23]:
siamese_network.evaluate(x_test, y_test, batch_size=10)



[0.41422051191329956, 0.8327999711036682]