# Overclustering table

In [8]:
%load_ext autoreload
%autoreload 2

import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['JAX_ENABLE_X64'] = '1'

cache_path = "../../cache/"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
import numpy as np
import pandas as pd
import pickle
import corc.utils
import sklearn
import tqdm

In [3]:
dataset_names = {
    "noisy_circles",
    "noisy_moons",
    "varied",
    "aniso",
    "blobs",
    "uniform_circle",
    "clusterlab10",
    "blobs1_8",
    "blobs1_16",
    "blobs1_32",
    "blobs1_64",
    "blobs2_8",
    "blobs2_16",
    "blobs2_32",
    "blobs2_64",
    "densired8",
    "densired16",
    "densired32",
    "densired64",
    "densired_soft_8",
    "densired_soft_16",
    "densired_soft_32",
    "densired_soft_64",
    "mnist8",
    "mnist16",
    "mnist32",
    "mnist64",
}
dataset_name = "densired8"

In [20]:
all_aris = dict()
for dataset_name in tqdm.tqdm(dataset_names):

    file_path = cache_path + f"stability/overclustering_{dataset_name}.pkl"
    if not os.path.exists(file_path):
        continue
    with open(file_path, 'rb') as file:
        models = pickle.load(file)
    X, y, tsne = corc.utils.load_dataset(dataset_name,cache_path=cache_path)
    n_classes = len(np.unique(y))

    ari_scores = list()
    for model in models:
        predictions = model.predict_with_target(X, target_number_classes = n_classes)
        ari_scores.append(sklearn.metrics.adjusted_rand_score(predictions, y))
    all_aris[dataset_name] = ari_scores


  0%|          | 0/27 [00:00<?, ?it/s]

100%|██████████| 27/27 [00:38<00:00,  1.43s/it]


In [26]:
df

Unnamed: 0,aniso,blobs,uniform_circle,densired_soft_8,densired16,varied,densired8,noisy_circles,noisy_moons
0,0.997002,0.982117,1.0,0.736422,0.637769,0.901862,0.687471,-0.000858,0.156012
1,0.994003,0.970377,1.0,0.796314,0.916628,0.896403,0.911436,0.039654,0.15628
2,0.994003,0.970378,1.0,0.800031,0.929779,0.904646,0.913243,0.988024,1.0
3,0.994003,0.964522,1.0,0.884712,0.999153,0.893675,0.98808,1.0,1.0
4,0.994003,0.562925,1.0,0.884895,0.999153,0.00276,0.915238,0.996,1.0
5,0.994003,0.967466,1.0,0.889245,0.999288,0.002963,0.91369,0.996,1.0
6,0.991018,0.964567,1.0,0.888503,0.992804,0.547516,0.990956,1.0,1.0
7,0.994003,0.961668,1.0,0.889069,0.999315,0.54744,0.984158,0.996,1.0
8,0.994003,0.564362,1.0,0.889373,0.82354,0.546302,0.985286,0.996,1.0


In [31]:
def highlight_top_3(df, n=3):
   def highlight_series(s):
     top_n = s.nlargest(n)
     return ['background-color: green' if v in top_n.values else '' for v in s]
   return df.style.apply(highlight_series, axis=0)

df = pd.DataFrame(all_aris)
highlight_top_3(df)

Unnamed: 0,aniso,blobs,uniform_circle,densired_soft_8,densired16,varied,densired8,noisy_circles,noisy_moons
0,0.997002,0.982117,1.0,0.736422,0.637769,0.901862,0.687471,-0.000858,0.156012
1,0.994003,0.970377,1.0,0.796314,0.916628,0.896403,0.911436,0.039654,0.15628
2,0.994003,0.970378,1.0,0.800031,0.929779,0.904646,0.913243,0.988024,1.0
3,0.994003,0.964522,1.0,0.884712,0.999153,0.893675,0.98808,1.0,1.0
4,0.994003,0.562925,1.0,0.884895,0.999153,0.00276,0.915238,0.996,1.0
5,0.994003,0.967466,1.0,0.889245,0.999288,0.002963,0.91369,0.996,1.0
6,0.991018,0.964567,1.0,0.888503,0.992804,0.547516,0.990956,1.0,1.0
7,0.994003,0.961668,1.0,0.889069,0.999315,0.54744,0.984158,0.996,1.0
8,0.994003,0.564362,1.0,0.889373,0.82354,0.546302,0.985286,0.996,1.0
