In [None]:
import os

In [None]:
!python -m ensurepip --default-pip && python -m pip install pillow scikit-learn matplotlib numpy jupyter rembg onnxruntime -q

In [None]:
# data_path = os.path.abspath(os.path.join("artifacts", "data"))
data_path = os.path.join("..", "artifacts", "images")
data_path

In [None]:
img = ("dress", "3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))

In [None]:
from PIL import Image
image = Image.open(image_path)
image 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def display_palette(palette):
    fig, ax1  = plt.subplots(1, 1) 
    
    palette_height = 1
    palette_width = 10
    palette_img = np.zeros((palette_height, palette_width, 3), dtype=np.uint8)
    
    x_start = 0
    for i, (color, percentage) in enumerate(palette):
        if i == len(palette) - 1:
            section_width = palette_width - x_start
        else:
            section_width = int(palette_width * percentage / 100)
        
        palette_img[:, x_start:x_start + section_width] = color
        x_start += section_width
    
    ax1.imshow(palette_img)
    ax1.set_title('Color Palette')
    ax1.axis('off')
    plt.show()

In [None]:
from sklearn.cluster import KMeans
import numpy as np
from collections import Counter

def extract_color_palette_from_image(image, n_colors=5, resize_width=150):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    original_width, original_height = image.size
    aspect_ratio = original_height / original_width
    resize_height = int(resize_width * aspect_ratio)
    image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
    display(image)
    
    pixels = np.array(image).reshape(-1, 3)
    
    kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
    kmeans.fit(pixels)
    
    colors = kmeans.cluster_centers_.astype(int)
    labels = kmeans.labels_
    
    color_counts = Counter(labels)
    color_percentages = [(count / len(labels)) * 100 for count in color_counts.values()]
    
    palette = list(zip(colors, color_percentages))
    palette.sort(key=lambda x: x[1], reverse=True)
    
    return palette

img = ("dress", "3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
display_palette(extract_color_palette_from_image(image, n_colors=5, resize_width=150))

In [None]:
img = ("shoes", "5d147b66-5238-4ed0-9c6f-c68286fea4ad.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
display_palette(extract_color_palette_from_image(image, n_colors=5, resize_width=150))

In [None]:
from rembg import remove
from PIL import Image

img = ("dress", "3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
image_removed_background = remove(image)
display(image_removed_background)

In [None]:
img = ("shoes", "5d147b66-5238-4ed0-9c6f-c68286fea4ad.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
image_removed_background = remove(image)
display(image_removed_background)

### New clustering method for flexible amount of clusters

In [None]:
from sklearn.cluster import AgglomerativeClustering
import numpy as np
from sklearn.preprocessing import StandardScaler
from collections import Counter

def new_extract_method(image, resize_width=150):
    if image.mode != 'RGB':
            image = image.convert('RGB')

    original_width, original_height = image.size
    aspect_ratio = original_height / original_width
    resize_height = int(resize_width * aspect_ratio)
    image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
    display(image)

    pixels = np.array(image).reshape(-1, 3)

    scaler = StandardScaler()
    pixels_scaled = scaler.fit_transform(pixels)

    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=50,
        linkage='ward'
    )
    labels = clustering.fit_predict(pixels_scaled)

    unique_labels = np.unique(labels)
    colors = np.array([pixels[labels == label].mean(axis=0) for label in unique_labels]).astype(int)


    color_counts = Counter(labels)
    color_percentages = [(count / len(labels)) * 100 for count in color_counts.values()]

    palette = list(zip(colors, color_percentages))
    palette.sort(key=lambda x: x[1], reverse=True)
    return palette

img = ("dress", "3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
display_palette(new_extract_method(image))

In [None]:
img = ("shoes", "5d147b66-5238-4ed0-9c6f-c68286fea4ad.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
display_palette(new_extract_method(image))

In [None]:
img = ("dress", "3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
image_removed_background = remove(image, bgcolor=(255, 0, 0))
display_palette(new_extract_method(image_removed_background))

In [None]:
img = ("shoes", "5d147b66-5238-4ed0-9c6f-c68286fea4ad.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
image_removed_background = remove(image)
display_palette(new_extract_method(image_removed_background))

In [None]:
def new_extract_method_including_bg_removal(image, resize_width=150, alpha_threshold=128):
    image = remove(image)

    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    original_width, original_height = image.size
    aspect_ratio = original_height / original_width
    resize_height = int(resize_width * aspect_ratio)
    image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
    display(image)

    pixels_rgba = np.array(image).reshape(-1, 4)
    
    # Getting the correct htreshold
    foreground_mask = pixels_rgba[:, 3] > alpha_threshold

    if not np.any(foreground_mask):
        print("Warning: No foreground pixels found. Lowering alpha threshold.")
        foreground_mask = pixels_rgba[:, 3] > 64  # Fallback threshold
        
    if not np.any(foreground_mask):
        print("Warning: Still no foreground pixels found. Using all non-zero alpha pixels.")
        foreground_mask = pixels_rgba[:, 3] > 0

    pixels = pixels_rgba[foreground_mask][:, :3]

    scaler = StandardScaler()
    pixels_scaled = scaler.fit_transform(pixels)

    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=50,
        linkage='ward'
    )
    labels = clustering.fit_predict(pixels_scaled)

    unique_labels = np.unique(labels)
    colors = np.array([pixels[labels == label].mean(axis=0) for label in unique_labels]).astype(int)

    color_counts = Counter(labels)
    color_percentages = [(count / len(labels)) * 100 for count in color_counts.values()]

    palette = list(zip(colors, color_percentages))
    palette.sort(key=lambda x: x[1], reverse=True)
    return palette

img = ("hat", "ff62e59f-c949-4f8e-8fca-798ab3e174c5.jpg")
image_path = os.path.abspath(os.path.join(data_path, *img))
image = Image.open(image_path)
display_palette(new_extract_method_including_bg_removal(image))

### Check on all images

In [None]:
import glob

category_name = "pants"
folder_path = os.path.abspath(os.path.join(data_path, category_name))
file_names = glob.glob(os.path.join(folder_path, "*.jpg"))
                                           
start_idx = 0
end_idx = 10

for file_name in file_names[start_idx: end_idx]:
    image = Image.open(file_name)
    display_palette(new_extract_method_including_bg_removal(image))

In [None]:
# def get_color_palette_info(image_path, n_colors=5):
#     palette = extract_color_palette(image_path, n_colors)
    
#     result = []
#     for i, (color, percentage) in enumerate(palette):
#         color_info = {
#             'rank': i + 1,
#             'rgb': tuple(color),
#             'hex': rgb_to_hex(color),
#             'percentage': round(percentage, 1)
#         }
#         result.append(color_info)
    
#     return result

# color_palette_info = get_color_palette_info(image_path, n_colors=3)
# color_palette_info

In [None]:
n_colors_to_extract = 4
max_images = 2

category = "shirt"

jpg_files = get_jpg_files(f'{data_path}/{category}')
print(f"Found {len(jpg_files)} JPG files in the folder.")

max_display = min(max_images, len(jpg_files))

for single_img in jpg_files[:max_display]:
    single_img_path = os.path.join(data_path, category, single_img)
    palette = extract_color_palette(single_img_path, n_colors=n_colors_to_extract)
    display_palette(palette)

In [None]:
n_colors_to_extract = 4
max_images = 2

category = "longsleeve"

jpg_files = get_jpg_files(f'{data_path}/{category}')
print(f"Found {len(jpg_files)} JPG files in the folder.")

max_display = min(max_images, len(jpg_files))

for single_img in jpg_files[:max_display]:
    single_img_path = os.path.join(data_path, category, single_img)
    palette = extract_color_palette(single_img_path, n_colors=n_colors_to_extract)
    display_palette(palette)

In [None]:
n_colors_to_extract = 4
max_images = 3

category = "t-shirt"

jpg_files = get_jpg_files(f'{data_path}/{category}')
print(f"Found {len(jpg_files)} JPG files in the folder.")

max_display = min(max_images, len(jpg_files))

for single_img in jpg_files[:max_display]:
    single_img_path = os.path.join(data_path, category, single_img)
    palette = extract_color_palette(single_img_path, n_colors=n_colors_to_extract)
    display_palette(palette)