# Siamese Network Explorer

In [None]:
import os
import pathlib

import numpy as np
import cv2

import skimage

import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

import typing

In [None]:
tf.random.set_seed(123)
np.random.seed(123)

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
def load_tf_from_path(
    img_path: typing.Union[str, pathlib.Path], 
    resizeSize: typing.Optional[tuple] = None,
    resizeMethod: typing.Optional[str] = tf.image.ResizeMethod.LANCZOS3) -> tf.Tensor:
    """
    Load a tf.Tensor from disk.

    --- Parameters ---
        img_path: path of the image (str or pathlib.Path)
        resizeSize: optional tuple containing width and height of the target resize
        resizeMethod: resize method to be used (LANCZOS3 default)
    """
    img = tf.io.read_file(img_path) if isinstance(img_path, str) else tf.io.read_file(str(img_path))
    img = tf.image.decode_jpeg(img)
    if resizeSize:
        img = tf.image.resize(img, size=resizeSize, method=resizeMethod) 
        img = tf.cast(tf.round(tf.clip_by_value(img, 0, 255)), dtype=tf.uint8)
    return tf.expand_dims(img, axis=0) # return a Tensor with shape [1, width, height, channels]

def make_dataset(
    image_paths: list,
    batch_size: int = 1,
    shuffle: bool = True,
    resizeSize: typing.Optional[tuple] = (28,28), 
    resizeMethod: typing.Optional[str] = tf.image.ResizeMethod.LANCZOS3):
    """
    Simple TF dataset creation
    TODO: make TF Records
    """
    imgs_tf_ls = list(load_tf_from_path(x, resizeSize, resizeMethod) for x in image_paths)
    dataset = tf.data.Dataset.from_tensor_slices(imgs_tf_ls)
    if shuffle:
        N = dataset.__len__()
        dataset = dataset.shuffle(N * 10)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_siamese_model(input_dim: tuple, rescaling=True):
    input_layer = tf.keras.layers.Input(input_dim)
    if rescaling:
        rescaling_layer = lambda x: tf.keras.layers.Rescaling(1./255)(x)
        input_layer = rescaling_layer(input_layer)
        
    conv2D = lambda input, filters: tf.keras.layers.Conv2D(filters, (3,3), padding='same', activation='swish')(input)
    maxPool = lambda x: tf.keras.layers.MaxPool2D((2,2))(x)
    flatten = lambda x: tf.keras.layers.Flatten()(x)
    output_dense = lambda x, units: tf.keras.layers.Dense(units, activation=None)(x)
    l2_normalize = lambda x: tf.math.l2_normalize(x, axis=-1) 
    
    x1 = conv2D(input_layer, 16)
    x2 = maxPool(x1)
    x3 = conv2D(x2, 16)
    x4 = maxPool(x3)
    x5 = flatten(x4)
    embeddingFeatures = output_dense(x5, 8)
    embeddingFeatures = l2_normalize(embeddingFeatures) # normalize output features between 0 and 1

    model = tf.keras.Model(input_layer, embeddingFeatures)
    
    # create siamese network
    input1, input2 = tf.keras.layers.Input(input_dim), tf.keras.layers.Input(input_dim)
    left_model = model(input1)
    right_model = model(input2)
    # dot product model
    dot_product = tf.keras.layers.dot([left_model, right_model], axes=1, normalize=False) # sum(left * right)
    # siamese model takes two input layers
    siamese_model = tf.keras.Model(inputs=[input1, input2], outputs=dot_product, name='Siamese Model')
    print(siamese_model.summary())
    #plot_model(siamese_model, to_file='siamese_model.png')
    return siamese_model

### Model unit-tests

In [None]:
img_paths = list(map(lambda x: str(x), pathlib.Path(skimage.data_dir).glob('*.jpg')))
len(img_paths)

In [None]:
img_test = load_tf_from_path(img_paths[0], (28,28))
img_test.shape
plt.imshow(img_test.numpy()[0])

In [None]:
siamese_model = get_siamese_model((28,28,3))

In [None]:
img_test_2 = load_tf_from_path(img_paths[1], (28,28))
plt.imshow(img_test_2.numpy()[0])

In [None]:
siamese_model([img_test, img_test])
siamese_model([img_test_2, img_test_2])
siamese_model([img_test, img_test_2])

In [None]:
dataset = make_dataset(img_paths)
for x in dataset:
    plt.imshow(x.numpy()[0][0])
    break

In [None]:
# resize_method = lambda x: tf.image.resize(x, (28,28), method=tf.image.ResizeMethod.LANCZOS3)
# cast_method = lambda x: tf.cast(tf.round(tf.clip_by_value(x, 0, 255)), dtype=tf.uint8)

# input_layer = tf.keras.layers.Input((28,28,3))
# input_layer = cast_method(resize_method(input_layer))
# if True:
#     rescaling_layer = lambda x: tf.keras.layers.Rescaling(1./255)(x)
#     input_layer = rescaling_layer(input_layer)


# conv2D = lambda input, filters: tf.keras.layers.Conv2D(filters, (3,3), padding='same', activation='swish')(input)
# maxPool = lambda x: tf.keras.layers.MaxPool2D((2,2))(x)
# flatten = lambda x: tf.keras.layers.Flatten()(x)
# output_dense = lambda x, units: tf.keras.layers.Dense(units, activation=None)(x)
# l2_normalize = lambda x: tf.math.l2_normalize(x, axis=-1)
# # sigmoid = lambda x: tf.keras.layers.Activation(activation='sigmoid')(x)

# x1 = conv2D(input_layer, 16)
# x2 = maxPool(x1)
# x3 = conv2D(x2, 16)
# x4 = maxPool(x3)
# x5 = flatten(x4)
# embedding = output_dense(x5, 8)
# embedding = l2_normalize(embedding)
# # embedding = sigmoid(embedding)
# model = tf.keras.Model(input_layer, embedding)

In [None]:
# # noise = tf.random.normal((1,1))
# img_test_2 = read_tf_img(img_paths[1])
# _ = plt.imshow(img_test_2.numpy()[0])

In [None]:
# i = tf.image.resize(img_test_2, (28,28), method=tf.image.ResizeMethod.LANCZOS3)
# i = tf.cast(tf.round(tf.clip_by_value(i, 0, 255)), dtype=tf.uint8) 
# plt.imshow(i.numpy()[0])

In [None]:
# left, right = model(img_test), model(img_test) 
# tf.keras.layers.dot([left, right], axes=1, normalize=False)
# tf.reduce_sum(tf.math.multiply(left, right))
# left, right = model(img_test), model(img_test_2) 
# tf.keras.layers.dot([left, right], axes=1, normalize=False)
# tf.reduce_sum(tf.math.multiply(left, right)

### Train siamese network