In [None]:
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
from threadpoolctl import threadpool_limits

import sys
sys.path.append('..')

from .utils import find_files
from .configs import KMeansClusterConfig

In [None]:
files = find_files(
    '../data/embeddings_xs/',
    ('npy')
)

len(files)

In [None]:
# Load embeddings
embeddings = [np.load(f) for f in files]
embeddings = np.concatenate(embeddings)

print(embeddings.shape)
# memory taken my embeddings
print(embeddings.nbytes / 1024**3, 'GB')

In [None]:
with threadpool_limits(limits=12, user_api='blas'):
    # Train KMeans
    kmeans = KMeans(n_clusters=1024, random_state=0)
    kmeans.fit(embeddings)

In [None]:
with threadpool_limits(limits=18, user_api='blas'):
    # Train KMeans
    kmeans = MiniBatchKMeans(
        n_clusters=1024,
        max_iter=KMeansClusterConfig.max_iter,
        batch_size=KMeansClusterConfig.batch_size,
        max_no_improvement=KMeansClusterConfig.max_no_improvement,
        n_init=KMeansClusterConfig.n_init,
        reassignment_ratio=KMeansClusterConfig.reassignment_ratio,
        verbose=1,
        compute_labels=True,
        init_size=None,
    )

    kmeans.fit(embeddings)

In [None]:
print(kmeans.inertia_)

In [None]:
import joblib

In [None]:
joblib.dump(kmeans, '../data/kmeans_1024.pkl')