# TensorFlow Object Detection API Results Exporter

## Import library modules

In [None]:
import sys
sys.dont_write_bytecode = True

import numpy as np
import os
import time
import tensorflow as tf
import shutil
import xml.etree.ElementTree as ET

from PIL import Image
from matplotlib import pyplot as plt

# Object detection module
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_utils
from object_detection.builders import model_builder

In [None]:
## Avoid out of memory by setting GPU memory consumption growth
gpu_list = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpu_list:
    tf.config.experimental.set_memory_growth(gpu, True)

## Configuration parameter

In [None]:
CFG = {
    "model_path": "Model",
    "extension": ('.jpg', '.png'),
}

## Declare functions

### Model selection

In [None]:
def model_selection(model_dir):
    print("Scanning for models ... ", end="")
    legit_model_list = [model for model in os.listdir(model_dir)
                        if os.path.exists(os.path.join(model_dir, model, 'pipeline.config')) and
                        os.path.exists(os.path.join(model_dir, model, 'checkpoint', 'ckpt-0.index')) and
                        os.path.exists(os.path.join(model_dir, model, 'label_map.pbtxt'))]
    num_models = len(legit_model_list)
    print(f"Found {num_models}")
    
    if len(legit_model_list) == 0:
        return None
    
    if len(legit_model_list) == 1:
        return legit_model_list[0]
    
    while True:
        for idx, model in enumerate(legit_model_list):
            print(f"{idx}: {model}")
        try:
            model_num = int(input("Select a model: "))         
            if 0 <= model_num < num_models:
                return legit_model_list[model_num]
            else:
                print("Invalid input. Please try again.")
        except ValueError:
            print("Invalid input. Please enter a valid number.")

In [None]:
def create_xml_annotation(image_filename, objects):
    root = ET.Element("annotation")
    
    # Create basic image information
    filename      = ET.SubElement(root, "filename")
    filename.text = image_filename

    # Object annotations
    for obj in objects:

        # Tree preparation
        obj_elem  = ET.SubElement(root, "object")
        classname   = ET.SubElement(obj_elem, "classname")
        bndbox      = ET.SubElement(obj_elem, "bndbox")
        xmin          = ET.SubElement(bndbox, "xmin")
        ymin          = ET.SubElement(bndbox, "ymin")
        xmax          = ET.SubElement(bndbox, "xmax")
        ymax          = ET.SubElement(bndbox, "ymax")
        score       = ET.SubElement(obj_elem, "score")
        
        # Data text
        ## Class name
        classname.text = obj["classname"]

        ## Bounding box coordinates
        xmin.text = str(obj["xmin"])
        ymin.text = str(obj["ymin"])
        xmax.text = str(obj["xmax"])
        ymax.text = str(obj["ymax"])

        ## Detection score
        score.text = str(obj["score"])

    # Create and return the ElementTree
    tree = ET.ElementTree(root)
    return tree

In [None]:
def process(dir_processing, detection_model, category_index, threshold=0.5, export_roi=False, export_xml=False):
    def prepare_path(path_str):
        if not os.path.exists(path_str):
            os.makedirs(path_str)

    def detect_fn(model,tensor):
        image, shapes = model.preprocess(tensor)
        prediction_dict = model.predict(image, shapes)
        detections = model.postprocess(prediction_dict, shapes)

        # Post processing the detection results (Convert Tensor to Numpy)
        num_detections = int(detections.pop('num_detections')) # Only num_detections has different shape
        detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()}
        detections['num_detections'] = num_detections
        detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
        return detections
    
    # Prepare result folders
    export_path = os.path.join(dir_processing,'export_result')
    complete_path = os.path.join(dir_processing,'complete')
    prepare_path(export_path)
    prepare_path(complete_path)

    ## Load images
    image_list = [os.path.join(dir_processing, file) for file in os.listdir(dir_processing) if os.path.isfile(os.path.join(dir_processing, file))]
    image_list = [path for path in image_list if any(ext in path for ext in CFG["extension"])]

    for image_path in image_list:
        # Set parameter
        filename = os.path.basename(image_path)
        _, ext = os.path.splitext(filename)

        print(f'Running inference for {filename} ... ')

        # Load an image and save into a numpy array
        image_raw = Image.open(image_path)
        image_np = np.array(image_raw.convert('RGB'))

        # Convert the image to a tensor
        image_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
                
        # Run inference
        detections = detect_fn(detection_model,image_tensor)

        # Visualization of the results of a detection.
        label_id_offset = 1
        image_np_with_detections = image_np.copy()

        # Get parameters
        boxes = detections['detection_boxes']
        max_boxes_to_draw = detections['num_detections']
        scores = detections['detection_scores']

        vis_utils.visualize_boxes_and_labels_on_image_array(
                image_np_with_detections,
                boxes,
                detections['detection_classes']+label_id_offset,
                scores,
                category_index,
                use_normalized_coordinates=True,
                max_boxes_to_draw=max_boxes_to_draw,
                min_score_thresh=threshold,
                line_thickness=2,
                agnostic_mode=False)

        plt.figure(figsize=(8,8))
        plt.imshow(image_np_with_detections)
        plt.axis('off')
        print('Done')
        plt.show()
    
        ## Export image with ROI
        if export_roi:
            image_ROI    = Image.fromarray(image_np_with_detections) 
            filename_ROI = filename.replace(f"{ext}", f"_ROI{ext}")
            destination = os.path.join(export_path,filename_ROI)

            if ext == ".jpg":
                image_ROI.save(destination, 
                            format='JPEG', 
                            subsampling=0, 
                            quality=image_raw.info.get('quality', 95),
                            optimize=True)
            
            elif ext == ".png":
                image_ROI.save(destination, 
                            format='png', 
                            subsampling=0, 
                            quality=image_raw.info.get('quality', 95),
                            optimize=True)
        
        ## Export XML
        if export_xml:
            object_detection_results = []
            # iterate over all objects found
            for box, class_num, score in zip(boxes, detections['detection_classes'], scores):
                if score < threshold:
                    continue

                class_name = category_index[class_num + 1]['name']
                                
                ymin = box[0] # left   or ymin
                xmin = box[1] # right  or xmin
                ymax = box[2] # top    or ymax
                xmax = box[3] # bottom or xmax

                object_detection_results.append({
                    "classname": class_name,
                    "xmin": xmin,
                    "ymin": ymin,
                    "xmax": xmax,
                    "ymax": ymax,
                    "score": score,
                })

            xml_tree = create_xml_annotation(filename, object_detection_results)
            destination = os.path.join(export_path, filename.replace(ext, ".xml"))
            xml_tree.write(destination)


        """
        im_height = image_np_with_detections.shape[0]
        im_width  = image_np_with_detections.shape[1]
    
        # iterate over all objects found
        for i in range(min(max_boxes_to_draw, boxes.shape[0])):
            if scores is None or scores[i] > threshold:

                class_num  = detections['detection_classes'][i] # this will return class number start from 0
                class_name = category_index[class_num + 1]['name']

                # boxes[i] is the box which will be drawn
                # boxes will return coordinates [ymin, xmin, ymax, xmax]
                ymin = boxes[i][0] * im_height # left   or ymin
                xmin = boxes[i][1] * im_width  # right  or xmin
                ymax = boxes[i][2] * im_height # top    or ymax
                xmax = boxes[i][3] * im_width  # bottom or xmax

        

                # Crop image
                ymin = ymin.astype(int)
                xmin = xmin.astype(int)
                ymax = ymax.astype(int)
                xmax = xmax.astype(int)

                # Post processing if it is not square image
                xmin, xmax, ymin, ymax = perfect_square(xmin, xmax, ymin, ymax, im_width, im_height)

                # Crop image
                image_crop = image_np[ymin:ymax, xmin:xmax] # output is array
                image_crop = Image.fromarray(image_crop) # Convert array to image

                # Save crop detected images
                filename_temp = filename.replace(fileext, f'_object_{str(i)}_{class_name + fileext}')
                filename_temp = os.path.join(dir_processing, class_name, filename_temp)
                try:
                    image_crop.save(filename_temp, 
                                format='JPEG', 
                                subsampling=0, 
                                quality=95, 
                                optimize=True)
                except OSError:
                    os.makedirs(os.path.join(dir_processing, class_name))
                    image_crop.save(filename_temp, 
                                format='JPEG', 
                                subsampling=0, 
                                quality=95, 
                                optimize=True)        
        """
    pass

## Execute

In [None]:
def initializing():
    ## Scan and look for models in Model folder
    model = model_selection(CFG["model_path"])

    ## Parameter
    PATH_TO_CFG      = os.path.join(CFG["model_path"], model, 'pipeline.config')
    PATH_TO_CKPT     = os.path.join(CFG["model_path"], model, 'checkpoint', 'ckpt-0')
    PATH_TO_LABELS   = os.path.join(CFG["model_path"], model, 'label_map.pbtxt')

    ## Load model
    print('Loading model... ', end='')
    start_time      = time.time()

    ## Load pipeline config and build a detection model
    configs         = config_util.get_configs_from_pipeline_file(PATH_TO_CFG)
    model_config    = configs['model']
    detection_model = model_builder.build(model_config=model_config, is_training=False)

    ## Restore checkpoint
    ckpt            = tf.compat.v2.train.Checkpoint(model=detection_model)
    ckpt.restore(PATH_TO_CKPT).expect_partial()

    ## List of the strings that is used to add label for each box.
    category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)

    end_time = time.time()
    elapsed_time = end_time - start_time
    print('Done! Took {} seconds'.format(elapsed_time))
    return detection_model, category_index

In [None]:
%matplotlib inline
detection_model, category_index = initializing()
process("Images", detection_model, category_index, threshold=0.8, export_roi=True, export_xml=True)