In [2]:
import numpy as np
from sklearn.neighbors import KDTree
from sklearn.neighbors import DistanceMetric
from sklearn.cluster import MiniBatchKMeans, KMeans
from tqdm import tqdm
from sklearn.metrics.pairwise import pairwise_distances_argmin
import time


In [3]:
vqa_features = np.load('./data/train_vqa_features.npy', allow_pickle=True).tolist()
indices, features = vqa_features['vqa_idx'], vqa_features['features_fc1']
indices = np.array(indices)

dataset_sz = len(features)

In [4]:
print("shape of features: ", features.shape)
print("shape of indices: ", indices.shape)

shape of features:  (443757, 1000)
shape of indices:  (443757,)


In [3]:
from numpy import unique
from numpy import where
#from matplotlib import pyplot
from sklearn.datasets import make_classification
from sklearn.cluster import Birch

In [7]:
from sklearn.cluster import MiniBatchKMeans
from sklearn.utils import parallel_backend
import time

kmeans_model = MiniBatchKMeans(n_clusters=10000, batch_size=100, n_init=10, max_no_improvement=10, verbose=0, random_state=42)
print(time.time())

with parallel_backend('multiprocessing', n_jobs=-1):
    kmeans_model.partial_fit(features.squeeze())

print(time.time())

kmeans_labels = kmeans_model.predict(features.squeeze())
print(time.time())
print(kmeans_labels)


1681885227.7164066
1681886305.994714
1681886329.4778535
[2651 4422 4698 ... 3631 3676 5372]


In [4]:
import pickle

In [10]:
# save model and labels
with open('kmeans_model.pkl', 'wb') as f:
    pickle.dump(kmeans_model, f)

with open('labels.pkl', 'wb') as f:
    pickle.dump(kmeans_labels, f)


In [5]:
with open('kmeans_model.pkl', 'rb') as f:
    kmeans_model = pickle.load(f)

# load labels
with open('labels.pkl', 'rb') as f:
    kmeans_labels = pickle.load(f)

In [6]:
cluster_centers = kmeans_model.cluster_centers_
print(cluster_centers)

[[ 0.06863916  0.2534172   0.10789369 ... -0.2115441   0.08754913
   0.01806714]
 [ 0.13064805 -0.09504803  0.2234054  ... -0.16325949 -0.10266946
  -0.19426823]
 [ 0.06755423 -0.17248109  0.17541368 ... -0.09177572  0.00855821
   0.12957056]
 ...
 [ 0.24439205 -0.18738632  0.16365345 ... -0.07598827 -0.04716949
  -0.2192129 ]
 [ 0.01130708  0.01074852 -0.08228029 ... -0.00130801 -0.24420989
   0.29326463]
 [-0.0523099  -0.21168521  0.32329524 ...  0.11451166 -0.0624956
  -0.05655364]]


In [7]:
from scipy.spatial import KDTree
# Build k-d tree
kdtree = KDTree(cluster_centers)

In [8]:
print(len(cluster_centers))

10000


In [11]:
cluster_features = {}
for i in range(len(cluster_centers)):
    nearest_1_to_4_indices = kdtree.query(cluster_centers[i].reshape(1, -1), k=5)[1][0]
    nearest_20_to_24_indices = kdtree.query(cluster_centers[i].reshape(1, -1), k=25)[1][0][20:]
    cluster_indices = [ nearest_1_to_4_indices, nearest_20_to_24_indices]
    cluster_features[i] = cluster_indices


In [16]:
print(f"Shape of cluster_features: {len(cluster_features)}")
print(cluster_features[0])


Shape of cluster_features: 10000
[array([   0, 4745, 6873, 9693, 9714]), array([7989,  418, 8337, 2752, 7298])]


In [17]:
# Find the indices of points that belong to cluster number i
cluster_i_indices = np.where(kmeans_labels == 1)[0]

In [18]:
print(cluster_i_indices)

[  8834  18442  21009  21694  24316  28134  52735  56047  76627  81932
  88821 101078 115898 121222 144642 146692 149629 174251 175267 178251
 181591 189704 191021 230825 243681 247558 249036 249636 253849 254601
 275897 276378 294587 294601 309607 312158 317218 320043 327732 333131
 334649 340105 344357 344397 354653 366456 366809 369732 388423 392331
 396454 410622 414528 417691 418930 423911 428651]


In [None]:
cluster_lists = []
for i in range(len(cluster_centers)):
    cluster_idx = i
    cluster_points = np.where(kmeans_labels == cluster_idx)[0].tolist()
    nearest_1_to_4_indices = cluster_features[i][1:5].tolist()
    nearest_20_to_24_indices = cluster_features[i][5:].tolist()
    cluster_lists.append([cluster_points, nearest_1_to_4_indices, nearest_20_to_24_indices])


In [19]:
cluster_lists = []
for i in range(len(cluster_centers)):
    cluster_idx = i
    cluster_points = np.where(kmeans_labels == cluster_idx)[0].tolist()
    nearest_cluster_indices = []
    for idx in cluster_features[i][0]:
        nearest_cluster_indices += np.where(kmeans_labels == idx)[0].tolist()
    far_cluster_indices = []
    for idx in cluster_features[i][1]:
        far_cluster_indices += np.where(kmeans_labels == idx)[0].tolist()
        
        
    # Store the points in separate lists
    anchor_points = cluster_points
    near_points = nearest_cluster_indices
    far_points = far_cluster_indices

    # Append the lists to the cluster_lists list
    cluster_lists.append([anchor_points, near_points, far_points])


In [21]:
print(cluster_lists[0])

[[2917, 22072, 22824, 29547, 42478, 60025, 73119, 80742, 111300, 131918, 134449, 137043, 153894, 177236, 186996, 196206, 206481, 255108, 269355, 279607, 281385, 296386, 307309, 312727, 316695, 319151, 323610, 343111, 345244, 362457, 366670, 371390, 374831, 407810, 428155], [2917, 22072, 22824, 29547, 42478, 60025, 73119, 80742, 111300, 131918, 134449, 137043, 153894, 177236, 186996, 196206, 206481, 255108, 269355, 279607, 281385, 296386, 307309, 312727, 316695, 319151, 323610, 343111, 345244, 362457, 366670, 371390, 374831, 407810, 428155, 3537, 16539, 17433, 34340, 39380, 43037, 62081, 70236, 73926, 77376, 102800, 122003, 129730, 134712, 148365, 154099, 181919, 189891, 194910, 197819, 202990, 237612, 252352, 263149, 271299, 285780, 300652, 311307, 351096, 353613, 375092, 377509, 399984, 435105, 441985, 3702, 4215, 11198, 12742, 26998, 32697, 38847, 40963, 71958, 79458, 83555, 83565, 89197, 90477, 100112, 103839, 125595, 137606, 152052, 152282, 152454, 160765, 167483, 174436, 175769, 1

In [34]:
new_cluster_lists = []
for row in cluster_lists:
    for i in range(0, len(row[0])):
        new_row = [row[0][i], row[1], row[2]]
        new_cluster_lists.append(new_row)

In [35]:
print(len(new_cluster_lists))

443757


In [36]:
sorted_list = sorted(new_cluster_lists, key=lambda x: x[0])


In [37]:
print(sorted_list[0][0],sorted_list[1][0],sorted_list[2][0])

0 1 2


In [38]:
with open("list_an_pn.pkl", "wb") as f:
    # Use pickle to dump the list to the file
    pickle.dump(sorted_list, f)