In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
%matplotlib inline
plt.style.use('ggplot')

In [None]:
def count_colors(img):
    return len(set([tuple(img[i, j, :]) for i in range(img.shape[0]) for j in range(img.shape[1])]))

In [None]:
def create_bit_encoding(bits=8):
    n_colors = int(2**bits)
    step_size = int(256/n_colors)
    image = np.zeros((256, 256, 3), np.uint8)
    image[:,0:100,:] = (255, 255, 255)
    for i in range(0, image.shape[1], step_size):
        image[:, i:i+step_size ,:] = (0, i, 0)
    return image

In [None]:
# plt.imshow(create_bit_encoding(bits=8))
plt.imshow(create_bit_encoding(bits=5))

In [None]:
def reduce_color(img, k=8):
    Z = img.reshape((-1,3))
    Z = np.float32(Z)

    # define criteria, number of clusters(K) and apply kmeans()
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    ret,label,center=cv2.kmeans(Z,k,None,criteria,10,cv2.KMEANS_RANDOM_CENTERS)

    # Now convert back into uint8, and make original image
    center = np.uint8(center)
    res = center[label.flatten()]
    res2 = res.reshape((img.shape))

    return res2


In [None]:
img = cv2.imread('../PGMData/imgs/schilderij_3527.jpg')
print(img.shape)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# plt.imshow(img)
rimg = reduce_color(img, k=8)
# plt.imshow(rimg)
plt.imshow(rimg)
print(count_colors(img))
print(count_colors(rimg))

In [None]:
def color_quantization(img, k=8):
    rows, cols = img.shape[:2]
    img = img.reshape(-1, 3)
    kmeans = KMeans(n_clusters=k).fit(img)
    labels = kmeans.predict(img)
    labels = kmeans.cluster_centers_[labels]
    return labels.reshape(rows, cols, 3).astype(int)

In [None]:
img2 = color_quantization(img, k=8)
plt.imshow(img2)
count_colors(img2)
plt.imshow(img)

In [None]:
def color_quantization_gmm(img, k=8):
    rows, cols = img.shape[:2]
    img = img.reshape(img.shape[0]*img.shape[1], img.shape[2])
    gmm = GaussianMixture(n_components=k).fit(img)
    labels = gmm.predict(img)
    labels = gmm.means_[labels]
    return labels.reshape(rows, cols, 3).astype(int)

In [None]:
img2 = color_quantization_gmm(img, k=8)
plt.imshow(img2)

In [None]:
plt.imshow(img)

In [None]:
# from mpl_toolkits.mplot3d import Axes3D
# tmp = img.reshape(img.shape[0]*img.shape[1], img.shape[2])

# fig = plt.figure()
# fig.set_figheight(16)
# fig.set_figwidth(16)
# ax = fig.add_subplot( 111,  aspect = "equal", projection="3d")
# ax.scatter(tmp[:, 0], tmp[:, 1], tmp[:, 2])

In [None]:
print(img.shape)

In [None]:
from ColorQuantization import ColorQuantization

In [None]:
cq = ColorQuantization(n_colors=8)

In [None]:
cq._reshape([img]).shape

In [None]:
cq.fit([img])

In [None]:
new_img = cq.colorize([img])[0]
cq.plot_pixels(img, title="8 Colors")

In [None]:
import matplotlib

In [None]:
def plot_pixels(data, title, colors=None, N=10000):
    if len(data.shape) == 3:
        data = data.reshape(-1, 3) / 255
    if colors is not None and len(colors.shape) == 3:
        colors = colors.reshape(-1, 3) / 255
    
    i = np.random.permutation(data.shape[0])[:N]
    data = data[i]
    R = data[:, 0]
    G = data[:, 1]
    B = data[:, 2]
    

    c_rng = range(N)
    
    if colors is None:
        colors = [tuple(data[j]) for j in range(data.shape[0])]
    else:
        colors = colors[i]
    cmap = matplotlib.colors.ListedColormap(colors)
    
    fig, ax = plt.subplots(1, 2, figsize=(16, 6))
    ax[0].scatter(R, G, c=c_rng, cmap=cmap, marker='.')
    ax[0].set(xlabel='Red', ylabel='Green', xlim=(0, 1), ylim=(0, 1))

    ax[1].scatter(R, B, c=c_rng, cmap=cmap, marker='.')
    ax[1].set(xlabel='Red', ylabel='Blue', xlim=(0, 1), ylim=(0, 1))

    fig.suptitle(title, size=20);

In [None]:
img = cv2.cvtColor(cv2.imread('china.png'), cv2.COLOR_BGR2RGB)
plot_pixels(img, title=f"Input color space: {count_colors(img)} unique colors")

In [None]:
def plot_img(img, recolored_img):
    fig, ax = plt.subplots(2, 1, figsize=(8, 8),
                           subplot_kw=dict(xticks=[], yticks=[]))
    fig.subplots_adjust(wspace=0.05)
    ax[0].imshow(img)
    ax[0].set_title('Original Image', size=16)
    ax[1].imshow(recolored_img)
    ax[1].set_title(f'{8}-color Image', size=16);

cq = ColorQuantization(n_colors=8)
cq.fit([img])
recolor = cq.colorize([img])[0]
print(count_colors(img), count_colors(recolor))


In [None]:
plot_img(img, recolor)


In [None]:
plot_pixels(img, title="Reduced color space: 8 colors", colors=recolor)