# Training of a Twin Network model: Zebrafish (Danio rerio)
Training with differently sorted datasets. This training script uses image tuplets for preparation of image triplets for Twin Network training. Rotated anchor images are used in the same triplets as positive images.

# Table of Contents
* [General](#first-bullet)
* [Load dataset](#second-bullet)
* [Load model](#third-bullet)
* [Training](#fourth-bullet)

## General <a class="anchor" id="first-bullet"></a>

General imports and class definitions

In [None]:
%load_ext autoreload
%autoreload 2

Import Python packages

In [None]:
import datetime
import os
from pathlib import Path
import tensorflow as tf

from twinnet_tools.tnconfig import ProjectConfig
from twinnet_tools.tnmodel import TNToolsModel, TNToolsNetwork
from twinnet_tools.tntraining_tuplet import TNToolsTrainingTupletsDataset, TNToolsTrainingTupletsImages

Load config file and paths from config file

In [None]:
config = ProjectConfig("twinnet_config")

In [None]:
config_paths_script = config.json["TrainingZebrafish2"]
path_src_data_train_json = config_paths_script["path_src_data_train_json"]
dir_data_dst = config_paths_script["dir_data_dst"]

Prepare paths to save outputs to

In [None]:
modelId="model1"

# Output path
outFolder = f"{dir_data_dst}/{modelId}"

if not os.path.exists(outFolder):
    os.makedirs(outFolder)

Make tool instances

In [None]:
img_height, img_width = 224, 224 
img_height_min, img_width_min = 300, 300
tools_dataset = TNToolsTrainingTupletsDataset(img_height,
                                              img_width,
                                              img_height_min,
                                              img_width_min)
tools_images = TNToolsTrainingTupletsImages(img_height,
                                            img_width,
                                            img_height_min,
                                            img_width_min)

tools_network = TNToolsNetwork()

Training parameters

In [None]:
batchSize = 5
image_count = 1000000
keys_include = ["normal_bright_complete"]

split_train = 0.8
num_train = round(image_count * split_train)
num_test = image_count - num_train

## Load dataset <a class="anchor" id="second-bullet"></a>

Sort image paths to image triplets and make datasets

In [None]:
dataset = tools_dataset(path_src_data_train_json,
                        keys_include,
                        image_count)

In [None]:
dataset2 = dataset
dataset2 = dataset2.shuffle(buffer_size=1024)
dataset2 = dataset2.map(tools_images.images_parse_fn)

train_dataset = dataset2.take(num_train)
val_dataset = dataset2.skip(num_test)

train_dataset = train_dataset.batch(batchSize, drop_remainder=False)
train_dataset = train_dataset.prefetch(8)

val_dataset = val_dataset.batch(batchSize, drop_remainder=False)
val_dataset = val_dataset.prefetch(8)

print(train_dataset.element_spec)

In [None]:
tools_images.visualize(
    *list(dataset2.take(3).as_numpy_iterator())
)

## Load model <a class="anchor" id="third-bullet"></a>

Define model

In [None]:
twin_network_embedding = tools_network.tn_embedding_make()
twin_network = tools_network.tn_network_resnet_make(twin_network_embedding)

In [None]:
twin_network.summary()

In [None]:
twin_network_model = TNToolsModel(twin_network)

Define callbacks

In [None]:
checkpoint_filepath = outFolder+'/checkpoints/'

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=5,
)


model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)


In [None]:
log_dir = outFolder+"logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 
                                                      histogram_freq=1)

Compile model

In [None]:
twin_network_model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), 
                           metrics=['accuracy'])

## Training <a class="anchor" id="fourth-bullet"></a>

Run training:
- 1000000 image triplets per dataset
- 1 run
- 10 epochs per dataset

In [None]:
history = twin_network_model.fit(train_dataset, 
                                 epochs=10, 
                                 callbacks=[early_stopping, 
                                            model_checkpoint_callback, 
                                            tensorboard_callback], 
                                 validation_data=val_dataset)

In [None]:
twin_network_model.compute_output_shape(
    input_shape=((None, 224,224,3), 
                 (None, 224,224,3),
                 (None, 224,224,3))
)

In [None]:
embedding.save(outFolder+'/dir_dst_model_epochs_10/')
embedding.save_weights(outFolder+'/dir_dst_model_epochs_10_weights/')