# Aim
Predict a trajectory of embryonic stages based on embryo images using Twin Network.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import cv2
import glob
import json
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pathlib
import shutil
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_io as tfio
from tensorflow.keras import applications, layers, models


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


from twinnet_tools.tnconfig import ProjectConfig

config = ProjectConfig("twinnet_config")
dir_root_scripts = config.json["dir_scripts"]


import glob
import matplotlib as mpl
import pandas as pd
from pathlib import Path
import sys

sys.path.append(dir_root_scripts)
from twinnet_tools.tngeneral import TNToolsGeneral
from twinnet_tools.tninference import TNToolsEmbeddings
from twinnet_tools.tninference import TNToolsSimilarities
from twinnet_tools.tnmodel import TNToolsNetwork
from twinnet_tools.tnplot import TNToolsPlot

# Import tools

tools_general = TNToolsGeneral()
tools_embeddings = TNToolsEmbeddings(size_img=224,
                                     size_img_min=250)
tools_model = TNToolsNetwork()
tools_similarities = TNToolsSimilarities()
tools_plot = TNToolsPlot()

In [None]:
new_rc_params = {'text.usetex': False,
                'svg.fonttype': 'none',
                'lines.linewidth': 1}

mpl.rcParams.update(new_rc_params)

# 1. Paths

In [None]:
# DataSet to analyze and refrence information
dataSet = "35_5_J"
refDataSet = "28_5_ref"
alldatainfo = "dataSets_info_zfish.txt"
modelName = "zebrafish_temperature.h5"
referenceImagesName = "images_jsons_reference_zfish"

# Main paths  
srcpath = "../temperature_zFish/"
codepath = "../TwinNet-main/code/Scripts/"

# Test datasets paths
datapath = os.path.join(srcpath,"testData")
dir_src_imgs_test = os.path.join(datapath, dataSet)

# Info dataSets 
dataSetsInfo = srcpath+alldatainfo

# Model and refrence images
path_model = os.path.join(srcpath, "models/"+modelName)
dir_src_jsons_reference = os.path.join(codepath, referenceImagesName) 

# Output 
dir_dst =  os.path.join(srcpath,"results")
sufix_out = dataSet


Load image data

In [None]:
# Read the text file dataSetsInfo as a pandas DataFrame with header
dataImages = pd.read_csv(dataSetsInfo, sep=',', header=0)

# Display the DataFrame
#print(dataImages)

# Retrieve data form the current dataSet

res = dataImages[dataImages['DataSetName'] == dataSet]
if not res.empty:
    img_time_interval = res.iloc[0, 2]
    initial_time =  res.iloc[0, 3]
    NumberEmbryos2Analyze = res.iloc[0, 1]
    print(f"Analysing Dataset: {dataSet} -> timeInterval: {img_time_interval} sec, init time: {initial_time} hpf, {NumberEmbryos2Analyze} embryos")
else:
    print("The dataSet was not found.")
    
# Retrieve data form the refrence dataSet   
    
    
res_ref = dataImages[dataImages['DataSetName'] == refDataSet]
if not res.empty:
    img_time_interval_reference = res_ref.iloc[0, 2]
    initial_time_reference =  res_ref.iloc[0, 3]
    print(f"Reference Dataset: {refDataSet} -> timeInterval: {img_time_interval_reference} sec, init time: {initial_time_reference} hpf")
else:
    print("The Reference dataSet was not found.")
    
    

Load reference image paths

In [None]:
def fn_json_load(path_json):
    """Load json file."""
    with open(path_json, 'rb') as JsonFile:
        content = json.load(JsonFile)
    return content

# Data loading

json_anchor = sorted(glob.glob(f'{dir_src_jsons_reference}/*.json'))
Number_anchors = len(json_anchor)

paths_anchor = [None] * Number_anchors
for i in range(0, Number_anchors):
    paths_anchor[i] = fn_json_load(json_anchor[i])

for i in range(0, Number_anchors):
    print(paths_anchor[i][0])   

# 2. Load model

In [None]:
resnet50 = tools_model.tn_embedding_load(path_model)

# 3. Calculate embeddings

In [None]:
def fn_image_tiff_parse(path_img, img_size=224, img_size_min=250):
    """Load TIFF image from path."""
    img1 = tf.io.read_file(path_img)
    img2 = tfio.experimental.image.decode_tiff(img1)
    img3 = tf.image.resize_with_crop_or_pad(img2, img_size_min, img_size_min)
    img4 = tf.reshape(img3, (img_size_min, img_size_min, 4))
    img5 = tf.image.resize(img4, (img_size, img_size))
    img6 = tfio.experimental.color.rgba_to_rgb(img5)
    img7 = tf.keras.applications.resnet50.preprocess_input(img6)
    return img7

def fn_images_tiff_parse(paths_images, **kwargs):
    """Load multiple tiff images from paths to numpy array with tfio."""
    image_segments = list()
    num_images = len(paths_images)

    for i in range(num_images):
        print(f'[LOADING] Image arrays {i + 1}/{num_images} ...'.ljust(50), end='\r')
        path_image = paths_images[i]
        try:
            image_segment = fn_image_tiff_parse(path_image, **kwargs)
            
            image_segments.append(image_segment)
        except cv2.error:
            pass
    return np.array(image_segments)

def list_to_embeddings(list_embryos_images, model_embedding):
    """Generate embeddings for an image set of embryos."""
    array_imgs = fn_images_tiff_parse(list_embryos_images)
    embeds_imgs = tools_embeddings.imgs_to_embeddings(model_embedding, array_imgs)
    return embeds_imgs

def fn_cosine_similarity(val_a, val_b):
    """
    Calculate cosine similarity between two values 'val_a' and 'val_b'.
    """
    a = np.squeeze(val_a)
    b = np.squeeze(val_b)
    return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))

def embryo_reference_similarities(embeddings_reference, embeddings_test):
    """
    Loop through list of embeddings. Calculate similarities between each embedding and all of its
    previous embeddings.
    """
    Number_anchors = len(embeddings_reference)
    column_names = ['Anch_sim_{:02d}'.format(i) for i in range(1, Number_anchors+1)]

    similarities = dict()
    for i in range(len(embeddings_test)):
        print(f'[INFO] {str(i + 1).zfill(3)}/{len(embeddings_test)}'.ljust(50), end='\r')
        df = pd.DataFrame(columns=column_names)      
        for j in range(len(embeddings_reference[0])):
            anch_sim = [None] * Number_anchors
            for k in range(0, Number_anchors):
                 anch_sim[k] = fn_cosine_similarity(embeddings_reference[k][j], embeddings_test[i])

            df.loc[j] = anch_sim
        similarities[i] = df

    return similarities

def save_sims(sims, dir_dst, **kwargs):
    """Save similarities stored within dataframes in a dict to a directory by dict keys."""
    signature = kwargs.get('signature', '')
    for _k, _v in sims.items():
        print(str(_k).ljust(10), end='\r')
        path_dst = f'{dir_dst}/{signature}similarities_{str(_k).zfill(3)}.csv'
        _v.to_csv(path_dst)
        
def save_json(data, file_path):
    """Save json file."""
    with open(file_path, 'w') as file:
        # Write the list of numbers to the file in JSON format
        json.dump(data, file)
 

Calculate reference embeddings

In [None]:
embeddings_reference = [None] * Number_anchors
for i in range(0, Number_anchors):
    embeddings_reference[i] = list_to_embeddings(paths_anchor[i], resnet50)

# 4. Calculate embedings and similarities

In [None]:
# Iterate over all subfolders in the root folder
allData_estimated_dev_age = []
allData_experimental_age = []
allData_fileNames = []
allData_CosDist_extimated_age = []

    
for folder_name in sorted(os.listdir(dir_src_imgs_test)):
    # Construct the full path of the subfolder
    subfolder_path = os.path.join(dir_src_imgs_test, folder_name)

    # Check if the current path is a directory
    if os.path.isdir(subfolder_path):
        print("Processing subfolder:", folder_name)
        # Load test image paths
        imgs_src = sorted(glob.glob(f'{subfolder_path}/*.tif'))
        print('***',  len(imgs_src), ' images loaded')
        if abs(24 - (len(imgs_src) * img_time_interval / 3600)) > 0.5:
            t_interval = (24*3600) / len(imgs_src)
            print('!!!!!! Check cosistency time interval expected : ', t_interval)

        # Calculate embedings
        embeddings_test = list_to_embeddings(imgs_src, resnet50)
        
        # Calculate similarities
        similarities_test = embryo_reference_similarities(embeddings_reference, embeddings_test)
        
        # Get median/mean/max similarities from the anchors
        sims_stat = list()

        for sims_test in similarities_test.values():
            # get median of similarities per time point
            sims_stat.append(sims_test.max(axis=1))
         
        df_sims_stat = pd.DataFrame(sims_stat).reset_index(drop=True)
        sims_stat = df_sims_stat.values

        # Get maxima of similarity sequence
        index_ref = np.argmax(sims_stat, axis=1)
        max_values = np.amax(sims_stat, axis=1)

        # Transform indixes to times
        experimental_age  = initial_time + (np.arange(len(imgs_src)) * (img_time_interval / 3600))
        estimated_dev_age = initial_time_reference + (index_ref * (img_time_interval_reference / 3600))
        
        # Store data
        allData_experimental_age.append(experimental_age.tolist())
        allData_estimated_dev_age.append(estimated_dev_age.tolist())
        allData_fileNames.append(folder_name)
        allData_CosDist_extimated_age.append(max_values.tolist())
        
        # plot
        #fig, axs = plt.subplots(dpi=300, figsize=(3,3))
        #plt.plot(experimental_age,estimated_dev_age)
        #plt.plot(experimental_age,experimental_age)
        #plt.xlabel('experimental_age')
        #plt.ylabel('estimated_dev_age')

        
        

In [None]:
# Save Data

save_json(allData_estimated_dev_age, os.path.join(dir_dst, 'allData_estimated_dev_age_'+sufix_out+'.json'))
save_json(allData_experimental_age, os.path.join(dir_dst, 'allData_experimental_age_'+sufix_out+'.json'))
save_json(allData_fileNames, os.path.join(dir_dst, 'allData_fileNames_'+sufix_out+'.json'))
save_json(allData_CosDist_extimated_age, os.path.join(dir_dst, 'allData_CosDist_extimated_age_'+sufix_out+'.json'))


# 5. Average plots

In [None]:
# Calculate the average and standard deviation per column
average = np.mean(allData_estimated_dev_age, axis=0)
std_dev = np.std(allData_estimated_dev_age, axis=0)

# Generate x-axis values for the plot
x = allData_experimental_age[0]

# Plot the average curve with standard deviation
plt.plot(x, average, label='Average')
plt.plot(x, x, label='Reference')
plt.fill_between(x, average - std_dev, average + std_dev, alpha=0.2, label='Standard Deviation')


# Add labels and title
plt.xlabel('Experimental age (hpf)')
plt.ylabel('Estimated developmental age (hpf)')

# Add a legend
plt.legend()

# Display the plot
plt.show()

In [None]:
ratio =   average / np.array(allData_experimental_age[0])

# Plot the average curve with standard deviation
plt.plot(x, ratio, label='Ratio')


# Add labels and title
plt.xlabel('Experimental age (hpf)')
plt.ylabel('Ration Estimated developmental age to Experimental age')

# Add a legend
plt.legend()

# Display the plot
plt.show()