In [25]:
import numpy as np
from PIL import Image
from matplotlib import pyplot

In [26]:
def input_image():
    image_name = input("Enter the name of image (ex: img1.jpg, img2.png,... ): ")
    try:
        image = Image.open(image_name).convert('RGB')
    except FileNotFoundError as err:
        print(err)
        return
    
    return image

In [27]:
def compute_distance(img_1d, centroids, index, cluster):
    #manhantan distance to calculate the distance
    return np.sum(np.abs(img_1d[index] - centroids[cluster]))

def assign_labels(img_1d, centroids, k_clusters):
    labels = np.zeros(img_1d.shape[0])
    for i in range(img_1d.shape[0]):
        tmp = 10000000
        for cluster in range(k_clusters):
            dist = compute_distance(img_1d, centroids, i, cluster)
            if (dist < tmp):
                tmp = dist
                labels[i] = cluster
    return labels

def update_centroids(img_1d, centroids, labels, k_clusters):
    for cluster in range(k_clusters):
        cluster_points = img_1d[labels == cluster]
        if (len(cluster_points) > 0):
            centroids[cluster] = np.mean(cluster_points, axis=0) 
    return centroids

def has_converged(centroids, new_centroids):
    return (set([tuple(a) for a in centroids]) == 
        set([tuple(a) for a in new_centroids]))

def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    #init centroids and labels
    if (init_centroids == 'random'):
        centroids = np.random.randint(0, 255, size=(k_clusters, img_1d.shape[1]))
    elif (init_centroids == 'in_pixels'):
        centroids = img_1d[np.random.choice(img_1d.shape[0], size=k_clusters, replace=False)]
    
    for _ in range(max_iter):
        labels = assign_labels(img_1d, centroids, k_clusters)
        new_centroids = update_centroids(img_1d, centroids, labels, k_clusters)
        if (has_converged(centroids, new_centroids)):
            break

    return centroids, labels



In [28]:
def handle_and_compress_image(image):
    #convert to matrix 
    img_matrix = np.array(image)
    heigth, width, num_channels = img_matrix.shape
    img_1d = img_matrix.reshape((heigth * width, num_channels))

    #start kmeans
    centroids, labels = kmeans(img_1d, 7, 1000)

    # getting back the 3d matrix (row, col, rgb(3))
    centroid = np.array(centroids)
    recovered = centroid[labels.astype(int), :]
    recovered = recovered.reshape(img_matrix.shape)
    
    # plotting the compressed image.
    format = int(input("Enter format to save (png: 1, pdf: 2): "))
    if (format == 1):
        pyplot.imshow(recovered)
        recovered = recovered.astype(np.uint8)
        pyplot.imsave("compressed_image.png", recovered)
    elif (format == 2):
        pyplot.imshow(recovered)
        pyplot.savefig("compressed_image.pdf", format="pdf")
        
    # pyplot.imshow(recovered)
    pyplot.show()

    

In [30]:
if __name__ == "__main__":
    image = input_image()
    handle_and_compress_image(image)

AttributeError: shape