In [1]:
import numpy as np
import distance_utils

In [2]:
def read_data(file_name):
    data_set = []
    with open(file_name, "rt") as f:
        for line in f:
            line = line.replace("\n", "")
            tokens = line.split(",")
            label = tokens[0]
            attribs = []
            for i in range(784):
                attribs.append(tokens[i + 1])
            data_set.append([label, attribs])
    return data_set


def get_labels(data_set):
    labels = [int(row[0]) for row in data_set]
    return labels


def get_features(data_set):
    features = [[int(datapoint) for datapoint in row[1]] for row in data_set]
    return features


In [3]:
validation_set = read_data("valid.csv")
labels = np.array(get_labels(validation_set))
features = np.array(get_features(validation_set))

In [4]:
features

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [5]:
n_clusters = 10

In [6]:
random_indices = np.random.choice(
            features.shape[0], n_clusters, replace=False
        )

In [7]:
len(random_indices)

10

In [8]:
centroids = features[random_indices]

In [9]:
centroids

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [10]:
previous_centroids = np.zeros_like(centroids)

In [11]:
previous_centroids

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [12]:
np.allclose(centroids, previous_centroids)

False

In [13]:
previous_centroids = centroids

In [14]:
metric = 'euclidean'

In [15]:
distance_metric = getattr(distance_utils, metric + "_distance")

In [16]:
distance_metric(features, centroids)

array([[2571.72082466, 1960.46499586, 2146.97484848, ..., 2291.55122133,
        2766.63315241, 2209.98393659],
       [2660.0428568 , 2726.43888617, 2387.52445014, ..., 2781.29556142,
        2807.43316928, 2450.5438172 ],
       [2285.27722607, 2080.48816387, 2183.46948685, ..., 2300.63621635,
        2717.64070473, 2299.28532375],
       ...,
       [2879.87725433,    0.        , 2208.72814081, ..., 2272.9126688 ,
        2670.15093206, 2307.20696948],
       [2981.80247501, 2265.41188308, 2479.90544175, ..., 2417.01530818,
        2073.47220864, 2201.86307476],
       [2812.17709257, 2015.53045127, 2405.27773864, ..., 2459.19885329,
        2683.57690406, 2436.57608131]])