In [None]:
import cv2
from mediapipe.tasks import python
import numpy as np
import mediapipe as mp
from mediapipe.tasks.python import vision
import matplotlib.pyplot as plt
from ycb import ycb_colourspace_stuff as col


Load image

In [None]:
image_path = "./outputs/example_seg_img.png"
img_bgr = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
h, w = img_rgb.shape[:2]


Resize for MediaPipe

In [None]:
input_img = cv2.resize(img_rgb, (226, 226))


Run Segmentation Model

In [None]:
model_path = "./preprocessing/models/selfie_multiclass_256x256.tflite"
class_indices = [2, 3]
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=input_img)

base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.ImageSegmenterOptions(base_options=base_options,
                                           output_category_mask=True,
                                           output_confidence_masks=False)


In [None]:
# Create Image Segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:
    # Segment
    segmentation_result = segmenter.segment(mp_image)



Get segmentation mask and resize to og image size

In [None]:
category_mask = segmentation_result.category_mask.numpy_view()
cat_mask_resized = cv2.resize(src=category_mask, dsize=(w,h), interpolation=cv2.INTER_NEAREST)
mask = np.isin(cat_mask_resized, class_indices)
mask = mask.astype(np.float32)


In [None]:
print(f"Single mask stats: Minimum value is {np.min(mask)}, Maximum value is {np.max(mask)}, Mean value is {np.mean(mask)}")

Convert to YCrCb

In [None]:
img_ycrcb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2YCrCb)
y, cr, cb = cv2.split(img_ycrcb)


Create green spill mask: high Cb, low Cr

In [None]:
spill_mask = ((cb > 125) & (cr < 135)).astype(np.float32)


In [None]:
print(f"Number of pixels in spill mask: {np.sum(spill_mask)}")


In [None]:
print(f"Mean value of spill mask: {np.mean(spill_mask)}")


Combine masks: only correct where both spill and person are present

In [None]:
combined_mask = mask * spill_mask
# combined_mask = cv2.GaussianBlur(combined_mask, (7, 7), 0)


In [None]:
print(f"Combined mask stats: Minimum value is {np.min(combined_mask)}, Maximum value is {np.max(combined_mask)}, Mean value is {np.mean(combined_mask)}")

Apply green spill correction

In [None]:
correction_strength = 2.0
cb = cb.astype(np.float32)
cr = cr.astype(np.float32)

cb -= correction_strength * combined_mask * (cb - 128)
cr += correction_strength * combined_mask * (128 - cr)

cb = np.clip(cb, 0, 255).astype(np.uint8)
cr = np.clip(cr, 0, 255).astype(np.uint8)

Reconstruct and convert back to RGB

In [None]:
corrected_ycrcb = cv2.merge((y, cr, cb))
corrected_bgr = cv2.cvtColor(corrected_ycrcb, cv2.COLOR_YCrCb2BGR)
corrected_rgb = cv2.cvtColor(corrected_bgr, cv2.COLOR_BGR2RGB)


Display before and after green spill correction

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(img_rgb)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(corrected_rgb)
plt.title("Corrected Image")
plt.axis('off')

plt.tight_layout()
plt.show()



Heatmaps

In [None]:
total_og_green = col.quantify_green(img_ycb=img_ycrcb)
total_corrected_green = col.quantify_green(corrected_ycrcb)

In [None]:
print(f"Total OG Green: {total_og_green}")  # 2,360
print(f"Total Corrected Green: {total_corrected_green}")    # 1,977


In [None]:
# col.create_green_spill_heatmap(img_ycb=corrected_ycrcb)


In [None]:
# col.create_green_spill_heatmap(img_ycb=img_ycrcb)
