In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

In [None]:
NUM_CLUSTERS = 2

dataset = pickle.load(open('clustered_embeddings.pkl', 'rb'))
all_embeddings = np.concatenate([data['embeds'] for data in dataset], axis=0)
print(all_embeddings.shape)

In [3]:
def cluster_embeddings(embeddings, n_clusters):
    all_embeddings = np.concatenate([data['embeds'] for data in embeddings], axis=0)
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(all_embeddings)
    labels = kmeans.labels_
    print(len(labels))
    for i, data in enumerate(embeddings):
        data["label"] = labels[i]
    return embeddings

In [None]:
clustered_embeddings = cluster_embeddings(dataset, NUM_CLUSTERS)
pickle.dump(clustered_embeddings, open("clustered_embeddings.pkl", "wb"))

In [None]:
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(all_embeddings)

colors = ["red", "blue", "green", "purple"]
for i, embedding in enumerate(dataset):
    plt.scatter(reduced_embeddings[i, 0], reduced_embeddings[i, 1], color=colors[embedding["label"]])
plt.title("Clustered embeddings")
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.grid()
plt.show()

""" pca_ = PCA(n_components=3)
reduced_embeddings = pca_.fit_transform(np.concatenate([embedding["embeds"] for embedding in embeddings], axis=0))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i, embedding in enumerate(embeddings):
    ax.scatter(reduced_embeddings[i, 0], reduced_embeddings[i, 1], reduced_embeddings[i, 2], color=colors[embedding["label"]])
plt.title("Clustered embeddings")
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
ax.set_zlabel("PCA 3")
plt.grid()
plt.show() """

In [1]:
import os
import matplotlib.pyplot as plt
import pickle
from collections import Counter

In [2]:
dataset = pickle.load(open('clustered_embeddings.pkl', 'rb'))

In [None]:
print(len(dataset))
print(dataset[0].keys())
print(dataset[0]['patch'])
print(dataset[0]['embeds'].shape)
print(dataset[0]['label'])

In [None]:
unique_labels = set([data['label'] for data in dataset])
num_labels = len(unique_labels)

print("Number of unique labels: ", num_labels)
print("Labels: ", unique_labels)

In [None]:
# Chosee random 10 patches from each cluster

clusters = {}
for d in dataset:
    if d['label'] not in clusters:
        clusters[d['label']] = []
    clusters[d['label']].append(d)

for k, v in clusters.items():
    print(k, len(v))
    
    if len(v) > 10:
        v = v[180000:180010]

    fig, axs = plt.subplots(2, 5, figsize=(20, 8), dpi=100)
    for i, d in enumerate(v):
        patch = os.path.join('datasets/patches', d['patch'])
        patch = plt.imread(patch)
        axs[i//5, i%5].imshow(patch)
        axs[i//5, i%5].set_title(d['patch'])
        
    plt.show()
