In [22]:
import tensorflow as tf
from pathlib import Path
import sleap
import json
import numpy as np
from sleap.nn.inference import (
    CentroidCrop,
    CentroidInferenceModel,
    TopDownInferenceModel,
    FindInstancePeaks,
    TopDownMultiClassFindPeaks,
    TopDownMultiClassInferenceModel,
    SingleInstanceInferenceModel,
    SingleInstanceInferenceLayer
)



In [12]:
import ast
def export_frozen_graph(model, preds, output_path):

    tensors = {}

    for key, val in preds.items():
        dtype = str(val.dtype) if isinstance(val.dtype, np.dtype) else repr(val.dtype)
        tensors[key] = {
            "type": f"{type(val).__name__}",
            "shape": f"{val.shape}",
            "dtype": dtype,
            "device": f"{val.device if hasattr(val, 'device') else 'N/A'}",
        }

    with output_path as d:
        model.export_model(d.as_posix(), tensors=tensors)

        tf.compat.v1.reset_default_graph()
        with tf.compat.v2.io.gfile.GFile(f"{d}/frozen_graph.pb", "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def)

        with open(f"{d}/info.json") as json_file:
            info = json.load(json_file)

        for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]:

            saved_name = (
                tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "")
            )
            saved_shape = ast.literal_eval(
                tensor_info.split("shape=", 1)[1].split("), ")[0] + ")"
            )
            saved_dtype = tensor_info.split("dtype=")[1].split(")")[0]

            loaded_shape = tuple(graph.get_tensor_by_name(f"import/{saved_name}").shape)
            loaded_dtype = graph.get_tensor_by_name(f"import/{saved_name}").dtype.name

            assert saved_shape == loaded_shape
            assert saved_dtype == loaded_dtype

In [13]:
#Export full topdown network with ID

runs_folder = r"C:\Users\neurogears\Desktop\Sleap_demo\models"
id_model = tf.keras.models.load_model(runs_folder + "\\" + "221027_multiclass.topdown" + "\\best_model.h5", compile = False)

centroid_model_path = runs_folder + "\\" + r"221027_104636.centroid.n=88\best_model.h5"
centroid_model = tf.keras.models.load_model(centroid_model_path, compile = False)

#Make sure you set the crop size to the expected input of the Id layer
centroid = CentroidCrop(
    keras_model=centroid_model, crop_size=96, input_scale = 0.5)

instance_peaks = TopDownMultiClassFindPeaks(keras_model=id_model, return_class_vectors = True)
model = TopDownMultiClassInferenceModel(centroid, instance_peaks)
preds = model.predict(np.zeros((4, 1088, 1440, 1), dtype="uint8"))

export_frozen_graph(model, preds, Path(runs_folder) / "BonsaiModels" /  "topdown_id" )



INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmpyzx4k2vb\assets
INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmpyzx4k2vb\assets


In [16]:
#Export full topdown network without ID

runs_folder = r"C:\Users\neurogears\Desktop\Sleap_demo\models"
pose_model = tf.keras.models.load_model(runs_folder + "\\" + "221027_111451.centered_instance.n=88" + "\\best_model.h5", compile = False)

centroid_model_path = runs_folder + "\\" + r"221027_104636.centroid.n=88\best_model.h5"
centroid_model = tf.keras.models.load_model(centroid_model_path, compile = False)

#Make sure you set the crop size to the expected input of the Id layer
centroid = CentroidCrop(
    keras_model=centroid_model, crop_size=96, input_scale = 0.5)

instance_peaks = FindInstancePeaks(keras_model=pose_model)
model = TopDownInferenceModel(centroid, instance_peaks)
preds = model.predict(np.zeros((4, 1088, 1440, 1), dtype="uint8"))

export_frozen_graph(model, preds, Path(runs_folder) / "BonsaiModels" /  "topdown" )


INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmpdx3gf39e\assets
INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmpdx3gf39e\assets


In [20]:
## Export a centroid model

runs_folder = r"C:\Users\neurogears\Desktop\Sleap_demo\models"

centroid_model_path = runs_folder + "\\" + r"221027_104636.centroid.n=88\best_model.h5"
centroid_model = tf.keras.models.load_model(centroid_model_path, compile = False)

#Make sure you set the crop size to the expected input of the Id layer
centroid = CentroidCrop(
    keras_model=centroid_model, crop_size=96, input_scale = 0.5, return_crops=False)

model = CentroidInferenceModel(centroid)

preds = model.predict(np.zeros((4, 1088, 1440, 1), dtype="uint8"))

export_frozen_graph(model, preds, Path(runs_folder) / "BonsaiModels" /  "centroid" )


INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmp94q1ut5d\assets
INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmp94q1ut5d\assets


In [23]:
#Single instance model (using full picture for now...)

runs_folder = r"C:\Users\neurogears\Desktop\Sleap_demo\models"

single_instance_model_path = runs_folder + "\\" + r"221027_092218.single_instance.n=88\best_model.h5"
single_instance_model = tf.keras.models.load_model(single_instance_model_path, compile = False)

model = SingleInstanceInferenceModel(
    SingleInstanceInferenceLayer(keras_model=single_instance_model)
)

preds = model.predict(np.zeros((4, 1088, 1440, 1), dtype="uint8"))

export_frozen_graph(model, preds, Path(runs_folder) / "BonsaiModels" /  "SingleInstance" )


INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmp9fd6v9uj\assets
INFO:tensorflow:Assets written to: C:\Users\NEUROG~1\AppData\Local\Temp\tmp9fd6v9uj\assets


In [None]:
# Customized model (Not supported in the package yet....)
from test_run import *
import test_utils
from pathlib import Path
from sleap.nn.peak_finding import find_local_peaks, find_global_peaks, make_centered_bboxes, crop_bboxes
from sleap.nn.data.utils import  describe_tensors

id_model = tf.keras.models.load_model(runs_folder + "\\" + run_name + "\\best_model.h5", compile = False)

#Define the full mode
class TopDownInferenceModel(tf.keras.Model):
    def __init__(
        self,
        this_model,
        td_output_stride,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.this_model = this_model
        self.td_output_stride = td_output_stride


    @tf.function
    def stage2(self, im):
        """Stage 2: Predict pose in each crop"""
        # Preprocessing
        X = tf.cast(im, tf.float32) / 255
        X = tf.image.resize(X, [96, 96])

        # Forward pass
        cms, class_probs = self.this_model(X)

        # Find keypoints in each crop
        #     pts: (n_centroids, n_nodes, 2)
        #     vals: (n_centroids, n_nodes)
        pts, vals = find_global_peaks(cms, threshold=0.2, refinement="integral")

        # Adjust coordinates for output stride
        pts = pts * self.td_output_stride

        return {"instance_peaks": pts, "instance_peak_vals": vals, "class_probabilities": class_probs}

    def call(self, imgs):
        preds = self.stage2(imgs)
        return preds

td_output_stride = 4
td_input_size = [96, 96] #Hardcoded in the network

inference_model = TopDownInferenceModel(id_model,td_output_stride)

preds = inference_model.predict(np.zeros((4, 96, 96, 1), dtype="uint8"))
describe_tensors(preds)

inference_model.save(
    r"C:\Users\neurogears\Desktop\SLEAPNetwork_v1\Bonsai\customNetwork" + "\\" + run_name,
    save_format="tf", save_traces=True)
