# **This is a genertic python code to export StarDist trained models and use them with DeepImageJ plugin**

https://deepimagej.github.io/deepimagej/index.html


If you are using Google Colab, mount your Google Drive. Otherwise, skip this step

In [0]:
from google.colab import drive
drive.mount('/content/drive')


Install the following packages: 
- A compatible version of Tensorflow <= 1.13.
- stardist python package. Here we used StarDist 0.3.6

In [0]:
% pip install tensorflow==1.13.1
% pip install stardist

# Load the StarDist trained model from your repository

Verify input and output sizes of your model. They can be different when the parameter grid is not (1,1). A different output size can lead to errors in DeepImageJ. Take it also into account if you want to perform shape measurements using the output image.

In [0]:
from stardist.models import StarDist2D
# Without shape completion
model_paper = StarDist2D(None, name='name_of_your_model', basedir='/content/drive/My Drive/the_path_to_your_model/folde_containing_the_model')
# Indicate which weights you want to use
model_paper.load_weights('weights_best.h5')

# Save as a TensorFlow SavedModel

In [0]:
import keras
import keras.backend as K
from keras.layers import concatenate
import tensorflow as tf
#Write the path where you would like to save the model. 
# The code will automatically create a new folder called "new_folder", where the
# TensorFlow model will be saved
OUTPUT_DIR = "/content/drive/My Drive/the_path_where_you_want_to_save_your_model/new_folder"
builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)

# StarDist has two different outputs. DeepImageJ can only read one of them, so 
# we concatenate them as different channels in order to used them in ImageJ.
signature = tf.saved_model.signature_def_utils.predict_signature_def(
            inputs  = {'input':  model_paper.keras_model.input[0]},
            # concatenate the output of StarDist
            outputs = {'output': concatenate([model_paper.keras_model.output[0],model_paper.keras_model.output[1]], axis = 3)})
signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }

builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],
                                             signature_def_map=signature_def_map)
builder.save()