-
Notifications
You must be signed in to change notification settings - Fork 8
/
palette.py
141 lines (111 loc) 路 4.09 KB
/
palette.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Available palettes
import csv
import base64
import scipy.spatial as sp
from PIL import Image, ImageColor
from ih.helpers import base_path, hex2rgb, rgb2hex
PALETTES = []
PALETTE_DIR = base_path("palettes")
for f in PALETTE_DIR.glob("*.txt"):
PALETTES.append(f.stem)
# Palette overrides (emoji)
PALETTE_OVERRIDE = {"馃У": "floss", "馃Ф": "wool", "馃": "alpaca"}
PALETTE_DEFAULT = "wool"
THREAD_DEFAULT = "wool.png"
THREAD_OVERRIDE = {}
for p in PALETTES:
img = base_path("styling").joinpath(f"{p}.png")
if img.exists():
THREAD_OVERRIDE[p] = img
if "floss" in p:
THREAD_OVERRIDE[p] = base_path("styling").joinpath("floss.png")
DEFAULT_IDENTITY = "stitches"
IDENTITY_OVERRIDE = {"lego": "bricks", "perler": "beads"}
# Return the name of the identity. E.g, floss is stitches, lego is bricks
def get_identity_name(palette_name):
if palette_name in IDENTITY_OVERRIDE.keys():
return IDENTITY_OVERRIDE[palette_name]
else:
return DEFAULT_IDENTITY
# Return the location of the image for the mock representation of the thread.
def get_thread_image_path(palette_name):
thread_image = THREAD_DEFAULT
if palette_name in THREAD_OVERRIDE:
return THREAD_OVERRIDE[palette_name]
return str(base_path("styling").joinpath(thread_image))
def get_thread_image(palette_name):
path = get_thread_image_path(palette_name)
with open(path, "rb") as f:
image = str(base64.b64encode(f.read())).strip("b'").strip("'")
return f"data:image/png;base64,{image}"
def get_palette(palette_name):
if palette_name in PALETTE_OVERRIDE.keys():
palette_name = PALETTE_OVERRIDE[palette_name]
if palette_name not in PALETTES:
raise ValueError(
"Invalid palette: %s. Choices: %s" % (palette_name, ", ".join(PALETTES))
)
palette = []
with open(PALETTE_DIR.joinpath(f"{palette_name}.txt")) as f:
data = csv.reader(f, delimiter=",")
for line in data:
code, h = line
rgb = hex2rgb(h)
palette.append({"rgb": rgb, "hex": h, "code": code})
return palette
# This is a fun hack.
#
# Since we are limited to a colour palette of 256, but we might have >256
# colours to pick from, let's make sure that all the closest colours get
# selected!
#
# To do that, we grab 256 colours from the provided image, and then walk
# our palette to pick the all the closest colour matches. We then stick those
# at the front of our palette so they won't be truncated off in
# get_palette_image below.
def reduce_palette(palette, image):
# No-op if palette is smol
if len(palette) <= 256:
return palette
print("Evaluating best colour selection...")
palette_triplets = [x["rgb"] for x in palette]
best_colours = set()
# Get image palette in RGB triplets
my_colours = [x[0:3] for x in image.getdata()]
# Get nearest colour https://stackoverflow.com/a/22478139
tree = sp.KDTree(palette_triplets)
for colour in my_colours:
_, result = tree.query(colour)
best_colours.add(rgb2hex(palette_triplets[result]))
# Stick best_colours at the front of our palette
first_colours = []
for item in palette:
if item["hex"] in best_colours:
first_colours.append(item)
first_colours += palette
print("...Done")
return first_colours
# get a base image that had the palette we want
# Math involved:
# putpalette requires a flattened list of up to 256 triples for RGB
# so:
# * get our RGB values
# * flatten
# * pad the list to 256 triples, if required
# * then cut the list, if required.
# The result will always be length 256 * 3
# Math: https://stackoverflow.com/a/55755789/124019
def get_palette_image(palette):
data = (
sum([x["rgb"] for x in palette], [])
+ (palette[-1]["rgb"] * (256 - len(palette)))
)[: 256 * 3]
image = Image.new("P", (1, 1))
image.putpalette(data)
return image
def thread_name(rgb, palette):
for t in palette:
if tuple(t["rgb"]) == rgb:
return t
## Return a basic thread type if thread not found in palette
return {"code": str(rgb)}