In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.datasets import fetch_openml



In [1]:
import torch
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_properties(i).name)


NVIDIA GeForce RTX 2080 Ti


In [2]:
!nvidia-smi

Mon Feb 26 16:37:33 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| 40%   38C    P5    49W / 260W |    460MiB / 11264MiB |     32%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.datasets import fetch_openml

# Load the MNIST dataset, False makes it return the data as a NumPy array
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='liac-arff')

# Flatten the images
X = mnist.data
y = mnist.target
#print(X.shape) #(70000, 784)

# Split the data
X_train, y_train = X[:60000], y[:60000]
X_test, y_test = X[60000:], y[60000:]

# Convert labels to integers
y_train = y_train.astype(int)
y_test = y_test.astype(int)

In [6]:
# Mahalanobis distance
def mahalanobis(x, y, VI):  #VI: inverse of the covariance matrix
    x_minus_y = x-y
    return np.sqrt(np.dot(np.dot(x_minus_y,VI), x_minus_y.T))

def kmeans_mahalanobis(X, k, iteration=100, tol=1e-2):
    old_centroids = X[np.random.choice(range(X.shape[0]), size= k, replace = False)]
    cov = np.cov(X, rowvar=False)
    reg_cov = cov + np.eye(cov.shape[0]) * 1e-6
    VI = np.linalg.inv(reg_cov)
    
    for n in range(iteration):
        print(n)
        labels = np.argmin(np.array([[mahalanobis(x, c, VI) for c in old_centroids]for x in X]), axis = 1 )
        new_centroids = np.array([X[labels == i].mean(axis = 0)for i in range(k)])
        if np.linalg.norm(old_centroids - new_centroids) < tol:
            break
        old_centroids = new_centroids
    return labels, old_centroids

In [7]:
def cluster_consistency(labels, y_train, k):
    Q = 0

    for i in range(k):
        
        cluster_labels = y_train[labels == i]
        class_counts = np.bincount(cluster_labels)

        mi = np.max(class_counts)
        Ni = len(cluster_labels)

        Qi = mi / Ni
        Q += Qi
    Q /=k

    return Q

k_values = [5, 10, 20, 40, 200]

for k in k_values:
    labels, centroids = kmeans_mahalanobis(X_train, k)
    Q = cluster_consistency(labels, y_train, k)
    print(f"Consistency for k={k}: {Q}")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Consistency for k=5: 0.42284549081588574
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Consistency for k=10: 0.44387971378390445
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Consistency for k=20: 0.6728455042213275
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
5