<a href="https://colab.research.google.com/github/mprksa/blocks/blob/main/object_detector_training100.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/mprksa/blocks.git

In [None]:
!python --version
!pip install --upgrade pip
!pip install mediapipe-model-maker

In [None]:
from google.colab import files
import os
import json
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import object_detector

# **Prepare data**

In [None]:
train_dataset_path = "blocks/train"
validation_dataset_path = "blocks/validation"

# **Review dataset**

In [None]:
with open(os.path.join(train_dataset_path, "labels.json"), "r") as f:
  labels_json = json.load(f)
for category_item in labels_json["categories"]:
  print(f"{category_item['id']}: {category_item['name']}")

In [None]:
#@title Visualize the training dataset
import matplotlib.pyplot as plt
from matplotlib import patches, text, patheffects
from collections import defaultdict
import math

def draw_outline(obj):
  obj.set_path_effects([patheffects.Stroke(linewidth=4,  foreground='black'), patheffects.Normal()])
def draw_box(ax, bb):
  patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))
  draw_outline(patch)
def draw_text(ax, bb, txt, disp):
  text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'
  ,color='white',fontsize=10,weight='bold')
  draw_outline(text)
def draw_bbox(ax, annotations_list, id_to_label, image_shape):
  for annotation in annotations_list:
    cat_id = annotation["category_id"]
    bbox = annotation["bbox"]
    draw_box(ax, bbox)
    draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)
def visualize(dataset_folder, max_examples=None):
  with open(os.path.join(dataset_folder, "labels.json"), "r") as f:
    labels_json = json.load(f)
  images = labels_json["images"]
  cat_id_to_label = {item["id"]:item["name"] for item in labels_json["categories"]}
  image_annots = defaultdict(list)
  for annotation_obj in labels_json["annotations"]:
    image_id = annotation_obj["image_id"]
    image_annots[image_id].append(annotation_obj)

  if max_examples is None:
    max_examples = len(image_annots.items())
  n_rows = math.ceil(max_examples / 3)
  fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image
  for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):
    ax = axs[ind//3, ind%3]
    img = plt.imread(os.path.join(dataset_folder, "images", images[image_id]["file_name"]))
    ax.imshow(img)
    draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)
  plt.show()

visualize(train_dataset_path, 5)

In [None]:
train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir="/tmp/od_data/train")
validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir="/tmp/od_data/validation")
print("train_data size: ", train_data.size)
print("validation_data size: ", validation_data.size)

# **Training**

In [None]:
spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384
hparams = object_detector.HParams(export_dir='exported_model')
options = object_detector.ObjectDetectorOptions(
    supported_model=spec,
    hparams=hparams
)

In [None]:
model = object_detector.ObjectDetector.create(
    train_data=train_data,
    validation_data=validation_data,
    options=options)

# **Evaluate**

In [None]:
loss, coco_metrics = model.evaluate(validation_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")

# **Grafik**

In [None]:
import matplotlib.pyplot as plt

# Data
epochs = range(1, 31)
train_total_loss = [7.3300, 1.0832, 0.6793, 0.4450, 0.3384, 0.2700, 0.2394, 0.2079, 0.1880, 0.1742, 0.1796, 0.1546, 0.1492, 0.1481, 0.1330, 0.1444, 0.1419, 0.1366, 0.1182, 0.1124, 0.1199, 0.1166, 0.1286, 0.1090, 0.1015, 0.0991, 0.1097, 0.0965, 0.1089, 0.0979]
val_total_loss = [1.2431, 0.8980, 0.5816, 0.4320, 0.3684, 0.3062, 0.2955, 0.2812, 0.2674, 0.2656, 0.2597, 0.2618, 0.2617, 0.2554, 0.2498, 0.2590, 0.2549, 0.2520, 0.2564, 0.2615, 0.2498, 0.2550, 0.2665, 0.2611, 0.2585, 0.2540, 0.2546, 0.2583, 0.2738, 0.2616]

train_cls_loss = [7.1613, 0.9615, 0.5766, 0.3548, 0.2554, 0.1912, 0.1625, 0.1342, 0.1156, 0.1029, 0.1078, 0.0851, 0.0803, 0.0788, 0.0653, 0.0751, 0.0721, 0.0670, 0.0515, 0.0456, 0.0521, 0.0489, 0.0591, 0.0420, 0.0354, 0.0318, 0.0412, 0.0296, 0.0405, 0.0306]
val_cls_loss = [1.1204, 0.7906, 0.4824, 0.3346, 0.2668, 0.2165, 0.2068, 0.1901, 0.1762, 0.1767, 0.1719, 0.1733, 0.1699, 0.1677, 0.1619, 0.1695, 0.1664, 0.1626, 0.1645, 0.1646, 0.1612, 0.1648, 0.1752, 0.1698, 0.1719, 0.1679, 0.1680, 0.1700, 0.1766, 0.1733]

train_box_loss = [0.0022, 0.0012, 0.0008, 0.0006, 0.0004, 0.0004, 0.0003, 0.0003, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0001, 0.0002, 0.0002, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0002, 0.0001, 0.0001, 0.0001, 0.0002, 0.0001, 0.0002, 0.0001]
val_box_loss = [0.0012, 0.0009, 0.0008, 0.0007, 0.0008, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0005, 0.0006, 0.0006, 0.0005, 0.0005, 0.0006, 0.0006, 0.0006, 0.0006, 0.0007, 0.0006, 0.0006, 0.0006, 0.0006, 0.0005, 0.0005, 0.0005, 0.0006, 0.0007, 0.0006]

train_model_loss = [7.2693, 1.0225, 0.6185, 0.3843, 0.2777, 0.2093, 0.1787, 0.1472, 0.1274, 0.1135, 0.1190, 0.0940, 0.0886, 0.0875, 0.0724, 0.0839, 0.0813, 0.0761, 0.0577, 0.0519, 0.0594, 0.0561, 0.0681, 0.0485, 0.0410, 0.0387, 0.0493, 0.0361, 0.0485, 0.0376]
val_model_loss = [1.1824, 0.8373, 0.5208, 0.3713, 0.3077, 0.2455, 0.2348, 0.2205, 0.2067, 0.2049, 0.1991, 0.2012, 0.2011, 0.1948, 0.1893, 0.1985, 0.1943, 0.1915, 0.1959, 0.2011, 0.1894, 0.1945, 0.2060, 0.2007, 0.1981, 0.1936, 0.1942, 0.1979, 0.2135, 0.2013]

# Plotting
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(epochs, train_total_loss, label='Training Total Loss')
plt.plot(epochs, val_total_loss, label='Validation Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Total Loss')
plt.title('Total Loss per Epoch')
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(epochs, train_cls_loss, label='Training Classification Loss')
plt.plot(epochs, val_cls_loss, label='Validation Classification Loss')
plt.xlabel('Epochs')
plt.ylabel('Classification Loss')
plt.title('Classification Loss per Epoch')
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(epochs, train_box_loss, label='Training Box Loss')
plt.plot(epochs, val_box_loss, label='Validation Box Loss')
plt.xlabel('Epochs')
plt.ylabel('Box Loss')
plt.title('Box Loss per Epoch')
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(epochs, train_model_loss, label='Training Model Loss')
plt.plot(epochs, val_model_loss, label='Validation Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Model Loss')
plt.title('Model Loss per Epoch')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Data training dan validasi
epochs = range(1, 31)
train_total_loss = [7.3300, 1.0832, 0.6793, 0.4450, 0.3384, 0.2700, 0.2394, 0.2079, 0.1880, 0.1742, 0.1796, 0.1546, 0.1492, 0.1481, 0.1330, 0.1444, 0.1419, 0.1366, 0.1182, 0.1124, 0.1199, 0.1166, 0.1286, 0.1090, 0.1015, 0.0991, 0.1097, 0.0965, 0.1089, 0.0979]
val_total_loss = [1.2431, 0.8980, 0.5816, 0.4320, 0.3684, 0.3062, 0.2955, 0.2812, 0.2674, 0.2656, 0.2597, 0.2618, 0.2617, 0.2554, 0.2498, 0.2590, 0.2549, 0.2520, 0.2564, 0.2615, 0.2498, 0.2550, 0.2665, 0.2611, 0.2585, 0.2540, 0.2546, 0.2583, 0.2738, 0.2616]
train_cls_loss = [7.1613, 0.9615, 0.5766, 0.3548, 0.2554, 0.1912, 0.1625, 0.1342, 0.1156, 0.1029, 0.1078, 0.0851, 0.0803, 0.0788, 0.0653, 0.0751, 0.0721, 0.0670, 0.0515, 0.0456, 0.0521, 0.0489, 0.0591, 0.0420, 0.0354, 0.0318, 0.0412, 0.0296, 0.0405, 0.0306]
val_cls_loss = [1.1204, 0.7906, 0.4824, 0.3346, 0.2668, 0.2165, 0.2068, 0.1901, 0.1762, 0.1767, 0.1719, 0.1733, 0.1699, 0.1677, 0.1619, 0.1695, 0.1664, 0.1626, 0.1645, 0.1646, 0.1612, 0.1648, 0.1752, 0.1698, 0.1719, 0.1679, 0.1680, 0.1700, 0.1766, 0.1733]
train_box_loss = [0.0022, 0.0012, 0.0008, 0.0006, 0.0004, 0.0004, 0.0003, 0.0003, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0001, 0.0002, 0.0002, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0002, 0.0001, 0.0001, 0.0001, 0.0002, 0.0001, 0.0002, 0.0001]
val_box_loss = [0.0012, 0.0009, 0.0008, 0.0007, 0.0008, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0005, 0.0006, 0.0006, 0.0005, 0.0005, 0.0006, 0.0006, 0.0006, 0.0006, 0.0007, 0.0006, 0.0006, 0.0006, 0.0006, 0.0005, 0.0005, 0.0005, 0.0006, 0.0007, 0.0006]

# Plot total loss
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_total_loss, label='Train Total Loss')
plt.plot(epochs, val_total_loss, label='Val Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Total Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()

# Plot cls loss
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_cls_loss, label='Train Classification Loss')
plt.plot(epochs, val_cls_loss, label='Val Classification Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Classification Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()

# Plot box loss
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_box_loss, label='Train Box Loss')
plt.plot(epochs, val_box_loss, label='Val Box Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Box Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()


# **Export Model**

In [None]:
model.export_model()
!ls exported_model
files.download('exported_model/model.tflite')

## **Model quantization**

In [None]:
qat_hparams = object_detector.QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_steps=12, decay_rate=0.99)
model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)
qat_loss, qat_coco_metrics = model.evaluate(validation_data)
print(f"QAT validation loss: {qat_loss}")
print(f"QAT validation coco metrics: {qat_coco_metrics}")

In [None]:
model.export_model('model_int8_qat.tflite')
!ls -lh exported_model
files.download('exported_model/model_int8_qat.tflite')

#**Post-training quantization (fp32 quantization)**

In [None]:
from mediapipe_model_maker import model

In [None]:
quantization_config = quantization.QuantizationConfig.for_float32()

In [None]:
model.restore_float_ckpt()
model.export_model(model_name="model_fp32.tflite", quantization_config=quantization_config)
!ls -lh exported_model
files.download('exported_model/model_fp32.tflite')

## Benchmarking
Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:
* The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.
* The float32 models were trained with the default HParams, and the QAT step for the int8 models was run with `QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1)`.
* For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See the [Hyperparameters](#hyperparameters) section above for more information on configuring training parameters.
* All latency numbers are benchmarked on the Pixel 6.


<table>
<thead>
<col>
<col>
<colgroup span="2"></colgroup>
<colgroup span="2"></colgroup>
<colgroup span="2"></colgroup>
<tr>
<th rowspan="2">Model architecture</th>
<th rowspan="2">Input Image Size</th>
<th colspan="2" scope="colgroup">Test AP</th>
<th colspan="2" scope="colgroup">CPU Latency</th>
<th colspan="2" scope="colgroup">Model Size</th>
</tr>
<tr>
<th>float32</th>
<th>QAT int8</th>
<th>float32</th>
<th>QAT int8</th>
<th>float32</th>
<th>QAT int8</th>
</tr>
</thead>
<tbody>
<tr>
<td>MobileNetV2</td>
<td>256x256</td>
<td>88.4%</td>
<td>73.5%</td>
<td>48ms</td>
<td>16ms</td>
<td>11MB</td>
<td>3.2MB</td>
</tr>
<tr>
<td>MobileNetV2 I320</td>
<td>320x320</td>
<td>89.1%</td>
<td>75.5%</td>
<td>75ms</td>
<td>33.38ms</td>
<td>10MB</td>
<td>3.3MB</td>
</tr>
<tr>
<td>MobileNet MultiHW AVG</td>
<td>256x256</td>
<td>88.5%</td>
<td>70.0%</td>
<td>56ms</td>
<td>19ms</td>
<td>13MB</td>
<td>3.6MB</td>
</tr>
<tr>
<td>MobileNet MultiHW AVG I384</td>
<td>384x384</td>
<td>92.7%</td>
<td>73.4%</td>
<td>238ms</td>
<td>41ms</td>
<td>13MB</td>
<td>3.6MB</td>
</tr>

</tbody>
</table>

