# Title
### Installation
You may need to install 

In [1]:
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import matplotlib.pyplot as plt
import PIL

In [4]:
#parameter
def get_pixel_error(center, pixelcell):
    sqr_diff = np.square(center - pixelcell)
    return sum(sqr_diff)

class KMeans:
    def __init__(self, k, iterations=15, initial_picktype='edgecolors'):
        self.k = k
        self.iterations = iterations
        self.initial_picktype = initial_picktype

    def cluster_datapoints(self, data):
        datapoints = len(data)
        dimension = len(data[0])
        
        #init clusters and centers
        cluster_centers = self.get_initial_cluster_centers(data)
        clusters = np.zeros_like(data)
                
        for iteration in range(self.iterations):           
            clusters = self.assign_each_point_to_cluster(data, cluster_centers, clusters)
            cluster_centers = self.calculate_new_cluster_center(data, cluster_centers, clusters)
        return [int(cclass) for cclass in clusters[:,0]], cluster_centers
    
    
    def get_initial_cluster_centers(self, data):
        if self.initial_picktype == 'random':
            return self.get_random_cluster_centers(data)
        else:
            return self.get_edgecolor_cluster_centers(data)
    
    def get_random_cluster_centers(self, data):
        #naive: return randomly k different elements
        centers = []
        while len(centers) < self.k:
            random_point = data[np.random.randint(low=0, high=len(data))]
            if self.data_point_is_not_in_centers(random_point, centers):
                centers.append(random_point)
        return centers

    def get_edgecolor_cluster_centers(self, data):
        #return edge colors
        if len(data[0])  != 3:
            raise()
        black = data[0] * 0 
        white = 1.0 - black
        red = np.array([1,0,0])
        green = np.array([0,1,0])
        blue = np.array([0,0,1])
        centers = [black, white, red, green, blue]
       
        return centers

    @staticmethod
    def data_point_is_not_in_centers(random_point, centers):
        for center in centers:
            if (random_point == center).all():
                return False
        return True
    
    def assign_each_point_to_cluster(self, data, cluster_centers, clusters):
        cluster_changes = 0
        for i in range(len(data)):
            datapoint = data[i]
            best_cluster = self.get_best_cluster_for_point(datapoint,cluster_centers)
            if (clusters[i] - best_cluster != 0).any():
                cluster_changes += 1
            clusters[i] = best_cluster
        print(cluster_changes)
        return clusters
    
    def get_best_cluster_for_point(self,datapoint,cluster_centers):
        best_cluster = -1
        best_error = float('inf')
        for c in range(len(cluster_centers)):
            clusterpoint = cluster_centers[c]
            error = get_pixel_error(clusterpoint, datapoint)
            if error < best_error:
                best_error = error
                best_cluster = c
        return best_cluster
    
    def calculate_new_cluster_center(self, data, cluster_centers, clusters ):
        new_cluster_centers = []
        for c in range(len(cluster_centers)):
            group_entries = ((clusters == c) * data)
            cluster_size = sum( (clusters == c)[:,0] )
            new_cluster_center = sum(group_entries) / float(cluster_size)
            new_cluster_centers.append(new_cluster_center)
            
        return new_cluster_centers
    
        
    
class Parameterbox:
    def __init__(self):
        self.used_colors = 16
        
    def generate(self):
        new_image = image.copy() 
        new_image = new_image / 255.0 
        
        height = new_image.shape[0]
        width = new_image.shape[1]
        dimension = new_image.shape[2]
        
                
        reshaped_input = new_image.reshape(height*width, dimension)
        kmeans = KMeans(self.used_colors, initial_picktype ='random')
        clusters, cluster_centers = kmeans.cluster_datapoints(reshaped_input)
        
                
        for i in range(len(clusters)):
            reshaped_input[i] = cluster_centers[clusters[i]] 
        return reshaped_input.reshape((height, width, dimension))
        
        

parameterbox = Parameterbox()

image = plt.imread('Images/panda-small.jpg')
#image = plt.imread('Images/brush.jpg')
#image = plt.imread('Images/flowers.jpg')
#image = plt.imread('Images/vietnam.jpg')

img_ratio = image.shape[0]/float(image.shape[1])

image = image[0::4, 0::4, :]



#define widges
plt_output = widgets.Output()
used_color_slider = widgets.IntSlider(
    min=2,
    max=100,
    step=1,
    description='Used colors:',
    value=3)
update_btn = widgets.Button(description='Update')

fig_width = 30

#define event handlers
def used_color_slider_eventhandler(change):
    parameterbox.used_colors = change.new
    
def update_btn_eventhandler(obj):
    update_plot()
    

#util functions
def update_plot():
    plt_output.clear_output()    

    data = [1,2,3,4,5,6,7,parameterbox.used_colors, parameterbox.used_colors]
    with plt_output:
        fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_width*img_ratio))
        
        axs[0].imshow(image)
        axs[0].axis('off')
        
        axs[1].imshow(parameterbox.generate())
        axs[1].axis('off')
        plt.show()



#add event handlers
used_color_slider.observe(used_color_slider_eventhandler, names='value')
update_btn.on_click(update_btn_eventhandler)


#display widges
#display(used_color_slider)
display(plt_output) 
#display(update_btn) 
update_plot()

Output()