In [1]:
import scipy.io as sio
import numpy as np
from sklearn.cluster import KMeans
from collections import Counter
import random

#### Part 1

In [2]:
matFile = sio.loadmat("mnist_10digits.mat")

xtrain = matFile["xtrain"]
ytrain = matFile["ytrain"]
xtest = matFile["xtest"]
ytest = matFile["ytest"]
ytrain = ytrain.flatten()

In [3]:
# Normalization
xtrain = xtrain.astype("float32") / 255.0
xtest = xtest.astype("float32") / 255.0

In [4]:
n_digits = len(np.unique(ytrain))

# Set seed for reproducibility
random.seed(5)
np.random.seed(5)

# Instantiate Kmeans clustering object
kmeans = KMeans(n_clusters=n_digits, random_state=45, n_init = 10)

cluster_labels = kmeans.fit_predict(xtrain)

In [5]:
def calc_purity(clusters):
    purity = {}

    for cluster in np.unique(clusters):
        # Find indices of each cluster
        indx = np.where(clusters == cluster)[0]

        # Get the true labels associated with the nodes that are grouped in the same cluster
        cluster_true_labels = ytrain[indx]
        # Count # of true labels that are in the same cluster
        count_labels = Counter(cluster_true_labels)

        # Get the majority label based on the maximum count
        maj_label = max(
            count_labels, key=count_labels.get
        )  # Code adapted from https://www.w3resource.com/python-exercises/dictionary/python-data-type-dictionary-exercise-80.php

        matched_num = count_labels[maj_label]

        # Compute purity rate = correctly assigned samples/ size of cluster
        purity[cluster] = matched_num / len(indx)


    return purity

In [6]:
np.random.seed(5)

purity = calc_purity(cluster_labels)
print(f"{purity = }")


purity = {np.int32(0): 0.5260193133047211, np.int32(1): 0.6232394366197183, np.int32(2): 0.8593668007696345, np.int32(3): 0.4266335066696812, np.int32(4): 0.3572308726335835, np.int32(5): 0.9047927461139896, np.int32(6): 0.8952198036705079, np.int32(7): 0.5327487473156765, np.int32(8): 0.7922858046158711, np.int32(9): 0.5261233815689261}


#### Part 2

In [7]:
from sklearn.metrics import pairwise_distances
from scipy import stats

In [8]:
matFile = sio.loadmat("mnist_10digits.mat")

xtrain = matFile["xtrain"]
ytrain = matFile["ytrain"]
xtest = matFile["xtest"]
ytest = matFile["ytest"]
ytrain = ytrain.flatten()

In [9]:
# If pixel > 128, assign as 1; otherwise 0
# Adapted code from https://stackoverflow.com/questions/19766757/replacing-numpy-elements-if-condition-is-met
xtrain[xtrain <= 128] = 0
xtrain[xtrain > 128] = 1
np.unique(xtrain)
xtrain.shape

(60000, 784)

In [10]:
def kmeans_hamming(k=10, epsilon = 0):
    n_feature = xtrain.shape[1]

    # Randomly creates centroids with binary values (0 and 1)
    centroids = np.random.choice([0, 1], size = (10, n_feature))
    
    # Store changes in centroids
    avg_hamming = np.inf

    while avg_hamming > epsilon:
        normalized_distances = pairwise_distances(xtrain, centroids, metric='hamming')
        
        # distances = normalized_distances * n_feature

        # Assign data points to the closest centroids
        labels = np.argmin(normalized_distances, axis = 1)
        
        # Code adapted from https://stackoverflow.com/questions/16330831/most-efficient-way-to-find-mode-in-numpy-array
        # Store updated centroids
        new_centroids = np.zeros((10, n_feature))

        # Find majority label for each feature
        for i in range(k):
            indx = labels[labels == i]
            assigned_data = xtrain[indx]
            majority = stats.mode(assigned_data, keepdims = True)[0]
            # Store new centroids in new_centroids
            new_centroids[i,:] = majority
        
        change = np.sum(new_centroids != centroids, axis = 1)
        avg_hamming = np.mean(change) # Average Hamming distance across centroids
        centroids = new_centroids

    return labels
    


In [11]:
np.random.seed(5)

labels = kmeans_hamming(k=10, epsilon = 0)

In [12]:
purity = calc_purity(labels)
print(f"{purity = }")

purity = {np.int64(0): 0.5296752519596865, np.int64(1): 0.7938212456633075, np.int64(2): 0.5611924686192469, np.int64(3): 0.3341034103410341, np.int64(4): 0.4146772767462423, np.int64(5): 0.7190119100132334, np.int64(6): 0.2898259705488621, np.int64(7): 0.4849068721901092, np.int64(8): 0.19253303415607081, np.int64(9): 0.3377684407096172}
