# Project 01 - Color Compression

## Thông tin sinh viên

- Họ và tên: Nguyễn Hoàng Trung Kiên
- MSSV: 22127478
- Lớp: 22CLC08

## Import các thư viện liên quan

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

## Helper functions

In [None]:
def read_img(img_path):

    '''
    Read image from img_path

    Parameters
    ----------
    img_path : str
        Path of image

    Returns
    -------
        Image (2D)
    '''

    # YOUR CODE HERE
    img = Image.open(img_path)
    img_2d = np.array(img)
    return img_2d

def show_img(img_2d):
    '''
    Show image

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)
    '''

    # YOUR CODE HERE
    plt.imshow(np.uint8(np.clip(img_2d, 0, 255)))
    plt.show()
    plt.close()

def save_img(img_2d, img_path):
    '''
    Save image to img_path

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)
    img_path : str
        Path of image
    '''

    # YOUR CODE HERE
    img = Image.fromarray(np.uint8(img_2d))
    # RGBA --> RGB if RBGA
    if img.mode == 'RGBA':
        img = img.convert('RGB')
    # file extension
    if img_path.lower().endswith('.png'):
        img.save(img_path, "PNG")
    elif img_path.lower().endswith(('.jpg', '.jpeg')):
        img.save(img_path, "JPEG")
    elif img_path.lower().endswith('.pdf'):
        img.save(img_path, "PDF", resolution=100.0)

def convert_img_to_1d(img_2d):
    '''
    Convert 2D image to 1D image

    Parameters
    ----------
    img_2d : <your type>
        Image (2D)

    Returns
    -------
        Image (1D)
    '''

    # YOUR CODE HERE
    height, width, channels = img_2d.shape
    img_1d = img_2d.reshape((height * width, channels))
    return img_1d



def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    '''
    K-Means algorithm

    Parameters
    ----------
    img_1d : np.ndarray with shape=(height * width, num_channels)
        Original (1D) image
    k_clusters : int
        Number of clusters
    max_iter : int
        Max iterator
    init_centroids : str, default='random'
        The method used to initialize the centroids for K-means clustering
        'random' --> Centroids are initialized with random values between 0 and 255 for each channel
        'in_pixels' --> A random pixel from the original image is selected as a centroid for each cluster

    Returns
    -------
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Stores the color centroids for each cluster
    labels : np.ndarray with shape=(height * width, )
        Stores the cluster label for each pixel in the image
    '''
    
    # YOUR CODE HERE
    # Initialize centroids
    if init_centroids == 'random':
        centroids = np.random.randint(0, 256, (k_clusters, img_1d.shape[1]))
    elif init_centroids == 'in_pixels':
        random_idx = np.random.choice(img_1d.shape[0], k_clusters, replace=False).astype(int)
        centroids = img_1d[random_idx]

    for i in range(max_iter):
        distances = np.sqrt(((img_1d[:, np.newaxis] - centroids)**2).sum(axis=2)) # Euclidean distance
        labels = np.argmin(distances, axis=1)
        new_centroids = np.array([img_1d[labels == j].mean(axis=0) if np.any(labels == j) else centroids[j] for j in range(k_clusters)])
        # Check for convergence
        if np.array_equal(centroids, new_centroids):
            break
        
        centroids = new_centroids
    return centroids, labels

def generate_2d_img(img_2d_shape, centroids, labels):
    '''
    Generate a 2D image based on K-means cluster centroids

    Parameters
    ----------
    img_2d_shape : tuple (height, width, 3)
        Shape of image
    centroids : np.ndarray with shape=(k_clusters, num_channels)
        Store color centroids
    labels : np.ndarray with shape=(height * width, )
        Store label for pixels (cluster's index on which the pixel belongs)

    Returns
    -------
        New image (2D)
    '''

    # YOUR CODE HERE
    height, width, num_channels = img_2d_shape
    labels_2d = labels.reshape((height, width)) # Reshape the labels
    new_img_2d = centroids[labels_2d] # Assign centroid values to the new image
    return new_img_2d


# Your additional functions here


## Your tests

In [None]:
# YOUR CODE HERE

## Main FUNCTION

In [None]:
# YOUR CODE HERE
def main():
    img_path = input("Enter the image file path: ")

    img_2d = read_img(img_path)
    img_1d = convert_img_to_1d(img_2d)

    k_clusters = int(input("Enter number of clusters: ")) # K-means clustering
    centroids_init = str(input("Enter init_centroids(random / in_pixels): "))

    max_iter = 100  # Maximum number of iterations
    centroids, labels = kmeans(img_1d, k_clusters, max_iter, init_centroids= centroids_init)
    new_img_2d = generate_2d_img(img_2d.shape, centroids, labels)

    print("Original: ")
    show_img(img_2d)
    print("After being compressed: ")
    show_img(new_img_2d)
    
    output_path = input("Enter the output file path (with .png, .pdf or .jpg extension): ")
    save_img(new_img_2d, output_path)
   

In [None]:
# Call main function
if __name__ == "__main__":
    main()