In [2]:
# %pip install scikit-image

In [3]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageChops, ImageOps
import cv2
import numpy as np
from skimage.measure import label, regionprops

model = tf.keras.models.load_model('models/v0_5b.keras')
image_size = (640, 640)
display_results = False

def display(image, prediction_image, combined_image, cleaned, expanded, combined_image2, only_area_of_interest_image):
  plt.figure(figsize=(15, 15))
  plt.subplot(3, 3, 1)
  plt.title('Input')
  plt.imshow(image)
  plt.axis('off')
  plt.subplot(3, 3, 2)
  plt.title('Output')
  plt.imshow(prediction_image)
  plt.axis('off')
  plt.subplot(3, 3, 3)
  plt.title('Combined')
  plt.imshow(combined_image)
  plt.axis('off')

  plt.subplot(3, 3, 4)
  plt.title('cleaned')
  plt.imshow(cleaned)
  plt.axis('off')
  plt.subplot(3, 3, 5)
  plt.title('expanded')
  plt.imshow(expanded)
  plt.axis('off')
  plt.subplot(3, 3, 6)
  plt.title('combined_image2')
  plt.imshow(combined_image2)
  plt.axis('off')
  plt.subplot(3, 3, 8)
  plt.title('only_area_of_interest_image')
  plt.imshow(only_area_of_interest_image)
  plt.axis('off')
  plt.show()

def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def predict(image_path):
  image_source = Image.open(f'data/predict/input/{image_path}').convert("RGB")
  image_source = ImageOps.exif_transpose(image_source)

  image = np.array(image_source)
  image = tf.convert_to_tensor(image, dtype=tf.float32)

  original_image_size = (image.shape[1], image.shape[0])

  image_input = tf.image.resize(image, image_size)
  image_input = image_input / 255.0

  prediction = model.predict(image_input[tf.newaxis, ...])
  prediction_mask = create_mask(prediction)
  prediction_image = tf.keras.utils.array_to_img(prediction_mask)
  prediction_image = prediction_image.convert("RGB")
  prediction_image = prediction_image.resize(original_image_size, Image.Resampling.LANCZOS)
  if display_results:
    prediction_image.save(f'data/predict/output/{image_path}')

  combined_image = ImageChops.multiply(image_source, prediction_image)  

  clean_mask = remove_small_patches(prediction_mask.numpy(), min_area=200)
  expanded_mask = expand_areas(clean_mask, dilation_size=10)
  expanded_mask_3d = np.expand_dims(expanded_mask, axis=-1)
  expanded_image = tf.keras.utils.array_to_img(expanded_mask_3d)
  expanded_image = expanded_image.convert("RGB")
  expanded_image = expanded_image.resize(original_image_size, Image.Resampling.LANCZOS)
  combined_image_with_expanded = ImageChops.multiply(image_source, expanded_image)
  if display_results:
    combined_image_with_expanded.save(f'data/predict/output/{image_path}_combined.jpg')

  only_area_of_interest_image = resize_image_and_area(expanded_mask, combined_image_with_expanded)
  if display_results:
    only_area_of_interest_image.save(f'data/predict/output/{image_path}_area_of_interest.jpg')
  else:
    only_area_of_interest_image.save(f'data/predict/output/{image_path}')
  
  return image_source, prediction_image, combined_image, clean_mask, expanded_mask, combined_image_with_expanded, only_area_of_interest_image


def remove_small_patches(mask, min_area=100):
    labeled_mask, num_labels = label(mask, return_num=True)
    clean_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_area:
            clean_mask[labeled_mask == region.label] = 1

    return clean_mask

def expand_areas(mask, dilation_size=10):
    kernel = np.ones((dilation_size, dilation_size), np.uint8)
    expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
    return expanded_mask

def resize_image_and_area(segmentation, original_image):
    non_zero_indices = np.argwhere(segmentation > 0)
    y_min, x_min = np.min(non_zero_indices, axis=0)
    y_max, x_max = np.max(non_zero_indices, axis=0)

    # Add padding of 10px
    padding = 10
    x_min = max(x_min - padding, 0)
    y_min = max(y_min - padding, 0)
    x_max = min(x_max + padding, segmentation.shape[1] - 1)
    y_max = min(y_max + padding, segmentation.shape[0] - 1)

    # Calculate the scaling factors to map the segmentation to the original image size
    original_width, original_height = original_image.size
    scale_x = original_width / image_size[1]
    scale_y = original_height / image_size[0]

    # Map the bounding box coordinates to the original image size
    x_min_orig = int(x_min * scale_x)
    x_max_orig = int(x_max * scale_x)
    y_min_orig = int(y_min * scale_y)
    y_max_orig = int(y_max * scale_y)

    # Crop the original image using the mapped bounding box
    cropped_image = original_image.crop((x_min_orig, y_min_orig, x_max_orig, y_max_orig))

    # Resize the cropped image to 640x640
    result_image = cropped_image.resize(image_size, Image.Resampling.LANCZOS)
    return result_image

files = [f for f in os.listdir(f'data/predict/input')]
for file in files:
  try:
    image, prediction_image, combined_image, clean_mask, expanded_mask, combined_image_with_expanded, only_area_of_interest_image = predict(file)
    if display_results:
      display(image, prediction_image, combined_image, clean_mask, expanded_mask, combined_image_with_expanded, only_area_of_interest_image)
  except Exception as e:
    print(f'Error processing {file}: {e}')

