In [34]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import tensorflow as tf
from pathlib import Path
from keras import applications
from keras import layers
from keras import losses
from keras import ops
from keras import optimizers
from keras import metrics
from keras import Model
from keras.applications import resnet
from utilities import paths
from utilities import preprocessing
from utilities import visualization
from main import target_shape
from distance.DistanceLayer import  DistanceLayer
from model.SiameseModel import SiameseModel

In [35]:
anchor_images = sorted([paths.anchor_images_path + "\\" + f for f in os.listdir(paths.anchor_images_path)])
positive_images = sorted([paths.positive_images_path + "\\" + f for f in os.listdir(paths.positive_images_path)])
image_count = len(anchor_images)
if image_count != len(positive_images):
    raise Exception("Number of images in the datasets don't match")

In [36]:
anchor_dataset = tf.data.Dataset.from_tensor_slices(anchor_images)
positive_dataset = tf.data.Dataset.from_tensor_slices(positive_images)

In [37]:
rng = np.random.RandomState(seed=42)
rng.shuffle(anchor_images)
rng.shuffle(positive_images)

In [38]:
negative_images = anchor_images + positive_images
np.random.RandomState(seed=32).shuffle(negative_images)

In [39]:
negative_dataset = tf.data.Dataset.from_tensor_slices(negative_images)
negative_dataset = negative_dataset.shuffle(buffer_size=4096)

In [40]:
dataset = tf.data.Dataset.zip((anchor_dataset, positive_dataset, negative_dataset))
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.map(preprocessing.preprocess_triplets)

In [41]:
train_dataset = dataset.take(round(image_count * 0.8)) # images for training
val_dataset = dataset.skip(round(image_count * 0.8)) # images for validation

In [42]:
train_dataset = train_dataset.batch(32, drop_remainder=False)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

val_dataset = val_dataset.batch(32, drop_remainder=False)
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

In [43]:
# visualization.visualize(*list(train_dataset.take(1).as_numpy_iterator())[0]) #https://www.geeksforgeeks.org/python-star-or-asterisk-operator/ 

In [44]:
base_cnn = resnet.ResNet50(
    weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

In [45]:
embedding = Model(base_cnn.input, output, name="Embedding")

In [46]:
trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

In [47]:
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

distances = DistanceLayer()(
    embedding(resnet.preprocess_input(anchor_input)),
    embedding(resnet.preprocess_input(positive_input)),
    embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

In [48]:
# siamese_model = SiameseModel(siamese_network)
# siamese_model.compile(optimizer=optimizers.Adam(0.0001))
# siamese_model.fit(train_dataset, epochs=10, validation_data=val_dataset)