# Color Compression

## Library inclusion

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# --- Python native module
from typing import Callable, Any # type hint lib
import os # path
from time import process_time   # test evaluation
import csv # exporting test results
from pathlib import Path # mkdir
import multiprocessing # running parrellel tests
from itertools import repeat # packing args

## Helper functions

In [None]:
def read_img(img_path):
    '''
    Read image from img_path

    Parameters
    ----------
    img_path : str
        Path of image

    Returns
    -------
        np.ndarray, shape(W, H, 3)
    '''
    return np.asarray(Image.open(img_path).convert("RGB"), dtype=np.float64)

def show_img(img_2d, caption_: str = ""):
    '''
    Show image

    Parameters
    ----------
    img_2d : 
        Image (2D)

    Returns
    ----------
        None
    '''

    # normalize data for img showing
    normalized_2d = img_2d.astype(np.float64) / 255
    img = plt.imshow(normalized_2d)
    img.axes.axis("off")
    img.axes.set_title(caption_)
    plt.show()


def save_img(img_2d, img_path):
    '''
    Save image to img_path

    Parameters
    ----------
    img_2d : np.ndarray, shape (W, H, 3)
        Image (2D)
    img_path : str
        Path of image

    Returns
    ----------
        None
    '''
    Image.fromarray(np.uint8(img_2d), "RGB").save(img_path)


def convert_img_to_1d(img_2d: np.ndarray):
    '''
    Convert 2D image to 1D image

    Parameters
    ----------
    img_2d : np.ndarray, shape (W, H, 3)
        Image (2D)

    Returns
    -------
        np.ndarray, shape(N, 3)
    '''
    return img_2d.reshape((img_2d.shape[0]*img_2d.shape[1], img_2d.shape[2]))


# Random initialization
def kmeans_rand_init(img_1d: np.ndarray, k: int, random_state: Any = 0):
    '''
    Randomly choosing a starting point in the 8-bit RGB color space.
    random_state is ignored and present only for ease of access compliance.

    Parameters
    ----------
    img_1d : np.ndarray, shape (N, 3)
        The input image
    k : int
        The number of centroids.
    
    Returns
    ----------
        np.ndarray((3), dtype=np.float64)
    '''
    return np.asarray([[np.random.randint(0, 256) for _ in range(img_1d.shape[1])] for _ in range(k)], dtype=np.float64)

# Sampling k data points as centroids using uniformly-distributed propability
def kmeans_img_rand_init(img_1d: np.ndarray, k: int, random_state: Any = 0):
    '''
    Randomly choosing a starting point, using the available colors in the input image.
    random_state is ignored and present only for ease of access compliance.

    Parameters
    ----------
    img_1d : np.ndarray, shape (N, 3)
        The input image.
    k : int
        The number of centroids.
        
    Returns
    ----------
        np.ndarray((3), dtype=np.float64)
    '''
    return np.random.default_rng().choice(img_1d, k, False)

# Kmeans++: Sampling based on its probability contribution to overall potential (phi)
def kmeans_pp_init(img_1d: np.ndarray, k: int, random_state: Any = 0):
    '''
    Sampling k-points from the input image as centroids based on the k-means++ method, which tries to select points that contribute more potential difference to the loss function.
    random_state is ignored and present only for ease of access compliance.

    Parameters
    ----------
    img_1d : np.ndarray, shape (N, 3)
        The input image.
    k : int
        The number of centroids.

    Returns
    ----------
        np.ndarray((3), dtype=np.float64)
    '''
    Centroids = np.zeros((k, img_1d.shape[1]))
    Centroids[0] = np.random.default_rng().choice(img_1d)

    # techinically a copy of kmeans_data_dist, but we're returning the minimum of inner products of each data points to its closest center
    # data_set should be of shape (N x d)
    def closest_dist(c_set: np.ndarray, data_set: np.ndarray) -> np.ndarray:
        dist_set_raw = data_set[:, np.newaxis] - c_set
        return np.min(np.einsum("...i, ...i -> ...", dist_set_raw, dist_set_raw), axis=1)
    
    for kith in range(1,k):
        dist_set = closest_dist(Centroids[:kith], img_1d)
        current_sums_distribution = dist_set.sum()
        probability_distribution = dist_set / current_sums_distribution 
        ci_d = np.random.default_rng().choice(img_1d, None, False, probability_distribution)
        while any((Centroids[:kith]==ci_d).all(1)):
            ci_d = np.random.default_rng().choice(img_1d, None, False, probability_distribution)
        Centroids[kith] = ci_d
    return Centroids

# mapping to initialization procedures
initialize_func: dict[str, Callable[[np.ndarray, int], np.ndarray]] = {
    "random": kmeans_rand_init,
    "in_pixels": kmeans_img_rand_init,
    "kmeans++": kmeans_pp_init
}


def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    '''
    K-Means algorithm

    Parameters
    ----------
    img_1d : np.ndarray with shape=(height * width, num_channels)
        Original (1D) image
    k_clusters : int
        Number of clusters
    max_iter : int
        Max iterator
    init_centroids : str, default='random'
        The method used to initialize the centroids for K-means clustering
        'random' --> Centroids are initialized with random values between 0 and 255 for each channel
        'in_pixels' --> A random pixel from the original image is selected as a centroid for each cluster
        'kmeans++' --> Centroid selection based on its contribution to the overall potential (loss)

    Returns
    -------
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Stores the color centroids for each cluster
    labels : np.ndarray with shape=(height * width, )
        Stores the cluster label for each pixel in the image
    '''
    init_func = initialize_func.get(init_centroids, initialize_func["random"])
    k_clusters: np.ndarray = np.arange(0, k_clusters) if isinstance(k_clusters, int) else np.asarray(k_clusters)  # vectorize clusters input
    centroids = init_func(img_1d, k_clusters.size)
    centroids.astype(np.float64) # integers should be enough, but numpy insists on data type, so....
    # point labeling
    labeling = np.ndarray(img_1d.shape[0])

    #                                   helper func
    def kmeans_data_dist(data_set: np.ndarray, centroids_set: np.ndarray):
        # dim: (<img_1d.shape[0]>: each data point, <centroids.size>: distance to each centroid)
        # Nx3 - 3xK ~ NxK
        return np.linalg.norm(data_set[:, np.newaxis] - centroids_set, axis=2)
    
    def kmeans_labeling(data_set: np.ndarray, centroids_set: np.ndarray):
        # assigned cluster labels INDEX to each data point
        # N x 1
        return np.argmin(kmeans_data_dist(data_set=data_set, centroids_set=centroids_set), axis=1)
    
    def kmeans_iterator(centroids_set: np.ndarray, cluster_labels: np.ndarray, labeling: np.ndarray) -> np.ndarray:
        # try to get new centroids
        new_centroids = centroids_set
        for kith in range(cluster_labels.size):
            # slicing for indices of data points whose labels are matched with the current computing cluster
            label_data_points = img_1d[labeling == kith, :]
            if(len(label_data_points) == 0):
                continue
            new_centroids[kith, :] = np.mean(label_data_points, axis=0)
        return new_centroids
    
    # https://machinelearningcoban.com/2017/01/01/kmeans/
    def convergent(centroids_set_a: np.ndarray, centroids_set_b: np.ndarray) -> bool:
        return (set([tuple(a.astype(np.uint8)) for a in centroids_set_a]) == 
        set([tuple(a.astype(np.uint8)) for a in centroids_set_b]))

    while 0 < max_iter:
        labeling = kmeans_labeling(data_set=img_1d, centroids_set=centroids)
        new_centroids = kmeans_iterator(centroids_set=centroids, cluster_labels=k_clusters, labeling=labeling)

        if convergent(centroids, new_centroids):
            break
        centroids = new_centroids
        max_iter -= 1
    return (centroids, labeling)


def generate_2d_img(img_2d_shape, centroids, labels):
    '''
    Generate a 2D image based on K-means cluster centroids

    Parameters
    ----------
    img_2d_shape : tuple (height, width, 3)
        Shape of image
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Store color centroids
    labels : np.ndarray with shape=(height * width, label)
        Store label for pixels (cluster's index on which the pixel belongs)

    Returns
    -------
        np.ndarray, shape(W, H, 3)
    '''
    return centroids[labels].reshape(img_2d_shape[0], img_2d_shape[1], -1).astype(np.float64)



## tests

In [None]:
def eval_potential(img_1d: np.ndarray, c_set: np.ndarray, label_set: np.ndarray):
    '''
    Evaluating the potential differences between data points and c_set, basically kmeans_data_dist but returning the inner product instead.

    Parameters
    ----------
    img_1d : np.ndarray, shape (N, 3)
        The input image
    c_set : np.ndarray, shape (k, 3)
        The centroids set
    label_set : np.array, shape (N, )
        The 0-indexed label set (based on c_set) of each data points in the original input image
    
    Returns
    ----------
        float
            The evaluated potential
    '''
    dist_vec = img_1d - np.asarray([c_set[label] for label in label_set])
    return np.format_float_scientific(np.sum(np.einsum("...i, ...i -> ...", dist_vec, dist_vec)))

def mth_job(mth_, k_, trials_, iteration_, img_1d_, org_shape_):
    '''
    Main test run and evaluation procedure

    Parameteres
    ----------
    mth_ : str
        The k-means initialization method.
    k_ : int
        Number of centroids.
    trials_ : int
        Number of resampling.
    img_1d_ : np.ndarray, shape (N, 3)
        The input image.
    org_shape_ : Array-like
        The original image shape.

    Returns
    ----------
        None
    '''
    run_results = []
    modified = None
    for i in range(trials_):
        start_time = process_time()
        (centroids, labeling) = kmeans(img_1d_, k_clusters=k_, max_iter=iteration_, init_centroids=mth_)
        end_time = process_time()

        potential = eval_potential(img_1d_, centroids, labeling)
        run_results.append([potential, np.round(end_time - start_time, 2)])

        # last iteration, show the result, could modify to show average results but that is out of this project's scope
        if i == trials_-1:
            modified = generate_2d_img(org_shape_, centroids, labeling)
    
    avg_run_result = np.mean(np.asarray(run_results, dtype=np.float64), axis=0)
    return (k_, avg_run_result, modified)

# provide ease of access to iterate over multiple seeding methods and k's of the algorithm
def test_eval(file_path="", methods=["random", "in_pixels", "kmeans++"], k_test_set=[3, 5, 7, 10, 25], iteration: int = 2000, trials: int = 20):
    org = read_img(file_path)
    org_flatten = convert_img_to_1d(org)
    results = dict()

    for mth in methods:
        result = dict()
        with multiprocessing.Pool() as pool:
            res = pool.starmap(mth_job, zip( \
                                repeat(mth), \
                                k_test_set, \
                                repeat(trials), \
                                repeat(iteration), \
                                repeat(org_flatten), \
                                repeat(org.shape)) \
            )
        for k_, run_res, img_map in res:
            result[k_] = (run_res, img_map)
        results[mth] = result  
    return results

# 
def test_run():
    '''
    main entry, handling test runs and results exporting. Should be customized according to testing set if needed.
    Parameters
    ----------
        None

    Returns
    ----------
        None
    '''
    # data type: [ "method": [ K: tuple( np.ndarray((2,1)), np.ndarray((W, H, D)) ) ] ] , W H D is the dimension of the input image
    methods = ["random", "in_pixels", "kmeans++"]
    k_test_set = [3, 5, 7, 10, 25]
    test_results = test_eval(file_path="free-nature-images.jpg", methods=methods, k_test_set=k_test_set, trials=20, iteration=2000)
    result_path = "./outputs/"
    Path(result_path).mkdir(exist_ok=True)

    # showing and saving result images
    for mth, ks in test_results.items():
        for k, data_ in ks.items():
            show_img(data_[1], "method: " + mth + ", K = " + str(k))
            save_img(data_[1], img_path=result_path + mth +  "_k-" + str(k) + ".png")

    # writing results to csv
    with open(result_path + "test_results.csv", "w+", newline='') as out_file:
        fields = ["K", *methods]
        writer = csv.DictWriter(out_file, fieldnames=fields, dialect="excel")
        writer.writeheader()

        for k in k_test_set: 
            data_t = dict()
            data_t["K"] = str(k)
            for mth in methods:
                data_t[mth] = ''.join(str(test_results[mth][k][0]))
            writer.writerow(data_t)
                

# commenting this to not run the tests (computationally expensive operations)
# test_run()

## Main FUNCTION

In [None]:
def main():
    img_path = input("Enter image path: ")
    org = read_img(img_path=img_path)
    k_clusters = int(input("Number of clusters: "))

    max_iter = int(input("Max iteration: "))
    init_type = input("Initialization type (Leave blank for default=random) [random, in_pixels, kmeans++]:")
    (centroids, labeling) = kmeans(convert_img_to_1d(org), k_clusters=k_clusters, max_iter=max_iter, init_centroids=init_type)
    modified = generate_2d_img(org.shape, centroids, labeling)
    print("Done! Showing...")
    show_img(modified, "method: " + init_type + ", K = " + str(k_clusters))
    while True:
        cont = input("Save?(Y/N): ")
        if len(cont) == 0 or cont.lower() == "n":
            return
        if cont.lower() != 'y':
            print("Invalid input")
        else:
            break
    save_path = os.path.abspath(os.getcwd()) + "/output."
    save_format = "\0"
    while True:
        save_format = input("Save format: (pdf/png)")
        save_format = save_format.lower()
        if save_format == "pdf" or save_format == "png":
            break
        print("Invalid option, try again")
    save_path += save_format
    save_img(modified, save_path)
    

In [None]:
# Call main function
main()