In [None]:
import os
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
from collections import Counter

def rgb_to_hex(rgb):
    """
    Converts an RGB color tuple to a hexadecimal color code.
    """
    return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]), int(rgb[1]), int(rgb[2]))

def get_color_name(rgb):
    """
    Determines the name of a color based on the RGB values.
    """
    r, g, b = rgb
    if r > 200 and g > 200 and b > 200:
        return "White"
    elif r < 50 and g < 50 and b < 50:
        return "Black"
    elif r > max(g, b) + 20:
        return "Red"
    elif g > max(r, b) + 20:
        return "Green"
    elif b > max(r, g) + 20:
        return "Blue"
    elif r > 180 and g > 180 and b < 100:
        return "Yellow"
    elif r > 180 and g < 100 and b > 180:
        return "Purple"
    elif g > 180 and b > 180 and r < 100:
        return "Cyan"
    elif max(r, g, b) - min(r, g, b) < 30:
        return "Gray"
    else:
        return "Other"

# Path to the image folder
image_folder = '/home/data/shangliujun/stable-diffusion-webui/outputs/backgroud_REMOVE'

# Iterate through all files in the folder
for filename in os.listdir(image_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(image_folder, filename)
        
        # Open the image and convert it to a NumPy array
        with Image.open(image_path) as img:
            img = img.convert('RGB')
            img_array = np.array(img)

        # Reshape the array for K-means clustering
        pixels = img_array.reshape((-1, 3))

        # Perform K-means clustering to find the main colors (10 in this case)
        kmeans = KMeans(n_clusters=10, random_state=42, n_init=10)
        kmeans.fit(pixels)

        # Get the cluster centers (main colors)
        colors = kmeans.cluster_centers_

        # Count the number of pixels in each cluster
        labels = kmeans.labels_
        color_counts = Counter(labels)

        # Total number of pixels
        total_pixels = pixels.shape[0]

        # Calculate the percentage of each color and classify them
        color_ratio = {}
        for i, color in enumerate(colors):
            count = color_counts[i]
            ratio = count / total_pixels
            color_name = get_color_name(color)
            if color_name in color_ratio:
                color_ratio[color_name] += ratio
            else:
                color_ratio[color_name] = ratio

        # Construct the path for the output text file
        txt_file_path = os.path.splitext(image_path)[0] + '_color_distribution.txt'
        
        # Write the color percentages to the text file
        with open(txt_file_path, 'w', encoding='utf-8') as txt_file:
            for color, ratio in sorted(color_ratio.items(), key=lambda x: x[1], reverse=True):
                txt_file.write(f"{color} Percentage: {ratio:.2%}\n")
        
        print(f"Processed: {filename}")

print("All images have been processed.")