In [None]:
import tensorflow as tf
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt

from model.ModelBuilder import ModelBuilder

import tensorflow_datasets as tfds

In [None]:
modelName = "MobileNetV3_FPN_TTFNet"
    

model_dir = "checkpoints/"
modelPart = modelName.split("_")

with open(os.path.join("model", "0_Config", modelName+".json"), "r") as config_file:
    config = json.load(config_file)

config['modelName'] = modelName
config['training_config']['num_classes'] = 80

model = ModelBuilder(config)
model.load_weights("logs/MobileNetV3_FPN_TTFNet/weights/_epoch300_mAP0.231").expect_partial()

In [None]:
[test_dataset], dataset_info = tfds.load(name="coco/2017", split=["validation"], with_info=True)

labelMap_Func = dataset_info.features["objects"]["label"].int2str
colors = np.random.rand(80, 3)*255
score_threshold = 0.3
numPic = 5

for sample in test_dataset.take(numPic):
    plt.figure(figsize=(12,12))
    orignal_image = sample['image'].numpy()
    input_img = np.expand_dims(cv2.resize(orignal_image/127.5 -1, dsize = (320, 320)), 0)
    
    ground_truth = sample['objects']['bbox']
    detections = model.predict(input_img)[0]
    
    bbox_preds = detections[:, :4]
    cls_preds = detections[:, 4]
    scores_preds = detections[:, 5]
    score_mask = scores_preds > score_threshold
    
    bbox_preds = tf.boolean_mask(bbox_preds, score_mask)
    cls_preds = tf.boolean_mask(cls_preds, score_mask)
    scores_preds = tf.boolean_mask(scores_preds, score_mask)

    for bbox, cls, score in zip(bbox_preds, cls_preds, scores_preds):
        y1, x1, y2, x2 = bbox
        x1 = int(x1*orignal_image.shape[1])
        x2 = int(x2*orignal_image.shape[1])
        y1 = int(y1*orignal_image.shape[0])
        y2 = int(y2*orignal_image.shape[0])
        _text = '{}_{:.2f}'.format(labelMap_Func(int(cls.numpy())), score)
        cv2.rectangle(orignal_image, (x1, y1), (x2, y2), colors[int(cls.numpy())], 1)
        cv2.putText(orignal_image, _text, (x1,y1+5), cv2.FONT_HERSHEY_COMPLEX, 0.5, colors[int(cls.numpy())], thickness=1, lineType=cv2.LINE_AA)

    plt.imshow(orignal_image)