# Color Compression

- **Student Info:** 21CLC05 - 21127135 - Diep Huu Phuc
- **Github Repository:** https://github.com/kru01/ColorCompression_K-means

## Import libraries and essential functions

- My implementation of K-means does **NOT** guarantee that each cluster will contain at least one pixel, i.e., ***empty clusters can exist***.
  - An empty cluster indicates that ***its centroid has no pixel sharing its color***, and the mean cannot be computed. ***Such centroids should be discarded to not affect further calculations***.
  - Therefore, ***we will not always get the number of clusters `k_clusters` we actually want in the output.*** This is much more prevalent when dealing with `init_centroids=random`.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

k_clusters_defaults = [3, 5, 7]
init_centroids_defaults = ['random', 'in_pixels']
max_iter = 1000
convergence_threshold = 0.2
outfile_types = ['.png', '.pdf']

def kmeans(img_1d:np.ndarray, k_clusters:int, max_iter:int, init_centroids:str='random'):
    '''
    Inputs:
        img_1d: np.ndarray with shape=(height * width, num_channels)
            Original image in 1d array

        k_clusters: int representing the number of clusters
        max_iter: int representing the iteration limit

        init_centroids: str describing how to init centroids
            'random' --> Each centroid has c channels, the value of each channel is a random in [0,255]
            'in_pixels' --> Each centroid is a random pixel of the original image

    Outputs:
        centroids: np.ndarray with shape=(k_clusters, num_channels)
            Store color centroids
            
        labels: np.ndarray with shape=(height * width, num_channels)
            Store label for pixels (cluster's index to which the pixel belongs)
    '''
    centroids = np.zeros((k_clusters, img_1d.shape[1]))

    if init_centroids == 'random':
        while len(np.unique(centroids, axis=0)) != k_clusters:
            centroids = np.random.choice(256, (k_clusters, img_1d.shape[1]))
    else: centroids = img_1d[np.random.choice(len(img_1d), k_clusters, False)]

    for _ in range(max_iter):
        distances = np.linalg.norm(img_1d - centroids[:, None], axis=2)
        labels = np.argmin(distances, axis=0)

        prev_centroids = centroids.copy()
        centroids = []

        for i in range(k_clusters):
            points = img_1d[labels == i]
            if len(points): centroids.append(np.mean(points, axis=0))

        k_clusters = len(centroids)
        centroids = np.array(centroids)
        try:
            if np.allclose(prev_centroids, centroids, atol=convergence_threshold):
                break
        except ValueError:
            if np.allclose(prev_centroids, centroids[:, None], atol=convergence_threshold):
                break

    return centroids, labels

def reconstruct_img(img_1d:np.ndarray, og_shape:np.ndarray.shape, centroids:np.ndarray, labels:np.ndarray):
    for i in range(len(centroids)): img_1d[labels == i] = centroids[i]
    img_1d = img_1d.reshape(og_shape)
    return img_1d

def compress_colors(img:str, k_clusters:int=0, init_centroids:str=""):
    img = np.asarray(img).copy()
    og_shape = img.shape
    img = img.reshape(og_shape[0] * og_shape[1], og_shape[2])
    out_imgs = []
    out_k_clusters = []

    if k_clusters > 0: k_clusters_arr = [k_clusters]
    else: k_clusters_arr = k_clusters_defaults
    if init_centroids: init_centroids_arr = [init_centroids]
    else: init_centroids_arr = init_centroids_defaults

    for ic in init_centroids_arr:
        for kc in k_clusters_arr:
            centroids, labels = kmeans(img, kc, max_iter, ic)
            out_k_clusters.append(len(centroids))
            out_imgs.append(reconstruct_img(img.copy(), og_shape, centroids, labels))
    return out_imgs, out_k_clusters

## Main program handling interfaces, inputs and outputs

- **Inputs are handled very extensively**. So long as the `img` file is correct, the program will safely execute.
  - **`img.jpg`** - All available renditions of the img will be output, i.e., the img will be run through the 3 default `k_clusters`s, each one has 2 `init_centroids` methods.
  - **`img.png 10`** - All renditions of the img with `k_clusters=10` and each default `init_centroids` method will be output.
  - **`img.jpg 7 1`** - The rendition of the img with `k_clusters=7` and `init_centroids=in_pixels` will be output.
- As for saving, 2 supported filetypes are `.png` and `.pdf`.
  - *To opt out of saving, either leave the input `blank` or mess it up.*

In [None]:
def main():
    inp = input('''Input img, k_clusters, init_centroids (0 - random, 1 - in_pixels).
If k_clusters or init_centroids is omitted, all default renditions of the img will be shown!
Examples: img.jpg 7 0 -OR- img.png 5 -OR- img.jpg
--> ''')

    filename, *clus_cent = inp.split(" ")
    try: img = Image.open(filename)
    except:
        print("Invalid file!")
        return
    
    try: k_clusters = int(clus_cent[0])
    except: k_clusters = 0
    try: init_centroids = int(clus_cent[1])
    except: init_centroids = -1

    if init_centroids < 0 or init_centroids > len(init_centroids_defaults) - 1:
        init_centroids = ""
    else: init_centroids = init_centroids_defaults[init_centroids]

    outfile_type = input("Save output images as (0 - .png, 1 - .pdf) --> ")
    try: outfile_type = outfile_types[int(outfile_type)]
    except: outfile_type = 0

    out_imgs, out_k_clusters = compress_colors(img, k_clusters, init_centroids)

    if k_clusters and init_centroids:
        fig, axis = plt.subplots(1, 2, figsize=(12, 7))
        axis[0].set_title(filename)
        axis[0].imshow(img)
        axis[1].set_title(f'k = {out_k_clusters[0]}, init_centroids = {init_centroids}')
        axis[1].imshow(out_imgs[0])
        plt.tight_layout()
        plt.show()
        if outfile_type:
            Image.fromarray(out_imgs[0]).save(f'{filename.split(".")[0]}_k{k_clusters}_{init_centroids}{outfile_type}')
        return

    plt.title(filename)
    plt.imshow(img)
    plt.show()
    filename = filename.split('.')[0]

    if k_clusters:
        fig, axis = plt.subplots(1, len(init_centroids_defaults), figsize=(12, 7))
        for i, ic in enumerate(init_centroids_defaults):
            axis[i].set_title(f'k = {out_k_clusters[i]}, init_centroids = {ic}')
            axis[i].imshow(out_imgs[i])
            if outfile_type:
                Image.fromarray(out_imgs[i]).save(f'{filename}_k{out_k_clusters[i]}_{ic}{outfile_type}')
        plt.tight_layout()
        plt.show()
        return

    fig, axis = plt.subplots(len(init_centroids_defaults), len(k_clusters_defaults), figsize=(12, 7))
    k_clus_ind = 0
    for i, ic in enumerate(init_centroids_defaults):
        for j in range(len(k_clusters_defaults)):
            axis[i][j].set_title(f'k = {out_k_clusters[k_clus_ind]}, init_centroids = {ic}')
            axis[i][j].imshow(out_imgs[k_clus_ind])
            if outfile_type:
                Image.fromarray(out_imgs[k_clus_ind]).save(f'{filename}_k{out_k_clusters[k_clus_ind]}_{ic}{outfile_type}')
            k_clus_ind += 1
    plt.tight_layout()
    plt.show()
    img.close()

main()

## References

- https://en.wikipedia.org/wiki/K-means_clustering
- https://youtu.be/4b5d3muPQmA
- https://towardsdatascience.com/create-your-own-k-means-clustering-algorithm-in-python-d7d4c9077670
- https://www.geeksforgeeks.org/how-to-convert-images-to-numpy-array/
- https://numpy.org/doc/stable/reference/generated/numpy.unique.html
- https://numpy.org/doc/stable/user/basics.broadcasting.html
- https://youtu.be/oG1t3qlzq14
- https://stackoverflow.com/questions/1401712/how-can-the-euclidean-distance-be-calculated-with-numpy
- https://towardsdatascience.com/the-concept-of-masks-in-python-50fd65e64707
- https://stackoverflow.com/questions/10580676/comparing-two-numpy-arrays-for-equality-element-wise
- https://stackoverflow.com/questions/49643907/clipping-input-data-to-the-valid-range-for-imshow-with-rgb-data-0-1-for-floa
- https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots