In [None]:
import torch
import json
import cv2
from datetime import date
import numpy as np
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from ibug.face_detection import RetinaFacePredictor
from ibug.face_parsing import FaceParser as RTNetPredictor
%matplotlib inline

In [None]:
threshold = 0.8 # default = 0.8
weights = None # r"C:\mahmoud_dev\machine learning\segmentation\face_parsing\ibug\face_parsing\rtnet\weights\rtnet101-fcn-14.torch" # default = None
num_classes = 14 # default = 11
max_num_faces = 50 # default = 50

parser_encoder = 'rtnet50'
parser_decoder = 'fcn'
rotate_image = False
today = date.today()


if torch.cuda.is_available():
    device = 'cuda:0'
face_detector = RetinaFacePredictor(threshold=threshold, device=device, model=(RetinaFacePredictor.get_model('mobilenet0.25')))
face_parser = RTNetPredictor(device=device, ckpt=weights, encoder=parser_encoder, decoder=parser_decoder, num_classes=num_classes)

def get_image_pred(img, face_detector, face_parser):
    if rotate_image:
        img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)

    faces = face_detector(img, rgb=False)
    masks = face_parser.predict_img(img, faces, rgb=False)
    
    return faces, masks


In [None]:
categories_list = ['background', 'skin', 'left_eyebrow', 'nose', 'upper_lip', 'inner_mouth', 'lower_lip',
                   'right_eyebrow', 'left_eye', 'hair', 'left_ear', 'right_ear', 'right_eye', 'glasses']

class_ids = {categories_list[i]: i for i in range(len(categories_list))}
print(class_ids)


In [None]:
# Set paths and filenames
image_dir = r'D:\_Xchng\Mahmoud\segmenation\dataset\data\raw_images'
json_filepath = r'D:\_Xchng\Mahmoud\segmenation\dataset\data\instances_default.json'
txt_filepath = r'D:\_Xchng\Mahmoud\segmenation\dataset\data\yolo_annotations\obj_train_data'

In [None]:
def save_yolo(image_dir):
# Loop through images in directory
    annotation = ''

    for image_id, filename in enumerate(os.listdir(image_dir)):

        if filename.endswith(tuple([".png", ".jpg"])):

            image_path = os.path.join(image_dir, filename)
            image = cv2.imread(image_path)

            faces, masks = get_image_pred(image, face_detector, face_parser)
            mask_arr, face = masks[0], faces[0] # assumes 1 face per image, loop for more faces.
            for class_name in categories_list:
                yolo_annotation = generate_segmentation_yolo(mask_arr, class_name, image, categories_list)
                annotation += yolo_annotation

        # Save annotation file to disk
        values = filename.split(".")
        path  = os.path.join(txt_filepath, f"{values[0]}.txt")
        with open(path, 'w') as f:
            f.write(annotation)


def generate_segmentation_yolo(mask, class_name, img, class_names):
    annotation_seg = ''
    width, height = img.shape[1], img.shape[0]
    # Loop through all unique pixel values in the mask
    for value in np.unique(mask):
        # Skip background value
        if value == 0:
            continue
    binary_mask = (mask > 0.5).astype(np.uint8)
    # Get the contours of the mask
    retrieval_method = cv2.RETR_TREE # options: cv2.RETR_EXTERNAL, cv2.RETR_TREE
    contour_approximation = cv2.CHAIN_APPROX_NONE # options: cv2.CHAIN_APPROX_SIMPLE, cv2.CHAIN_APPROX_NONE
    contours, hierarchy = cv2.findContours(binary_mask, retrieval_method, contour_approximation)
    
    # Define class names and map each class name to an index
    class_dict = {class_name: index for index, class_name in enumerate(class_names)}
    
    # Write the bounding boxes to a txt file in YOLO format for each class
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        x_center = (x + w / 2) / width
        y_center = (y + h / 2) / height
        w_norm = w / width
        h_norm = h / height
        class_index = int(mask[y:y+h, x:x+w].mean() + 0.5)
        class_name = class_names[class_index]
        annotation_seg += f"{class_dict[class_name]} {x_center} {y_center} {w_norm} {h_norm}\n"

    
    return annotation_seg