<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Dataset" data-toc-modified-id="Dataset-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Dataset</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Data-Augmentation" data-toc-modified-id="Data-Augmentation-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Data Augmentation</a></span></li><li><span><a href="#Input-pipeline" data-toc-modified-id="Input-pipeline-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Input pipeline</a></span></li><li><span><a href="#Train" data-toc-modified-id="Train-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Train</a></span></li><li><span><a href="#Scrap" data-toc-modified-id="Scrap-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Scrap</a></span></li></ul></div>

In [None]:
%load_ext autoreload
%autoreload 2

# Does not seem to work with TF2.0 yet
# %load_ext tensorboard
%matplotlib inline

import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

In [None]:
import tensorflow as tf

# Copied from: https://www.tensorflow.org/beta/guide/using_gpu
tf.debugging.set_log_device_placement(True)

# Adapted from: https://www.tensorflow.org/beta/guide/using_gpu
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    # Restrict TensorFlow to only use the first GPU
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)

## Dataset

In [None]:
from pathlib import Path

from nucleus.dataset.detections import BasketballDetectionsDataset
from nucleus.visualize import BasketballDetectionsLabelColorMap

In [None]:
cache = Path('/data1/joan/nucleus/dataset_cache/')

In [None]:
dataset = BasketballDetectionsDataset.load(
    path=cache / 'basketball_detections' / 'basketball_detections.json'
)

In [None]:
img = dataset.images[0]
box_args = dict(label_color_map=BasketballDetectionsLabelColorMap)


img.view(box_args=box_args)

## Model

In [None]:
from nucleus.detection import *

In [None]:
backbone_manager = MobileNetV2Manager()

In [None]:
backbone = backbone_manager.create_model(
    input_shape=(None, None, 3),
    alpha=0.35
)

In [None]:
detector_manager = YoloManager()

In [None]:
detector = detector_manager.load_model(
    save_format='tf',
    custom_objects=backbone_manager.custom_objects
)

## Scrap

In [None]:
from nucleus.image import *
from nucleus.box import *

In [None]:
batch_size = 6

ds_val = dataset.get_ds(
    partition='val', 
    n_examples=10,
    shuffle=None,
    repeat=1,
    batch=batch_size
)

In [None]:
for images, _ in ds_val.take(1):
    print(images.shape, images.dtype)

In [None]:
inference_detector = detector_manager.create_inference_model(
    model=detector,
    score_threshold=0.5,
    nms_iou_threshold=0.25
)

In [None]:
detections = tf.convert_to_tensor(inference_detector.predict(images))

In [None]:
for i in range(batch_size):
    box_collection = BoxCollection.from_tensor(
        tensor = unpad_tensor(detections[i]), 
        unique_labels=dataset.unique_boxes_labels
    )
    img = Image.from_hwc(hwc=images[i], box_collection=box_collection)

    img.view(figure_id=i, box_args=box_args)