# Color Quantization

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin
from sklearn.datasets import load_sample_image
from sklearn.utils import shuffle
from scipy.ndimage import imread
import os

%matplotlib inline

def recreate_image(codebook, labels, w, h):
    """Recreate the (compressed) image from the code book & labels"""
    d = codebook.shape[1]
    image = np.zeros((w, h, d))
    label_idx = 0
    for i in range(w):
        for j in range(h):
            image[i][j] = codebook[labels[label_idx]]
            label_idx += 1
    return image

def transformPhotos(album, n_colours):
    count = 0
    
    for root, dirs, filenames in os.walk(album):
        for pic in filenames:
            count += 1
            # Convert to floats instead of the default 8 bits integer coding. Dividing by
            # 255 is important so that plt.imshow behaves works well on float data (need to
            # be in the range [0-1]
            img = imread(os.path.join(root, pic))
            jpg = np.array(img, dtype=np.float64) / 255

            # Load Image and transform to a 2D numpy array.
            w, h, d = tuple(jpg.shape)
            image_array = np.reshape(jpg, (w * h, d))

            image_array_sample = shuffle(image_array, random_state=0)[:1000]
            kmeans = KMeans(n_clusters=n_colours, random_state=0).fit(image_array_sample)

            # Get labels for all points
            labels = kmeans.predict(image_array)

            quantized_image = recreate_image(kmeans.cluster_centers_, labels, w, h)
            quantized_rgb = (kmeans.cluster_centers_ * 255).tolist()
            quantized_rgb = [[int(i) for i in j] for j in  quantized_rgb]
            colours.append(quantized_rgb)
    return colours

In [2]:
colours = []
photos = './photos'
transformPhotos(photos, 7)
print(colours)

[[[240, 246, 250], [90, 78, 69], [170, 152, 132], [205, 188, 169], [43, 33, 28], [227, 213, 195], [137, 117, 98]], [[155, 138, 131], [39, 35, 39], [183, 169, 166], [85, 69, 65], [227, 224, 225], [205, 199, 199], [126, 107, 98]], [[20, 15, 15], [204, 207, 194], [87, 67, 51], [244, 246, 242], [113, 105, 91], [57, 36, 24], [155, 157, 141]], [[177, 159, 148], [35, 29, 33], [228, 227, 232], [77, 66, 67], [146, 127, 118], [208, 195, 188], [118, 100, 95]], [[205, 202, 210], [46, 35, 31], [134, 110, 90], [202, 180, 158], [100, 68, 52], [183, 144, 109], [233, 215, 192]], [[211, 202, 193], [90, 66, 55], [158, 140, 125], [38, 30, 30], [242, 238, 231], [133, 113, 96], [185, 173, 157]], [[229, 214, 201], [70, 59, 50], [139, 126, 120], [19, 18, 19], [104, 93, 87], [184, 169, 161], [50, 43, 35]], [[127, 129, 124], [30, 26, 26], [187, 197, 216], [66, 57, 55], [155, 158, 166], [246, 246, 246], [102, 99, 94]], [[122, 87, 59], [29, 9, 5], [183, 127, 69], [133, 70, 13], [238, 224, 202], [209, 166, 118], [