# Training of a Twin Network model: Stickleback (Gasterosteus aculeatus)

# Table of Contents
* [General](#first-bullet)
* [Load dataset](#second-bullet)
* [Load model](#third-bullet)
* [Training](#fourth-bullet)

## General <a class="anchor" id="first-bullet"></a>

General imports and class definitions

In [None]:
%load_ext autoreload
%autoreload 2

Import Python packages

In [None]:
import datetime
import json
import matplotlib.pyplot as plt
import os
import sys
import tensorflow as tf
from tensorflow.keras import layers

from twinnet_tools.tnconfig import ProjectConfig
from twinnet_tools.tnmodel import TNToolsDistanceLayer
from twinnet_tools.tnmodel import TNToolsModel
from twinnet_tools.tntraining import TNToolsTrainingImages, TNToolsTrainingPaths

os.environ["OPENCV_LOG_LEVEL"]="FATAL"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

Load config file and paths from config file

In [None]:
config = ProjectConfig("twinnet_config")

In [None]:
config_paths_script = config.json["TrainingStickleback"]
dir_data_src = config_paths_script["dir_data_src"]
dir_data_dst = config_paths_script["dir_data_dst"]

Prepare paths to save outputs to

In [None]:
modelId="model1"

# Output path
outFolder = f"{dir_data_dst}/{modelId}"

if not os.path.exists(outFolder):
    os.makedirs(outFolder)

Make tool instances

In [None]:
img_height, img_width = 224, 224 
it = TNToolsTrainingImages(img_height, 500, img_width, 500)
pg = TNToolsTrainingPaths(img_height, 400, img_width, 400)

Training parameters

In [None]:
batchSize = 64
image_count = 150000

## Load dataset <a class="anchor" id="second-bullet"></a>

Load paths of datasets to use for image triplet preparation

In [None]:
list_paths_datasets_directories = dir_data_src
                                   
list_paths_jsons_exclude = [str(i) + '/EMBRYOS_TO_EXCLUDE.json' for i in list_paths_datasets_directories]

Sort image paths to image triplets

In [None]:
paths_anchor, paths_positive, paths_negative = pg.procedure_paths_binary(list_paths_datasets_directories,
                                                                         list_paths_jsons_exclude,
                                                                         image_count)

In [None]:
paths_anchor2, paths_positive2, paths_negative2 = pg.procedure_paths_binary(list_paths_datasets_directories,
                                                                            list_paths_jsons_exclude,
                                                                            image_count)

In [None]:
paths_anchor3, paths_positive3, paths_negative3 = pg.procedure_paths_binary(list_paths_datasets_directories,
                                                                            list_paths_jsons_exclude,
                                                                            image_count)

Make tensorflow datasets

In [None]:
dataset_anchor = tf.data.Dataset.from_tensor_slices(paths_anchor)
dataset_positive = tf.data.Dataset.from_tensor_slices(paths_positive)
dataset_negative = tf.data.Dataset.from_tensor_slices(paths_negative)

dataset = tf.data.Dataset.zip((dataset_anchor, dataset_positive, dataset_negative))
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.map(it.parse_triple_fn)

train_dataset = dataset.take(round(image_count * 0.8))
val_dataset = dataset.skip(round(image_count * 0.8))

train_dataset = train_dataset.batch(batchSize, drop_remainder=False)
train_dataset = train_dataset.prefetch(8)

val_dataset = val_dataset.batch(batchSize, drop_remainder=False)
val_dataset = val_dataset.prefetch(8)

In [None]:
dataset_anchor2 = tf.data.Dataset.from_tensor_slices(paths_anchor2)
dataset_positive2 = tf.data.Dataset.from_tensor_slices(paths_positive2)
dataset_negative2 = tf.data.Dataset.from_tensor_slices(paths_negative2)

dataset2 = tf.data.Dataset.zip((dataset_anchor2, dataset_positive2, dataset_negative2))
dataset2 = dataset2.shuffle(buffer_size=1024)
dataset2 = dataset2.map(it.augment_triple_one)

train_dataset2 = dataset2.take(round(image_count * 0.8))
val_dataset2 = dataset2.skip(round(image_count * 0.8))

train_dataset2 = train_dataset2.batch(batchSize, drop_remainder=False)
train_dataset2 = train_dataset2.prefetch(8)

val_dataset2 = val_dataset2.batch(batchSize, drop_remainder=False)
val_dataset2 = val_dataset2.prefetch(8)


In [None]:
dataset_anchor3 = tf.data.Dataset.from_tensor_slices(paths_anchor3)
dataset_positive3 = tf.data.Dataset.from_tensor_slices(paths_positive3)
dataset_negative3 = tf.data.Dataset.from_tensor_slices(paths_negative3)

dataset3 = tf.data.Dataset.zip((dataset_anchor3, dataset_positive3, dataset_negative3))
dataset3 = dataset3.shuffle(buffer_size=1024)
dataset3 = dataset3.map(it.augment_triple_two)

train_dataset3 = dataset3.take(round(image_count * 0.8))
val_dataset3 = dataset3.skip(round(image_count * 0.8))

train_dataset3 = train_dataset3.batch(batchSize, drop_remainder=False)
train_dataset3 = train_dataset3.prefetch(8)

val_dataset3 = val_dataset3.batch(batchSize, drop_remainder=False)
val_dataset3 = val_dataset3.prefetch(8)

In [None]:
def visualize(anchor, positive, negative):
    """Visualize a few triplets from the supplied batches."""

    def show(ax, image):
        ax.imshow(image / 255, cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    fig = plt.figure(figsize=(9, 9))

    axs = fig.subplots(4, 3)
    for i in range(4):
        show(axs[i, 0], anchor[i])
        show(axs[i, 1], positive[i])
        show(axs[i, 2], negative[i])


visualize(*list(train_dataset.take(1).as_numpy_iterator())[0])

## Load model <a class="anchor" id="third-bullet"></a>

Define model

In [None]:
base_cnn = tf.keras.applications.resnet.ResNet50(
    weights="imagenet", input_shape=(img_height, img_width) + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = tf.keras.models.Model(base_cnn.input, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

In [None]:
anchor_input = tf.keras.layers.Input(name="anchor", shape=(img_height, img_width) + (3,))
positive_input = tf.keras.layers.Input(name="positive", shape=(img_height, img_width) + (3,))
negative_input = tf.keras.layers.Input(name="negative", shape=(img_height, img_width) + (3,))

distances = TNToolsDistanceLayer()(
    embedding(tf.keras.applications.resnet.preprocess_input(anchor_input)),
    embedding(tf.keras.applications.resnet.preprocess_input(positive_input)),
    embedding(tf.keras.applications.resnet.preprocess_input(negative_input)),
)

twin_network = tf.keras.models.Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)


In [None]:
twin_network.summary()

Define callbacks

In [None]:
checkpoint_filepath = outFolder+'/checkpoints/'

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=5,
)


model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)


In [None]:
log_dir = outFolder+"logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 
                                                      histogram_freq=1)

## Training <a class="anchor" id="fourth-bullet"></a>

Run training:
- 150000 image triplets per dataset
- 3 runs, each run with a different dataset
- 10 epochs per dataset

In [None]:
twin_network_model = TNToolsModel(twin_network)
twin_network_model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), 
                           metrics=['accuracy'])

Training iteration 1

In [None]:
history = twin_network_model.fit(train_dataset, 
                                 epochs=10, 
                                 callbacks=[early_stopping, 
                                            model_checkpoint_callback, 
                                            tensorboard_callback], 
                                 validation_data=val_dataset)

In [None]:
twin_network_model.compute_output_shape(
    input_shape=((None, 224,224,3), 
                 (None, 224,224,3),
                 (None, 224,224,3))
)

In [None]:
embedding.save(outFolder+'/dir_dst_model_epochs_10/')
embedding.save_weights(outFolder+'/dir_dst_model_epochs_10_weights/')

Training iteration 2

In [None]:
log_dir = outFolder+"logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

history2 = twin_network_model.fit(train_dataset2, epochs=20, initial_epoch=10, callbacks=[early_stopping, model_checkpoint_callback, tensorboard_callback], validation_data=val_dataset2)

In [None]:
embedding.save(outFolder+'/dir_dst_model_epochs_20/')
embedding.save_weights(outFolder+'/dir_dst_model_epochs_20_weights/')

Training iteration 3

In [None]:
log_dir = outFolder+"logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

history3 = twin_network_model.fit(train_dataset3, epochs=30, initial_epoch=20, callbacks=[early_stopping, model_checkpoint_callback,tensorboard_callback], validation_data=val_dataset3)

In [None]:
embedding.save(outFolder+'/dir_dst_model_epochs_30/')
embedding.save_weights(outFolder+'/dir_dst_model_epochs_30_weights/')