In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from math import comb
from sklearn.cluster import KMeans
import scipy as sp
from PIL import Image
import sys, os, time
sys.path.append(os.path.abspath("../../"))
from pipeoptz import Pipeline, Node, PipelineOptimizer, BoolParameter, IntParameter

In [None]:
X = []
y = []
for i in range(5):
    X.append({"image":np.array(Image.open(f"images/lolipop/{i}.png"))})
    y.append(np.array(Image.open(f"images/lolipop/edited/{i}.png")))

In [None]:
def ith_subset(n, i):
    total = 2**n
    if i < 0 or i >= total:
        raise ValueError(f"Index i must be in [0, {total - 1}]")

    remaining = i
    for k in range(n + 1):
        c = comb(n, k)
        if remaining < c:
            cardinality = k
            break
        remaining -= c

    subset = []
    x = 0
    for j in range(cardinality):
        while comb(n-1 - x, cardinality - j - 1) <= remaining:
            remaining -= comb(n-1 - x, cardinality - j - 1)
            x += 1
        subset.append(x)
        x += 1
    return subset

def integer(n):
    return n

def to_mask(image):
    return image[:,:,3] != 0

def biggest_mask(elements):
    size = []
    if len(elements) != 0:
        for el in elements:
            size.append(el.sum())
        return elements[size.index(max(size))]
    else:
        return None

def colored_mask(image, mask):
    if mask is not None:
        return image*mask[:,:,np.newaxis]
    else:
        return np.zeros(shape=image.shape)

def extract_palette(image: np.ndarray, n_colors: int, sample_size: int = 0,
                    max_iter: int = 300, use_lab: bool = False) -> np.ndarray:
    if image.shape[2] == 4:
        opaque_mask = image[:,:,3] != 0
        pixels = image[opaque_mask][:,:3]
        if pixels.shape[0] == 0: return np.array([], dtype=np.uint8).reshape(0,3)
    elif image.shape[2] == 3: # RGB
        pixels = image.reshape(-1, 3)
    else:
        raise ValueError("Image must be RGB or RGBA.")

    if use_lab:
        pixels_lab = cv2.cvtColor(pixels.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_RGB2LAB)[0]
        data_for_kmeans = pixels_lab
    else:
        data_for_kmeans = pixels.astype(np.float32)

    if sample_size > 0 and data_for_kmeans.shape[0] > sample_size:
        indices = np.random.choice(data_for_kmeans.shape[0], size=sample_size, replace=False)
        sample = data_for_kmeans[indices]
    else:
        sample = data_for_kmeans
    
    if sample.shape[0] < n_colors:
        n_colors = max(1, sample.shape[0])
        if n_colors == 0: return np.array([], dtype=np.uint8).reshape(0,3)


    kmeans = KMeans(n_clusters=n_colors, max_iter=max_iter, n_init='auto', random_state=0)
    kmeans.fit(sample)
    centers = kmeans.cluster_centers_

    if use_lab:
        palette_lab = np.array(centers, dtype=np.float32).reshape(1, -1, 3)
        palette_rgb = cv2.cvtColor(palette_lab, cv2.COLOR_LAB2RGB)[0]
        palette = np.clip(palette_rgb, 0, 255).astype(np.uint8)
        l_values = centers[:, 0]
        sorted_indices = np.argsort(l_values)[::-1]
        palette = palette[sorted_indices]

    else:
        luminance = 0.299 * centers[:, 0] + 0.587 * centers[:, 1] + 0.114 * centers[:, 2]
        sorted_indices = np.argsort(luminance)[::-1]
        palette = np.clip(centers[sorted_indices], 0, 255).astype(np.uint8)
    return palette

def recolor(image, palette) :
    h, w, c = image.shape
    is_rgba = (c == 4)
    
    rgb_image_part = image[:, :, :3].astype(np.float32)
    pixels_flat = rgb_image_part.reshape(-1, 3)
    palette_float = palette.astype(np.float32)

    dists = np.linalg.norm(pixels_flat[:, np.newaxis, :] - palette_float[np.newaxis, :, :], axis=2)
    
    nearest_palette_indices = np.argmin(dists, axis=1)
    recolored_rgb_flat = palette[nearest_palette_indices]
    recolored_rgb = recolored_rgb_flat.reshape(h, w, 3).astype(np.uint8)

    if is_rgba:
        alpha_channel = image[:, :, 3:]
        return np.dstack((recolored_rgb, alpha_channel))
    return recolored_rgb

def remove_palette(image, recolored_image, palette, indices_to_remove):
    if image.shape[2] != 4:
        raise ValueError("Original image must be RGBA.")

    output_image = image.copy()
    recolored_rgb = recolored_image[:,:,:3]

    for i in indices_to_remove:
        if 0 <= i < len(palette):
            color_to_match = palette[i]
            match_mask = np.all(recolored_rgb == color_to_match, axis=2)
            output_image[match_mask, 3] = 0
    return output_image

def isolate(binary_mask, sizemin=1):
    if not np.any(binary_mask):
        return []
    labeled_array, num_features = sp.ndimage.label(binary_mask)
    
    elements = []
    for i in range(1, num_features + 1):
        component_mask = (labeled_array == i)
        if np.sum(component_mask) >= sizemin:
            elements.append(component_mask)
    return elements


In [None]:
def IoU(im1, im2):
    f1 = im1[:,:,3]==255
    f2 = im2[:,:,3]==255
    intersection = (f1 & f2).sum()
    union = (f1 | f2).sum()
    return intersection / union if union > 0 else 0

def loss(f1, f2):
    return 1/(IoU(f1, f2)+1e-20)

In [None]:
pipeline = Pipeline("RemoveBG")

pipeline.add_node(Node("Palette size", integer, fixed_params={"n":8}))
pipeline.add_node(Node("Extract palette", extract_palette, fixed_params={"use_lab":False}), predecessors={"image":"run_params:image", "n_colors":"Palette size"})
pipeline.add_node(Node("Palette indices", ith_subset, fixed_params={"i": 37}), predecessors={"n":"Palette size"})
pipeline.add_node(Node("Recolor", recolor), predecessors={"image":"run_params:image", "palette":"Extract palette"})
pipeline.add_node(Node("Remove palette", remove_palette), predecessors={"image":"run_params:image", "recolored_image":"Recolor", "palette":"Extract palette", "indices_to_remove":"Palette indices"})
pipeline.add_node(Node("To mask", to_mask), predecessors={"image":"Remove palette"})
pipeline.add_node(Node("Isolate", isolate), predecessors={"binary_mask": "To mask"})
pipeline.add_node(Node("Main element", biggest_mask), predecessors={"elements":"Isolate"})
pipeline.add_node(Node("Colored element", colored_mask), predecessors={"image":"run_params:image", "mask":"Main element"})

pipeline.to_dot("RemoveBG.dot", True)
im = np.array(Image.open("RemoveBG.png"))
os.remove("RemoveBG.dot")
os.remove("RemoveBG.png")
plt.imshow(im)

In [None]:
optimizer = PipelineOptimizer(pipeline, loss, 0.01, X, y)
optimizer.add_param(IntParameter("Palette size", "n", 8, 15))
optimizer.add_param(IntParameter("Palette indices", "i", 1, 63))
optimizer.add_param(BoolParameter("Extract palette", "use_lab"))

In [None]:
params_and_log = {}
times = {}
methods = {"ACO":{"iterations":10, "ants":10},
           "SA":{"iterations":50}, 
           "PSO":{"iterations":15, "swarm_size":10}, 
           "GA":{"generations":15, "population_size":20}, 
           "grid_search":{"max_combinations":100}, 
           "BO":{"iterations":25, "init_points":25}}

for i, method in enumerate(methods):
    t = time.time()
    print(f"Method {method} :")
    params_and_log[method] = optimizer.optimize(method=method, verbose=True, **methods[method])
    times[method] = time.time()-t
    print()

In [None]:
plt.figure(figsize=(8,25))
plt.subplot(7, 2, 1)
plt.axis('off')
plt.imshow(X[0]["image"])
plt.subplot(7, 2, 2)
plt.axis('off')
plt.imshow(y[0])

for i, method in enumerate(methods):
    pipeline.set_fixed_params(params_and_log[method][0])
    ind, hist, _ = pipeline.run(X[0])
    plt.subplot(7, 2, 3+2*i)
    plt.gca().set_yscale('log')
    plt.title(method)
    plt.gca().set_ylim(ymin=1, ymax=1e5)
    plt.plot(params_and_log[method][1])

    plt.subplot(7, 2, 4+2*i)
    plt.axis('off')
    plt.title(f"Time: {times[method]:.2f}s")
    plt.imshow(hist[ind], cmap="gray", vmin=0, vmax=255)

In [None]:
for m in params_and_log:
    print(f"{m} :\t{params_and_log[m][0]}")

In [None]:
plt.figure(figsize=(10, 20))
for i in range(62):
    print(f"{i+1}/62", end="\t\r")
    plt.subplot(62//12+1, 12, i+1)
    im = np.array(Image.open(f"images/lolipop/{i}.png"))
    res = pipeline.run({"image":im})
    plt.imshow(res[1][res[0]])

In [None]:
res = pipeline.run(X[0])
plt.imshow(res[1][res[0]])