In [77]:
import os
import colorsys

from PIL import Image
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import clip

In [35]:
from pathlib import Path

state_name = "sac+logos+ava1-l14-linearMSE.pth"
if not Path(state_name).exists():
    url = f"https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/{state_name}?raw=true"
    import requests

    r = requests.get(url)
    with open(state_name, "wb") as f:
        f.write(r.content)

In [44]:
# https://github.com/grexzen/SD-Chad/tree/main
class AestheticPredictor(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        return self.layers(x)


device = "cuda" if torch.cuda.is_available() else "cpu"
state_name = "sac+logos+ava1-l14-linearMSE.pth"
pt_state = torch.load(state_name, map_location=torch.device("cpu"))

# CLIP embedding dim is 768 for CLIP ViT L 14
predictor = AestheticPredictor(768)
predictor.load_state_dict(pt_state)
predictor.to(device)
predictor.eval()

clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)


def get_image_features(
    image, device=device, model=clip_model, preprocess=clip_preprocess
):
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        # l2 normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)
    image_features = image_features.cpu().detach().numpy()
    return image_features


def get_score(image):
    image_features = get_image_features(image)
    score = predictor(torch.from_numpy(image_features).to(device).float())
    return score.item()

  pt_state = torch.load(state_name, map_location=torch.device("cpu"))


In [70]:
def img_to_arr(image):
    image = image.resize((100, 100))

    # Convert the image to RGB
    image = image.convert("RGB")

    # Convert the image to a numpy array
    image_array = np.array(image)

    return image_array


def extract_color_palette(image, n_colors=10, format="HSV"):
    # Reshape the image array to a 2D array of pixels
    pixels = image.reshape(-1, 3)

    # Use KMeans to find the most common colors
    kmeans = KMeans(n_clusters=n_colors)
    kmeans.fit(pixels)
    colors = kmeans.cluster_centers_
    return colors if format == "RGB" else [rgb_to_hsv(color) for color in colors]


def display_color_palette(colors):
    plt.figure(figsize=(10, 2))
    for i, color in enumerate(colors):
        plt.fill_between([i, i + 1], 0, 1, color=color / 255.0)
    plt.xticks([])
    plt.yticks([])
    plt.show()


def rgb_to_hsv(rgb):
    r, g, b = rgb / 255.0
    h, s, v = colorsys.rgb_to_hsv(r, g, b)

    return (round(h * 360, 4), round(s, 4), round(v, 4))

In [79]:
image_paths = os.listdir("image")

data = []

for image_path in tqdm(image_paths):
    image = Image.open("image/" + image_path)
    score = get_score(image)

    img_arr = img_to_arr(image)
    hsvs = extract_color_palette(img_arr, n_colors=9, format="HSV")
    mean_hsv = tuple(np.round(np.mean(hsvs, axis=0), 4))

    data.append([image_path[:-4], round(score, 4), mean_hsv, *hsvs])

100%|██████████| 1510/1510 [13:40<00:00,  1.84it/s]


In [80]:
features = pd.DataFrame(
    data=data,
    columns=[
        "id",
        "Aesthetic_Score",
        "HSV_MEAN",
        "HSV_1",
        "HSV_2",
        "HSV_3",
        "HSV_4",
        "HSV_5",
        "HSV_6",
        "HSV_7",
        "HSV_8",
        "HSV_9",
    ],
)
features.head()

Unnamed: 0,id,Aesthetic_Score,HSV_MEAN,HSV_1,HSV_2,HSV_3,HSV_4,HSV_5,HSV_6,HSV_7,HSV_8,HSV_9
0,348,6.4164,"(170.6655, 0.3543, 0.4293)","(207.0312, 0.4427, 0.3094)","(20.5862, 0.4998, 0.6823)","(241.9402, 0.5852, 0.0874)","(195.2978, 0.0706, 0.5595)","(354.8251, 0.3616, 0.3029)","(268.5475, 0.2214, 0.1578)","(35.3204, 0.3131, 0.8505)","(5.924, 0.4391, 0.4814)","(206.5174, 0.2549, 0.4328)"
1,1804,6.5671,"(212.7679, 0.4648, 0.5424)","(294.5466, 0.2214, 0.318)","(28.0085, 0.4799, 0.8216)","(248.4722, 0.4416, 0.1474)","(335.2901, 0.3448, 0.4817)","(35.6514, 0.3988, 0.9617)","(181.4001, 0.8142, 0.7327)","(191.3901, 0.8362, 0.5014)","(358.8891, 0.3426, 0.6628)","(241.2631, 0.3034, 0.2539)"
2,1810,6.4905,"(83.8641, 0.5575, 0.4876)","(176.7723, 0.8507, 0.3795)","(33.0526, 0.5018, 0.7686)","(3.481, 0.42, 0.0473)","(6.5196, 0.6514, 0.6413)","(38.1707, 0.2284, 0.9399)","(177.6161, 0.9164, 0.5233)","(173.3469, 0.6346, 0.2052)","(135.3847, 0.2159, 0.5382)","(10.4332, 0.5979, 0.3455)"
3,3961,6.8126,"(149.1291, 0.2325, 0.3577)","(194.0966, 0.4491, 0.2413)","(33.0916, 0.1369, 0.4401)","(220.1904, 0.2443, 0.0857)","(194.286, 0.3407, 0.3373)","(190.8234, 0.3034, 0.4352)","(114.2887, 0.0357, 0.702)","(203.2183, 0.1318, 0.1392)","(171.2715, 0.1165, 0.5571)","(20.8952, 0.3338, 0.2814)"
4,1186,5.9921,"(220.3679, 0.5815, 0.4527)","(228.851, 0.6813, 0.1271)","(0.1531, 0.3318, 0.8072)","(317.2909, 0.4176, 0.3766)","(198.3426, 0.7738, 0.5509)","(207.7513, 0.7562, 0.3499)","(185.4368, 0.6794, 0.809)","(343.6052, 0.3973, 0.6004)","(285.474, 0.4683, 0.226)","(216.4059, 0.7275, 0.2272)"


In [81]:
features.to_csv("features.csv", index=False)