In [21]:
from collections import Counter
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, confusion_matrix

# === 1. Dane ===
X = np.load("./data/image_emb.npy")
X_test = np.load("./data/image_emb_test.npy")

labels = np.load("./data/image_labels.npy", allow_pickle=True)
labels_test = np.load("./data/image_labels_test.npy", allow_pickle=True)

# klasy niepuste
unique_classes = sorted({lbl for lbl in labels if lbl != ""})
num_classes = len(unique_classes)

print("Klasy:", unique_classes)

# === 2. Klasteryzacja ===
kmeans = KMeans(n_clusters=num_classes, random_state=42)
cluster_assign_train = kmeans.fit_predict(X)

# === 3. Mapowanie klaster → klasa ===
cluster_to_class = {}

for c in range(num_classes):
    labels_in_cluster = [
        labels[i]
        for i in range(len(labels))
        if cluster_assign_train[i] == c and labels[i] != ""
    ]
    majority = Counter(labels_in_cluster).most_common(1)[0][0]
    cluster_to_class[c] = majority

print("\nMapowanie klaster → klasa:")
for c, cls in cluster_to_class.items():
    print(f"{c}: {cls}")

# === 4. Predykcja na test ===
cluster_assign_test = kmeans.predict(X_test)
y_pred_test = np.array([cluster_to_class[c] for c in cluster_assign_test])
y_true_test = labels_test

# === 5. Ewaluacja ===
acc = accuracy_score(y_true_test, y_pred_test)
cm = confusion_matrix(y_true_test, y_pred_test, labels=unique_classes)

print("\nAccuracy:", acc)
print("\nConfusion matrix:")
print(cm)


Klasy: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Mapowanie klaster → klasa:
0: deer
1: ship
2: truck
3: cat
4: frog
5: dog
6: horse
7: airplane
8: bird
9: automobile

Accuracy: 0.7693333333333333

Confusion matrix:
[[372   9  17 263   2   0   1   0  24   4]
 [  0 710   0  36   0   0   0   1   2  10]
 [  8   0 599 118  30   3   3   0   0   0]
 [  0   0   3 656   5  81   1   0   0   0]
 [  0   0   4 164 578   0  13   0   2   0]
 [  0   0   2 171   7 567   2   9   0   0]
 [  0   1   4 174   0   0 574   0   1   0]
 [  0   0   0 366   8   5   0 381   0   0]
 [  4  11   0  69   0   0   0   0 655   0]
 [  0  22   0  63   0   0   0   0   7 678]]
