# 1. K-Means Clustering

In [None]:
import os
import cv2
import numpy as np
from sklearn.cluster import KMeans
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

def get_dominant_colors_kmeans(image, n_colors=5, resize=128):
    image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    image = cv2.resize(image, (resize, resize))
    pixels = image.reshape(-1, 3)
    kmeans = KMeans(n_clusters=n_colors, random_state=42)
    kmeans.fit(pixels)
    counts = Counter(kmeans.labels_)
    sorted_labels = [label for label, _ in counts.most_common(n_colors)]
    dominant_colors = [kmeans.cluster_centers_[label].astype(int) for label in sorted_labels]
    return dominant_colors

def plot_palette(colors, title):
    plt.figure(figsize=(5, 1))
    palette = np.zeros((50, 50 * len(colors), 3), dtype=np.uint8)
    for i, color in enumerate(colors):
        palette[:, i*50:(i+1)*50, :] = color
    plt.imshow(palette[..., ::-1])  # BGR → RGB
    plt.axis('off')
    plt.title(title)
    plt.show()

def save_palette_as_svg(colors, filename, title_text=None):
    n = len(colors)
    width = 60
    height = 60
    text_height = 30 if title_text else 0  # 텍스트 공간 확보
    svg_height = height + text_height

    rects = ""
    for i, color in enumerate(colors):
        r, g, b = int(color[2]), int(color[1]), int(color[0])
        hex_color = f'#{r:02x}{g:02x}{b:02x}'
        rects += f'<rect x="{i*width}" y="{text_height}" width="{width}" height="{height}" fill="{hex_color}"/>'

    text_tag = ""
    if title_text:
        font_size = 20
        text_x = (n * width) // 2
        text_y = 22  # 텍스트 y 위치 (SVG 상단)
        text_tag = f'<text x="{text_x}" y="{text_y}" font-size="{font_size}" text-anchor="middle" font-family="Arial" fill="#222">{title_text}</text>'

    svg = f'''<svg width="{n*width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">{rects}{text_tag}</svg>'''
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        f.write(svg)

base_dir = './preprocessed_google_images'
style_folders = [
    'antique_interior', 'modern_interior', 'natural_interior', 
    'northern_european_interior', 'romantic_interior', 
    'traditional_korean_style_interior', 'vintage_interior'
]

result_base_dir = 'kmeans_EDA'

style_palette_dict = {}

for style in style_folders:
    folder_path = os.path.join(base_dir, style)
    all_dominant_colors = []
    for fname in tqdm(os.listdir(folder_path), desc=style):
        fpath = os.path.join(folder_path, fname)
        try:
            image = Image.open(fpath).convert('RGB')
            colors = get_dominant_colors_kmeans(image, n_colors=5)
            all_dominant_colors.extend(colors)
        except Exception as e:
            print(f"Error with file {fpath}: {e}")
    if all_dominant_colors:
        color_tuples = [tuple(c) for c in all_dominant_colors]
        common_colors = [np.array(color) for color, _ in Counter(color_tuples).most_common(5)]
        style_palette_dict[style] = common_colors

        # 1. 팔레트 출력
        plot_palette(common_colors, f"KMeans_{style}")

        # 2. SVG 파일 저장 (EDA 폴더에 바로 저장)
        svg_filename = os.path.join(result_base_dir, f'kmeans_{style}.svg')
        save_palette_as_svg(common_colors, svg_filename, title_text=f'kmeans_{style}.svg')
        print(f"SVG saved: {svg_filename}")

# 2. Mean-Shift Clustering

In [None]:
import os
import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

def get_dominant_colors_meanshift(image, max_colors=5, resize=128, quantile=0.1, sample_size=500):
    image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    image = cv2.resize(image, (resize, resize))
    pixels = image.reshape(-1, 3)
    bandwidth = estimate_bandwidth(pixels, quantile=quantile, n_samples=sample_size)
    if bandwidth < 1:
        bandwidth = 1  # bandwidth가 너무 작으면 최소값 지정
    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(pixels)
    labels = ms.labels_
    cluster_centers = ms.cluster_centers_
    counts = Counter(labels)
    # 등장 빈도 순으로 상위 max_colors 색상
    sorted_labels = [label for label, _ in counts.most_common(max_colors)]
    dominant_colors = [cluster_centers[label].astype(int) for label in sorted_labels]
    # max_colors보다 클러스터가 적으면 첫 색상으로 채움
    while len(dominant_colors) < max_colors:
        dominant_colors.append(dominant_colors[0])
    return dominant_colors

def plot_palette(colors, title):
    plt.figure(figsize=(5, 1))
    palette = np.zeros((50, 50 * len(colors), 3), dtype=np.uint8)
    for i, color in enumerate(colors):
        palette[:, i*50:(i+1)*50, :] = color
    plt.imshow(palette[..., ::-1]) # BGR → RGB
    plt.axis('off')
    plt.title(title)
    plt.show()

def save_palette_as_svg(colors, filename, title_text=None):
    n = len(colors)
    width = 60
    height = 60
    text_height = 30 if title_text else 0
    svg_height = height + text_height

    rects = ""
    for i, color in enumerate(colors):
        r, g, b = int(color[2]), int(color[1]), int(color[0])
        hex_color = f'#{r:02x}{g:02x}{b:02x}'
        rects += f'<rect x="{i*width}" y="{text_height}" width="{width}" height="{height}" fill="{hex_color}"/>'

    text_tag = ""
    if title_text:
        font_size = 20
        text_x = (n * width) // 2
        text_y = 22
        text_tag = f'<text x="{text_x}" y="{text_y}" font-size="{font_size}" text-anchor="middle" font-family="Arial" fill="#222">{title_text}</text>'

    svg = f'''<svg width="{n*width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">{rects}{text_tag}</svg>'''
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        f.write(svg)

base_dir = './preprocessed_google_images'
style_folders = [
    'antique_interior', 'modern_interior', 'natural_interior', 
    'northern_european_interior', 'romantic_interior', 
    'traditional_korean_style_interior', 'vintage_interior'
]

result_base_dir = 'meanshift_EDA'

style_palette_dict = {}

for style in style_folders:
    folder_path = os.path.join(base_dir, style)
    all_dominant_colors = []
    for fname in tqdm(os.listdir(folder_path), desc=style):
        fpath = os.path.join(folder_path, fname)
        try:
            image = Image.open(fpath).convert('RGB')
            colors = get_dominant_colors_meanshift(image, max_colors=5)
            all_dominant_colors.extend(colors)
        except Exception as e:
            print(f"Error with file {fpath}: {e}")
    if all_dominant_colors:
        color_tuples = [tuple(c) for c in all_dominant_colors]
        common_colors = [np.array(color) for color, _ in Counter(color_tuples).most_common(5)]
        style_palette_dict[style] = common_colors

        # 1. 팔레트 출력
        plot_palette(common_colors, f"MeanShift_{style}")

        # 2. SVG 파일 저장 (EDA 폴더에 바로 저장)
        svg_filename = os.path.join(result_base_dir, f'meanshift_{style}.svg')
        save_palette_as_svg(common_colors, svg_filename, title_text=f'meanshift_{style}.svg')
        print(f"SVG saved: {svg_filename}")

# 3. DBSCAN

In [None]:
import os
import cv2
import numpy as np
from sklearn.cluster import DBSCAN
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

def get_dominant_colors_dbscan(image, max_colors=5, resize=128, eps=7, min_samples=100):
    image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    image = cv2.resize(image, (resize, resize))
    pixels = image.reshape(-1, 3)
    db = DBSCAN(eps=eps, min_samples=min_samples)
    labels = db.fit_predict(pixels)

    # Noise(-1) 제거
    mask = labels != -1
    if not np.any(mask):
        return [pixels[0]] * max_colors  # 모두 노이즈면 첫 픽셀 반복

    filtered_labels = labels[mask]
    filtered_pixels = pixels[mask]
    counts = Counter(filtered_labels)
    sorted_labels = [label for label, _ in counts.most_common(max_colors)]

    dominant_colors = []
    for label in sorted_labels:
        color = filtered_pixels[filtered_labels == label].mean(axis=0).astype(int)
        dominant_colors.append(color)

    # 부족할 경우 첫 색상으로 채움
    while len(dominant_colors) < max_colors:
        dominant_colors.append(dominant_colors[0])

    return dominant_colors

def plot_palette(colors, title):
    plt.figure(figsize=(5, 1))
    palette = np.zeros((50, 50 * len(colors), 3), dtype=np.uint8)
    for i, color in enumerate(colors):
        palette[:, i*50:(i+1)*50, :] = color
    plt.imshow(palette[..., ::-1])  # BGR → RGB
    plt.axis('off')
    plt.title(title)
    plt.show()

def save_palette_as_svg(colors, filename, title_text=None):
    n = len(colors)
    width = 60
    height = 60
    text_height = 30 if title_text else 0
    svg_height = height + text_height

    rects = ""
    for i, color in enumerate(colors):
        r, g, b = int(color[2]), int(color[1]), int(color[0])
        hex_color = f'#{r:02x}{g:02x}{b:02x}'
        rects += f'<rect x="{i*width}" y="{text_height}" width="{width}" height="{height}" fill="{hex_color}"/>'

    text_tag = ""
    if title_text:
        font_size = 20
        text_x = (n * width) // 2
        text_y = 22
        text_tag = f'<text x="{text_x}" y="{text_y}" font-size="{font_size}" text-anchor="middle" font-family="Arial" fill="#222">{title_text}</text>'

    svg = f'''<svg width="{n*width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">{rects}{text_tag}</svg>'''
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        f.write(svg)

base_dir = './preprocessed_google_images'
style_folders = [
    'antique_interior', 'modern_interior', 'natural_interior', 
    'northern_european_interior', 'romantic_interior', 
    'traditional_korean_style_interior', 'vintage_interior'
]

result_base_dir = 'dbscan_EDA'

style_palette_dict = {}

for style in style_folders:
    folder_path = os.path.join(base_dir, style)
    all_dominant_colors = []
    for fname in tqdm(os.listdir(folder_path), desc=style):
        fpath = os.path.join(folder_path, fname)
        try:
            image = Image.open(fpath).convert('RGB')
            colors = get_dominant_colors_dbscan(image, max_colors=5)
            all_dominant_colors.extend(colors)
        except Exception as e:
            print(f"Error with file {fpath}: {e}")
    if all_dominant_colors:
        color_tuples = [tuple(c) for c in all_dominant_colors]
        common_colors = [np.array(color) for color, _ in Counter(color_tuples).most_common(5)]
        style_palette_dict[style] = common_colors

        # 1. 팔레트 출력
        plot_palette(common_colors, f"DBSCAN_{style}")

        # 2. SVG 파일 저장 (EDA 폴더에 바로 저장)
        svg_filename = os.path.join(result_base_dir, f'dbscan_{style}.svg')
        save_palette_as_svg(common_colors, svg_filename, title_text=f'dbscan_{style}.svg')
        print(f"SVG saved: {svg_filename}")