In [1]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageOps, ImageDraw
import numpy as np


2024-11-01 16:57:05.971306: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
source = 'data_source'
target = 'data'
face_detection_model = 'models/face_detection_v0_4.tflite'
image_prediction_size = (96, 96)
image_output_size = (96, 96)
display_results = False


In [3]:
def load_model(model_path):
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    return interpreter

def predict(interpreter, input):
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_data = np.array(input, dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    return output_data[0]


In [4]:
model = load_model(face_detection_model)

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [5]:
def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

for folder in os.listdir(f'{source}'):
    mkdir(f'{target}/{folder}/image')
    mkdir(f'{target}/{folder}/seg')

In [6]:
def display(image, image_with_bbox, prediction_image, label):
  plt.figure(figsize=(15, 15))
  plt.subplot(2, 2, 1)
  plt.title('Input')
  plt.imshow(image)
  plt.axis('off')
  plt.subplot(2, 2, 2)
  plt.title('Input with bounding box')
  plt.imshow(image_with_bbox)
  plt.axis('off')
  plt.subplot(2, 2, 3)
  plt.title('Output')
  plt.imshow(prediction_image)
  plt.axis('off')
  plt.subplot(2, 2, 4)
  plt.title('Label')
  plt.imshow(label)
  plt.axis('off')
  plt.show()

def get_face_bbox(folder, image_path):
  image_source = Image.open(f'{source}/{folder}/image/{image_path}').convert("RGB")
  image_source = ImageOps.exif_transpose(image_source)
  label_source = Image.open(f"{source}/{folder}/seg/{image_path.split('.')[0]}.png").convert("L")

  image = np.array(image_source)
  image = tf.convert_to_tensor(image, dtype=tf.float32)

  image_input = tf.image.resize(image, image_prediction_size)
  image_input = image_input / 255.0

  prediction = predict(model, image_input[tf.newaxis, ...])
  bbox = prediction * [image_source.width, image_source.height, image_source.width, image_source.height]
  left = min(bbox[0], bbox[2])
  right = max(bbox[0], bbox[2])
  top = min(bbox[1], bbox[3])
  bottom = max(bbox[1], bbox[3])

  image_with_bbox = image_source.copy()
  if display_results:
    draw = ImageDraw.Draw(image_with_bbox)
    draw.rectangle([(left, top), (right, bottom)], outline="red", width=3)

  image_cropped = image_source.crop((left, top, right, bottom))
  prediction_image = image_cropped.resize(image_output_size, Image.Resampling.LANCZOS)
  
  label_cropped = label_source.crop((left, top, right, bottom))
  label_cropped_array = np.array(label_cropped)
  label_cropped_array = np.where(np.logical_and(label_cropped_array != 4, label_cropped_array != 5), 0, 1)
  prediction_label = Image.fromarray(label_cropped_array.astype(np.uint8))

  if not display_results:
    prediction_image.save(f'{target}/{folder}/image/{image_path}')
    prediction_label.save(f'{target}/{folder}/seg/{image_path}')
  return image_source, image_with_bbox, prediction_image, prediction_label

In [7]:
files = [(folder, f) for folder in os.listdir(f'{source}') for f in os.listdir(f'{source}/{folder}/image')]
c = 0
for file in files:
  image, image_with_bbox, prediction_image, prediction_label = get_face_bbox(file[0], file[1])
  c += 1
  if display_results:
    display(image, image_with_bbox, prediction_image, prediction_label)
    break
  if c % 100 == 0:
    print(f'{c} / {len(files)} ({c / len(files) * 100:.2f}%)')
if c % 100 != 0 and not display_results:
  print(f'{c} / {len(files)} ({c / len(files) * 100:.2f}%)')


100 / 22188 (0.45%)
200 / 22188 (0.90%)
300 / 22188 (1.35%)
400 / 22188 (1.80%)
500 / 22188 (2.25%)
600 / 22188 (2.70%)
700 / 22188 (3.15%)
800 / 22188 (3.61%)
900 / 22188 (4.06%)
1000 / 22188 (4.51%)
1100 / 22188 (4.96%)
1200 / 22188 (5.41%)
1300 / 22188 (5.86%)
1400 / 22188 (6.31%)
1500 / 22188 (6.76%)
1600 / 22188 (7.21%)
1700 / 22188 (7.66%)
1800 / 22188 (8.11%)
1900 / 22188 (8.56%)
2000 / 22188 (9.01%)
2100 / 22188 (9.46%)
2200 / 22188 (9.92%)
2300 / 22188 (10.37%)
2400 / 22188 (10.82%)
2500 / 22188 (11.27%)
2600 / 22188 (11.72%)
2700 / 22188 (12.17%)
2800 / 22188 (12.62%)
2900 / 22188 (13.07%)
3000 / 22188 (13.52%)
3100 / 22188 (13.97%)
3200 / 22188 (14.42%)
3300 / 22188 (14.87%)
3400 / 22188 (15.32%)
3500 / 22188 (15.77%)
3600 / 22188 (16.22%)
3700 / 22188 (16.68%)
3800 / 22188 (17.13%)
3900 / 22188 (17.58%)
4000 / 22188 (18.03%)
4100 / 22188 (18.48%)
4200 / 22188 (18.93%)
4300 / 22188 (19.38%)
4400 / 22188 (19.83%)
4500 / 22188 (20.28%)
4600 / 22188 (20.73%)
4700 / 22188 (21.18