# Project 01 - Color Compression

## Thông tin sinh viên

- Họ và tên: Hoàng Bảo Khanh
- MSSV: 22127183
- Lớp: 22CLC03

## Import các thư viện liên quan

In [1]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

## Helper functions

In [2]:
def read_img(img_path):
    '''
    Read image from img_path

    Parameters
    ----------
    img_path : str
        Path of image

    Returns
    -------
        Image (2D)
    '''

    # YOUR CODE HERE
    image = Image.open(img_path)
    return image

def show_img(img_2d):
    '''
    Show image

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)
    '''

    # YOUR CODE HERE
    plt.imshow(img_2d)

def save_img(img_2d, img_path):
    '''
    Save image to img_path

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)
    img_path : str
        Path of image
    '''

    # YOUR CODE HERE
    # Split the string by dot, and take the string in front of the dot 
    img = Image.fromarray(img_2d)
    img.save(img_path + '.png')
    img.save(img_path + '.pdf')

def convert_img_to_1d(img_2d):
    '''
    Convert 2D image to 1D image

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)

    Returns
    -------
        Image (1D)
    '''

    # YOUR CODE HERE
    return np.reshape(img_2d, (img_2d.shape[0] * img_2d.shape[1], img_2d.shape[2]))

def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    '''
    K-Means algorithm

    Parameters
    ----------
    img_1d : np.ndarray with shape=(height * width, num_channels)
        Original (1D) image
    k_clusters : int
        Number of clusters
    max_iter : int
        Max iterator
    init_centroids : str, default='random'
        The method used to initialize the centroids for K-means clustering
        'random' --> Centroids are initialized with random values between 0 and 255 for each channel
        'in_pixels' --> A random pixel from the original image is selected as a centroid for each cluster

    Returns
    -------
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Stores the color centroids for each cluster
    labels : np.ndarray with shape=(height * width, )
        Stores the cluster label for each pixel in the image
    '''
    
    # YOUR CODE HERE

    centroids = initialize_centroids(img_1d, k_clusters, init_centroids)
        
    for i in range(max_iter):
        # Calculate distances from pixels to centroids and return the indices of the minimum distance to centroids
        labels = np.argmin(np.linalg.norm(img_1d - centroids[:, None], axis=2), axis=0)
        # Check if the label contains values from 0 to k_clusters - 1 or not
        # If not, random centroids again. 
        unique_label = np.unique(labels)
        if (len(unique_label) != k_clusters):
            centroids = initialize_centroids(img_1d, k_clusters, init_centroids)
            continue
        # Save the previous centroids
        old_centroids = centroids
        # Reset the curent centroids
        centroids = np.zeros(old_centroids.shape)
        for i in range(old_centroids.shape[0]):
            # get index where the label matchs the current i value
            index = np.where(i == labels)
            # Calculate the new centroids by using mean on row in each label
            centroids[i] = np.mean(img_1d[index], axis=0) 

        # Check if the old centroids are equal within a tolerance
        if (np.allclose(old_centroids, centroids, rtol=10e-5, equal_nan=False)):
            break

    return centroids, labels

        
def generate_2d_img(img_2d_shape, centroids, labels):
    '''
    Generate a 2D image based on K-means cluster centroids

    Parameters
    ----------
    img_2d_shape : tuple (height, width, 3)
        Shape of image
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Store color centroids
    labels : np.ndarray with shape=(height * width, )
        Store label for pixels (cluster's index on which the pixel belongs)

    Returns
    -------
        New image (2D)
    '''

    # YOUR CODE HERE
    result = centroids[labels].astype(np.uint8)
    result = result.reshape(img_2d_shape.shape)
    return result

# Your additional functions here
def initialize_centroids(img, k_clusters, init_centroids):
    if (init_centroids == 'random'):
        # random value 
        return np.random.randint(0, 256, size=(k_clusters, len(img[0])))
    elif (init_centroids == 'in_pixels'):
        #random pixel in the img
        index = np.random.randint(0, len(img), size=k_clusters)
        return img[index]
    else: return None


## Your tests

In [3]:
#YOUR CODE HERE
# image = read_img("image.png")
# image = np.array(image)
# img_1d = convert_img_to_1d(image)
# k_clusters = 3
# centroids, label = kmeans(img_1d, k_clusters, 100)
# result = generate_2d_img(image, centroids, label)
# show_img(result)
# img_path = "image_" + f"{k_clusters}" + "clusters_result"
# save_img(result, img_path)

## Main FUNCTION

In [4]:
# YOUR CODE HERE
def main():
  image_path = input("Enter your image path: ")
  image = read_img(image_path)
  image_array = np.array(image)
  print(image_array)
  image_1d = convert_img_to_1d(image_array)
  k_clusters = int(input("Enter the number of colors used for K-Means algorithm: "))
  max_iters = int(input("Enter the maximum of iterations using in K-Means algorithm: "))
  
  init_centroids = input("Enter the method for initializing centroids (Enter 'random' or 'in_pixels'): ")
  while (init_centroids != 'random' and init_centroids != 'in_pixels'):
    print("Invalid input. Please enter again.")
    init_centroids = input("Enter the method for initializing centroids (Enter 'random' or 'in_pixels'): ")

  centroids, labels = kmeans(image_1d, k_clusters, max_iters, init_centroids)
  image_result = generate_2d_img(image_array, centroids, labels)
  print("Image after reducing colors with " + f"{k_clusters}" + " clusters")
  show_img(image_result)
  
  img_path = "image_" + f"{k_clusters}" + "clusters_result"
  save_img(image_result, img_path) 
  

In [5]:
# Call main function
main()

[[[ 12   6   6]
  [ 11   5   5]
  [ 13   7   9]
  ...
  [ 38  18   0]
  [ 29   7   0]
  [ 25   3   0]]

 [[ 15   9   9]
  [ 15   9   9]
  [ 14   8  10]
  ...
  [ 51  30   0]
  [ 37  12   0]
  [ 31   8   0]]

 [[ 19  13  15]
  [ 19  13  15]
  [ 17  11  15]
  ...
  [ 65  40   0]
  [ 53  24   0]
  [ 50  23   0]]

 ...

 [[ 84  64  65]
  [ 85  67  67]
  [ 80  64  65]
  ...
  [ 46 126 237]
  [ 47 129 239]
  [ 48 130 238]]

 [[ 80  60  59]
  [ 81  63  61]
  [ 74  58  58]
  ...
  [ 45 127 235]
  [ 46 128 236]
  [ 47 130 236]]

 [[ 77  57  56]
  [ 79  61  59]
  [ 73  57  57]
  ...
  [ 44 126 234]
  [ 46 129 235]
  [ 47 130 236]]]


ValueError: invalid literal for int() with base 10: ''

: 