In [2]:
import pickle
import numpy as np

In [3]:
with open('embeddings_all.pkl', 'rb') as file:
    data = pickle.load(file)

x = data["outputs"].cpu().numpy()
y = data["labels"].cpu().numpy()

In [4]:
# normalize embeddings
x = (x - np.min(x)) / (np.max(x) - np.min(x))

In [5]:
embeddings = np.array(x)

In [6]:
# do PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=100)
pca.fit(embeddings)

pca_transform = pca.transform(embeddings)

In [7]:
fine_labels = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']

print(len(fine_labels))

100


In [8]:
# compute the centroid for each class
centroids = []
for i in range(100):
    centroids.append(np.mean(pca_transform[y == i], axis=0))

print(len(centroids))
#print(centroids)

100
[array([ 0.05010166, -0.01225002,  0.12724656, -0.1505487 , -0.08199433,
       -0.01974311,  0.07010168,  0.15655112,  0.08414223, -0.11111852,
        0.02555917,  0.11408956, -0.00374186,  0.04006504, -0.05533646,
        0.08079492,  0.165545  ,  0.1159436 , -0.1632486 ,  0.15103614,
       -0.00123266, -0.05124922, -0.06418547, -0.0287468 ,  0.02224082,
       -0.03053893, -0.06973585,  0.02891964, -0.0160425 ,  0.00385101,
        0.0535934 , -0.00983144, -0.07205262, -0.01133927, -0.03478529,
       -0.03621357, -0.03091669, -0.0586762 ,  0.03183748, -0.01978778,
        0.06451975,  0.04320461, -0.02074438, -0.03618882, -0.01500285,
       -0.05438012,  0.04452179, -0.01539901,  0.04021429,  0.05529677,
        0.03588022,  0.02589567,  0.02625083, -0.01195503, -0.03703377,
        0.01805221, -0.03363945,  0.00896023, -0.01263595,  0.03826422,
       -0.00895749, -0.0485932 , -0.00103038, -0.02223672,  0.02685896,
        0.04622087,  0.01942882, -0.03777413, -0.0178368 , 

In [10]:
# compute the distance matrix 
from scipy.spatial.distance import cdist

dist = cdist(centroids, centroids, metric='euclidean')
print(dist.shape)

print(dist)


(100, 100)
[[0.         0.74442849 0.71874767 ... 0.69761268 0.75287131 0.73432406]
 [0.74442849 0.         0.66236316 ... 0.67659019 0.72573713 0.69008426]
 [0.71874767 0.66236316 0.         ... 0.67533094 0.36532893 0.68393619]
 ...
 [0.69761268 0.67659019 0.67533094 ... 0.         0.66904051 0.69999079]
 [0.75287131 0.72573713 0.36532893 ... 0.66904051 0.         0.70571092]
 [0.73432406 0.69008426 0.68393619 ... 0.69999079 0.70571092 0.        ]]


In [28]:
taken = np.zeros(100, dtype=bool)

clusters = {}

dist_complete = dist.copy()

for i in range(100):
    dist_complete[i, i] = np.inf

for i in range(10):
    max_ind = np.argsort(dist_complete[i][:10])
    print(max_ind)
    dist_complete[:, max_ind] = np.inf

      

[4 3 7 9 2 1 5 8 6 0]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]


In [14]:
print(clusters)

{0: array([56, 82, 52,  9, 91, 60, 21, 76, 69, 39]), 1: array([90, 92, 31, 44, 49, 66, 72, 25, 28, 13]), 2: array([10, 34, 97, 45, 31, 64, 35, 49, 24, 44]), 3: array([54, 71, 30,  3, 20, 18, 79, 64, 42, 14]), 4: array([54, 73, 71, 62,  3, 49, 79, 26, 63, 92]), 5: array([24, 83, 19, 93, 86, 64, 90, 11, 39, 31]), 6: array([78, 13,  6, 44, 81, 53, 17, 23, 25, 28]), 7: array([23, 78,  6, 13, 17, 43, 76, 49, 25, 44]), 8: array([47, 78, 88, 77, 21, 83, 44, 43, 40, 96]), 12: array([36, 89, 16, 75, 67, 84, 80, 68, 70, 71]), 15: array([18, 30,  3, 37, 28, 42, 54, 71, 64, 33]), 22: array([60, 10, 39, 85, 86, 77, 16, 31, 38, 83]), 27: array([43, 54, 92, 71,  4, 79, 73, 90, 62, 41]), 29: array([43, 92, 44, 15, 26, 27, 37, 79, 33, 90]), 32: array([92, 66, 71, 90, 54, 45, 49, 11, 27, 64]), 46: array([11, 97, 35,  2, 32, 40, 21, 64, 83,  3]), 48: array([ 8, 41, 88, 45, 77, 26, 57, 14, 44, 22]), 50: array([73, 63, 79, 36, 62,  4, 64, 44, 54, 32]), 51: array([76, 79, 73, 62, 50, 44, 64, 29, 91, 40]), 5

In [12]:
# now perform agglomerative clustering but if a cluster reaches 10 elements, don't merge it with anything else

from sklearn.cluster import AgglomerativeClustering

clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0, linkage='single', compute_full_tree=True).fit(dist)

print(clustering.labels_)

# print dendrogram

from scipy.cluster.hierarchy import dendrogram
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(25, 10))
dn = dendrogram(clustering.children_, labels=fine_labels, orientation='top')
plt.show()



[58 83 94 56 84 86 87 76 60 74 55 99 79 88 77 69 81 59 61 85 93 89 67 80
 65 91 63 57 42 49 95 70 75 50 64 98 96 29 82 37 27 51 24 28 97 92 48 46
 62 53 78 34 66 31 45 47 68 52 43 90 32 73 54 33 26 36 71 38 39 44 25 40
 18 41 19 23 30 13 12 16 15 21 72 35 20  9 14 11  6 17  8  7 22  5 10  3
  4  1  2  0]


TypeError: Linkage matrix 'Z' must contain doubles.

<Figure size 2500x1000 with 0 Axes>