<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Model" data-toc-modified-id="Model-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Model</a></span><ul class="toc-item"><li><span><a href="#Save-model" data-toc-modified-id="Save-model-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Save model</a></span></li><li><span><a href="#Save-architecture" data-toc-modified-id="Save-architecture-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Save architecture</a></span></li><li><span><a href="#Save-weights" data-toc-modified-id="Save-weights-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Save weights</a></span></li></ul></li><li><span><a href="#Inference-Model" data-toc-modified-id="Inference-Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Inference Model</a></span><ul class="toc-item"><li><span><a href="#Save-model" data-toc-modified-id="Save-model-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Save model</a></span></li><li><span><a href="#Save-architecture" data-toc-modified-id="Save-architecture-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Save architecture</a></span></li><li><span><a href="#Save-weights" data-toc-modified-id="Save-weights-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Save weights</a></span></li></ul></li></ul></div>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf

In [None]:
from nucleus.detection import *
from nucleus.detection.backbones.managers import *

## Model

In [None]:
backbone_manager = MobileNetManager()
backbone = backbone_manager.create_model(alpha=0.25)

In [None]:
# backbone = tf.keras.applications.MobileNet(
#     include_top=False,
#     weights='imagenet',
#     input_shape=(None, None, 3)
# )

In [None]:
detector_manager = YoloManager()
detector = detector_manager.create_model(
    backbone=backbone,
    n_classes=4
)

In [None]:
detector.summary(line_length=117)

In [None]:
tf.keras.utils.plot_model(
    model=detector,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True
)

### Save model

In [None]:
detector_manager.save_model(
    model=detector, 
    save_format='tf',
    custom_objects=backbone_manager.custom_objects,
    overwrite=True
)

In [None]:
detector = detector_manager.load_model(
    save_format='tf',
    custom_objects=backbone_manager.custom_objects
)

### Save architecture

In [None]:
detector_manager.save_model_arch(model=detector)

In [None]:
detector = detector_manager.load_model_arch(
    custom_objects=backbone_manager.custom_objects
)

### Save weights

In [None]:
detector_manager.save_model_weights(model=detector)

In [None]:
detector = detector_manager.load_model_weights(model=detector)

## Inference Model

In [None]:
matcher = detector_manager.create_matcher()

In [None]:
inference_detector = detector_manager.create_inference_model(model=detector)

In [None]:
inference_detector.summary(line_length=117)

In [None]:
tf.keras.utils.plot_model(
    model=inference_detector,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True
)

### Save model

In [None]:
from nucleus.detection.layers import YoloInferenceLayer

In [None]:
detector_manager.save_model(
    model=inference_detector,
    save_format='tf',
    overwrite=True,
    custom_objects={
        **backbone_manager.custom_objects,
        'YoloInferenceLayer': YoloInferenceLayer
    }
)

In [None]:
inference_detector = detector_manager.load_model(
    inference=True, 
    save_format='tf',
    custom_objects={
        **backbone_manager.custom_objects,
        'YoloInferenceLayer': YoloInferenceLayer
    }
)

### Save architecture

In [None]:
detector_manager.save_model_arch(model=inference_detector)

In [None]:
inference_detector = detector_manager.load_model_arch(
    inference=True, 
    custom_objects={
        **backbone_manager.custom_objects,
        'YoloInferenceLayer': YoloInferenceLayer
    }
)

### Save weights

In [None]:
detector_manager.save_model_weights(model=inference_detector)

In [None]:
inference_detector = detector_manager.load_model_weights(model=inference_detector)