In [None]:
import os
import random
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage import io
from skimage.segmentation import mark_boundaries, slic


In [None]:
img_dir = "data/colorization/coco"
img_path = random.choice(os.listdir(img_dir))
img_path = os.path.join(img_dir, img_path)
# Load the image
Image.open(img_path)

In [None]:
image = io.imread(img_path)
n_segments = 10
# Apply SLIC superpixel segmentation
segments = slic(image, n_segments=n_segments, compactness=20, sigma=1)
segment_ids = np.unique(segments)

In [None]:
# Display the segmented image
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(mark_boundaries(image, segments))
ax.set_title('SLIC Superpixel Segmentation')
plt.show()

In [None]:
def get_most_uniform_segments(hue, segments, segment_ids):
    seg_color_var = []
    for sid in segment_ids:
        segment_hue = hue[segments == sid]
        seg_color_var.append((sid, segment_hue.var()))

    seg_color_var.sort(key=lambda x: x[1])
    return seg_color_var[:6]

def get_most_saturated_segments(saturation, segments, segment_ids):
    seg_sats = []
    for sid, _ in segment_ids:
        segment_sat = saturation[segments == sid]
        seg_sats.append((sid, segment_sat.mean()))

    seg_sats.sort(key=lambda x: x[1], reverse=True)
    return seg_sats[:3]

In [None]:
hsv_img = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
h, s, v = cv2.split(hsv_img)

most_uniform_segments = get_most_uniform_segments(h, segments, segment_ids)
most_colorful_segments = get_most_saturated_segments(s, segments, most_uniform_segments)
high_sat_sid, sat_value = random.choice(most_colorful_segments)
high_sat_sid, sat_value

In [None]:
binary_mask = np.zeros_like(segments, dtype=np.uint8)
binary_mask[np.isin(segments, high_sat_sid)] = 1
Image.fromarray(binary_mask*255, mode="L")

In [None]:
def get_hint_image(image, binary_mask):
    def mean_color(ch):
        ch = int((ch * binary_mask).sum() // binary_mask.sum())
        ch = np.clip(ch, 0, 255)
        return ch

    r,g,b = np.array_split(image, 3, axis=2)
    r = mean_color(r.squeeze())
    g = mean_color(g.squeeze())
    b = mean_color(b.squeeze())

    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    hint = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)

    mask = binary_mask.astype(bool)

    hint[mask, 0] = r
    hint[mask, 1] = g
    hint[mask, 2] = b
    return hint

hint = get_hint_image(image, binary_mask)
Image.fromarray(hint)