In [1]:
import numpy as np
import matplotlib.pyplot as plt
from utils import *

%matplotlib inline

In [11]:
def find_closet_centroids(X,centroids):
    k = centroids.shape[0]
    idx = np.zeros(X.shape[0],dtype=int)

    for i in range(X.shape[0]):
        dist = []
        for j in range(k):
            norm_ij = np.linalg.norm(X[i] - centroids[j])
            dist.append(norm_ij)
        idx[i] = np.argmin(dist)
    return idx

def compute_centroids(X,idx,K):
    m,n = X.shape
    centroids = np.zeros((K,n))

    for i in range(K):
        points = X[idx==i]
        centroids[i] = np.mean(points)
    return centroids


In [12]:
X = load_data()
initial_centroids = np.array([[3,3], [6,2], [8,5]])
idx = find_closet_centroids(X, initial_centroids)

print("First three elements in idx are:", idx[:3])
K = 3
centroids = compute_centroids(X, idx, K)

print("The centroids are:", centroids)


First three elements in idx are: [0 2 1]
The centroids are: [[2.79311264 2.79311264]
 [4.22357988 4.22357988]
 [5.36803564 5.36803564]]


In [15]:
def run_kMeans(X, initial_centroids, max_iters=10):
    m,n = X.shape
    K = initial_centroids.shape[0]
    centroids = initial_centroids
    idx = np.zeros(m)
    for i in range(max_iters):
        idx = find_closet_centroids(X,centroids)
        compute_centroids(X,idx,K)
    return idx,centroids

def get_kmeans(X,K):
    randix = np.random.permutation(X.shape[0])
    centroids = X[randix[:K]]
    return centroids


In [19]:
original_img = plt.imread('bird_small.png')
X_img = np.reshape(original_img, (original_img.shape[0] * original_img.shape[1], 3))

K = 16
max_iters = 10

# Using the function you have implemented above. 
initial_centroids = get_kmeans(X_img, K)

# Run K-Means - this can take a couple of minutes depending on K and max_iters
centroids, idx = run_kMeans(X_img, initial_centroids, max_iters)
idx = find_closet_centroids(X_img, centroids)

# Replace each pixel with the color of the closest centroid
X_recovered = centroids[idx, :] 

# Reshape image into proper dimensions
X_recovered = np.reshape(X_recovered, original_img.shape) 


KeyboardInterrupt: 