# Overclustering table

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['JAX_ENABLE_X64'] = '1'

cache_path = "../../cache/"

In [2]:
import numpy as np
import pandas as pd
import pickle
import corc.utils
import sklearn
import tqdm
import corc.our_datasets

In [3]:
dataset_names = corc.our_datasets.COMPLEX_DATASETS

In [None]:
dataset_name = "blobs2_32"
file_path = cache_path + f"stability/overclustering_{dataset_name}.pkl"
with open(file_path, 'rb') as file:
        models = pickle.load(file)

In [None]:
def get_all_aris(mode='tmm'):
    all_aris = dict()
    for dataset_name in (dataset_names):
        if mode == 'tmm':
            file_path = cache_path + f"stability/overclustering_{dataset_name}.pkl"
        elif mode = 'gmm':
            file_path = cache_path + f"stability/overclustering_{dataset_name}_gmm.pkl" 
        else:
            raise notImplementedError(f"mode should be 'gmm' or 'tmm' not {mode}")
        if not os.path.exists(file_path):
            continue
        with open(file_path, 'rb') as file:
            models = pickle.load(file)
        X, y, tsne = corc.utils.load_dataset(dataset_name,cache_path=cache_path)
        n_classes = len(np.unique(y))

        ari_scores = list()
        for model in models:
            predictions = model.predict_with_target(X, target_number_classes = n_classes)
            ari_scores.append(sklearn.metrics.adjusted_rand_score(predictions, y))
        all_aris[dataset_name] = ari_scores
    df = pd.DataFrame(all_aris)
    return df

In [None]:
df_tmm = get_all_aris(mode='tmm')
df_gmm = get_all_aris(mode='gmm')

In [6]:
def highlight_top_3(df, n=3):
   def highlight_series(s):
     top_n = s.nlargest(n)
     return ['background-color: green' if v in top_n.values else '' for v in s]
   return df.style.apply(highlight_series, axis=0)

In [7]:
highlight_top_3(df_gmm)

Unnamed: 0,blobs1_8,blobs1_16,blobs1_32,blobs2_8,blobs2_16,blobs2_32,blobs2_64,densired8,densired16,densired32,densired64,densired_soft_8,densired_soft_16,densired_soft_32,densired_soft_64,mnist8,mnist16,mnist32,mnist64
0,0.814423,0.920047,0.917967,0.572109,0.525852,0.365315,0.566471,0.648763,0.62744,0.495468,0.64202,0.502709,0.541895,0.407808,0.425392,0.889062,0.710529,0.615976,0.566612
1,0.63965,0.899501,0.712205,0.959739,0.854372,0.689198,0.373897,0.90802,0.916563,0.906647,0.916329,0.616039,0.525173,0.304359,0.227974,0.623094,0.555739,0.548065,0.420701
2,0.6299,0.558804,0.412534,0.807805,0.876167,0.682567,0.194747,0.909941,0.930681,0.808626,0.924748,0.603631,0.640347,0.570201,0.204136,0.662431,0.549071,0.512947,0.373812
3,0.232655,0.582265,0.523148,0.938953,0.681332,0.507696,0.250722,0.999584,1.0,0.804409,0.993287,0.700907,0.621934,0.356164,0.214128,0.599622,0.524694,0.416713,0.303265
4,0.298823,0.707846,0.430122,0.906359,0.852878,0.572709,0.001066,0.999584,1.0,1.0,0.993287,0.768248,0.832905,0.600331,0.228557,0.608077,0.494744,0.465637,0.051002
5,0.297672,0.502271,0.312582,0.933201,0.885998,0.613573,0.0,0.999584,1.0,0.931092,0.993287,0.747562,0.735223,0.667802,0.329901,0.651067,0.477122,0.396886,0.136353
6,0.440043,0.472384,0.155123,0.789923,0.536958,0.292272,0.0,0.999584,0.99918,0.927383,0.993789,0.536826,0.74125,0.795942,0.252562,0.59646,0.496604,0.301184,0.02452
7,0.367178,0.481279,0.151702,0.465022,0.771682,0.158333,0.0,0.999502,0.999315,0.928243,0.994337,0.511625,0.803407,0.756454,0.324911,0.595802,0.441954,0.117517,0.033171
8,0.356036,0.463524,0.095611,0.44307,0.582869,0.192373,0.0,0.996811,0.999426,1.0,0.994337,0.473312,0.834772,0.680902,0.620295,0.58745,0.609309,0.047356,0.185953


In [8]:
highlight_top_3(df_tmm)

Unnamed: 0,blobs1_8,blobs1_16,blobs1_32,blobs1_64,blobs2_8,blobs2_16,blobs2_32,blobs2_64,densired8,densired16,densired32,densired64,densired_soft_8,densired_soft_16,densired_soft_32,densired_soft_64,mnist8,mnist16,mnist32,mnist64
0,0.775676,0.890631,0.857074,0.964115,0.619106,0.781662,0.515708,0.544242,0.687471,0.637769,0.571231,0.625237,0.736422,0.733834,0.750715,0.747331,0.883489,0.791376,0.753442,0.630026
1,0.658561,0.581141,0.845557,0.550005,0.757977,0.616129,0.876902,0.465187,0.911436,0.916628,0.911685,0.917171,0.796314,0.710559,0.683403,0.791607,0.773602,0.807896,0.777105,0.600599
2,0.61168,0.710024,0.565021,0.245518,0.922321,0.797421,0.696641,0.249065,0.913243,0.929779,0.808521,0.919783,0.800031,0.853632,0.714313,0.917523,0.666307,0.803425,0.81259,0.591537
3,0.410362,0.415039,0.354273,0.081056,0.870417,0.735078,0.449716,0.021155,0.98808,0.999153,0.808521,0.881878,0.884712,0.844114,0.854328,0.844104,0.771589,0.814932,0.641513,0.526682
4,0.314486,0.591087,0.085818,0.055507,0.866308,0.610421,0.504079,0.0,0.915238,0.999153,0.727185,0.918369,0.884895,0.850864,0.853959,0.842079,0.665501,0.736213,0.778578,0.62027
5,0.156415,0.395989,0.092861,0.029394,0.639056,0.678948,0.26249,0.3755,0.91369,0.999288,0.993328,0.894343,0.889245,0.849815,0.85295,0.689054,0.665063,0.666162,0.795096,0.611451
6,0.351681,0.306422,0.033905,0.0,0.748206,0.485853,0.397772,0.539496,0.990956,0.992804,0.992917,0.995852,0.888503,0.842696,0.868353,0.845843,0.662391,0.810379,0.649608,0.566011
7,0.176628,0.407914,0.038337,0.0,0.625236,0.187122,0.237228,0.269167,0.984158,0.999315,0.899923,0.986499,0.889069,0.854533,0.867293,0.84311,0.599919,0.676672,0.606006,0.519248
8,0.265082,0.140625,0.091169,0.178254,0.745674,0.178682,0.0,0.143935,0.985286,0.82354,0.99317,0.985841,0.889373,0.852529,0.857304,0.842752,0.664835,0.738509,0.710884,0.58856


In [9]:
highlight_top_3( pd.concat([df_gmm.iloc[2:], df_tmm[2:]], ignore_index=True))

Unnamed: 0,blobs1_8,blobs1_16,blobs1_32,blobs2_8,blobs2_16,blobs2_32,blobs2_64,densired8,densired16,densired32,densired64,densired_soft_8,densired_soft_16,densired_soft_32,densired_soft_64,mnist8,mnist16,mnist32,mnist64,blobs1_64
0,0.6299,0.558804,0.412534,0.807805,0.876167,0.682567,0.194747,0.909941,0.930681,0.808626,0.924748,0.603631,0.640347,0.570201,0.204136,0.662431,0.549071,0.512947,0.373812,
1,0.232655,0.582265,0.523148,0.938953,0.681332,0.507696,0.250722,0.999584,1.0,0.804409,0.993287,0.700907,0.621934,0.356164,0.214128,0.599622,0.524694,0.416713,0.303265,
2,0.298823,0.707846,0.430122,0.906359,0.852878,0.572709,0.001066,0.999584,1.0,1.0,0.993287,0.768248,0.832905,0.600331,0.228557,0.608077,0.494744,0.465637,0.051002,
3,0.297672,0.502271,0.312582,0.933201,0.885998,0.613573,0.0,0.999584,1.0,0.931092,0.993287,0.747562,0.735223,0.667802,0.329901,0.651067,0.477122,0.396886,0.136353,
4,0.440043,0.472384,0.155123,0.789923,0.536958,0.292272,0.0,0.999584,0.99918,0.927383,0.993789,0.536826,0.74125,0.795942,0.252562,0.59646,0.496604,0.301184,0.02452,
5,0.367178,0.481279,0.151702,0.465022,0.771682,0.158333,0.0,0.999502,0.999315,0.928243,0.994337,0.511625,0.803407,0.756454,0.324911,0.595802,0.441954,0.117517,0.033171,
6,0.356036,0.463524,0.095611,0.44307,0.582869,0.192373,0.0,0.996811,0.999426,1.0,0.994337,0.473312,0.834772,0.680902,0.620295,0.58745,0.609309,0.047356,0.185953,
7,0.61168,0.710024,0.565021,0.922321,0.797421,0.696641,0.249065,0.913243,0.929779,0.808521,0.919783,0.800031,0.853632,0.714313,0.917523,0.666307,0.803425,0.81259,0.591537,0.245518
8,0.410362,0.415039,0.354273,0.870417,0.735078,0.449716,0.021155,0.98808,0.999153,0.808521,0.881878,0.884712,0.844114,0.854328,0.844104,0.771589,0.814932,0.641513,0.526682,0.081056
9,0.314486,0.591087,0.085818,0.866308,0.610421,0.504079,0.0,0.915238,0.999153,0.727185,0.918369,0.884895,0.850864,0.853959,0.842079,0.665501,0.736213,0.778578,0.62027,0.055507


In [10]:
highlight_top_3(df_tmm - df_gmm)

Unnamed: 0,blobs1_16,blobs1_32,blobs1_64,blobs1_8,blobs2_16,blobs2_32,blobs2_64,blobs2_8,densired16,densired32,densired64,densired8,densired_soft_16,densired_soft_32,densired_soft_64,densired_soft_8,mnist16,mnist32,mnist64,mnist8
0,-0.029416,-0.060893,,-0.038747,0.255811,0.150393,-0.022229,0.046997,0.010329,0.075764,-0.016784,0.038708,0.191939,0.342907,0.321939,0.233713,0.080847,0.137466,0.063414,-0.005572
1,-0.31836,0.133353,,0.018911,-0.238243,0.187705,0.091291,-0.201762,6.5e-05,0.005039,0.000841,0.003416,0.185386,0.379044,0.563632,0.180275,0.252157,0.22904,0.179899,0.150508
2,0.15122,0.152487,,-0.01822,-0.078746,0.014074,0.054318,0.114516,-0.000902,-0.000106,-0.004965,0.003302,0.213285,0.144111,0.713387,0.196401,0.254354,0.299643,0.217725,0.003876
3,-0.167226,-0.168875,,0.177707,0.053746,-0.05798,-0.229567,-0.068536,-0.000847,0.004112,-0.111408,-0.011504,0.22218,0.498164,0.629976,0.183805,0.290238,0.2248,0.223417,0.171967
4,-0.116759,-0.344304,,0.015663,-0.242457,-0.06863,-0.001066,-0.040051,-0.000847,-0.272815,-0.074918,-0.084346,0.017959,0.253628,0.613522,0.116647,0.241468,0.312942,0.569267,0.057424
5,-0.106282,-0.21972,,-0.141257,-0.20705,-0.351083,0.3755,-0.294144,-0.000712,0.062236,-0.098944,-0.085894,0.114592,0.185147,0.359153,0.141683,0.18904,0.39821,0.475098,0.013995
6,-0.165962,-0.121218,,-0.088361,-0.051105,0.1055,0.539496,-0.041717,-0.006376,0.065534,0.002063,-0.008628,0.101446,0.072412,0.593281,0.351676,0.313775,0.348424,0.541491,0.065931
7,-0.073365,-0.113365,,-0.190549,-0.58456,0.078895,0.269167,0.160214,0.0,-0.02832,-0.007838,-0.015344,0.051126,0.110839,0.5182,0.377444,0.234718,0.488489,0.486077,0.004118
8,-0.322899,-0.004442,,-0.090953,-0.404187,-0.192373,0.143935,0.302604,-0.175886,-0.00683,-0.008496,-0.011524,0.017757,0.176402,0.222457,0.416061,0.129199,0.663528,0.402607,0.077385


In [14]:
highlight_top_3(df_tmm - df_gmm)

Unnamed: 0,blobs1_16,blobs1_32,blobs1_8,blobs2_16,blobs2_32,blobs2_64,blobs2_8,densired16,densired32,densired64,densired8,densired_soft_16,densired_soft_32,densired_soft_64,densired_soft_8,mnist16,mnist32,mnist64,mnist8
0,-0.029416,-0.060893,-0.038747,0.255811,,,0.046997,0.010329,0.075764,-0.016784,0.038708,0.191939,0.342907,0.321939,0.233713,0.080847,0.137466,0.063414,-0.005572
1,-0.31836,0.133353,0.018911,-0.238243,,,-0.201762,6.5e-05,0.005039,0.000841,0.003416,0.185386,0.379044,0.563632,0.180275,0.252157,0.22904,0.179899,0.150508
2,0.15122,0.152487,-0.01822,-0.078746,,,0.114516,-0.000902,-0.000106,-0.004965,0.003302,0.213285,0.144111,0.713387,0.196401,0.254354,0.299643,0.217725,0.003876
3,-0.167226,-0.168875,0.177707,0.053746,,,-0.068536,-0.000847,0.004112,-0.111408,-0.011504,0.22218,0.498164,0.629976,0.183805,0.290238,0.2248,0.223417,0.171967
4,-0.116759,-0.344304,0.015663,-0.242457,,,-0.040051,-0.000847,-0.272815,-0.074918,-0.084346,0.017959,0.253628,0.613522,0.116647,0.241468,0.312942,0.569267,0.057424
5,-0.106282,-0.21972,-0.141257,-0.20705,,,-0.294144,-0.000712,0.062236,-0.098944,-0.085894,0.114592,0.185147,0.359153,0.141683,0.18904,0.39821,0.475098,0.013995
6,-0.165962,-0.121218,-0.088361,-0.051105,,,-0.041717,-0.006376,0.065534,0.002063,-0.008628,0.101446,0.072412,0.593281,0.351676,0.313775,0.348424,0.541491,0.065931
7,-0.073365,-0.113365,-0.190549,-0.58456,,,0.160214,0.0,-0.02832,-0.007838,-0.015344,0.051126,0.110839,0.5182,0.377444,0.234718,0.488489,0.486077,0.004118
8,-0.322899,-0.004442,-0.090953,-0.404187,,,0.302604,-0.175886,-0.00683,-0.008496,-0.011524,0.017757,0.176402,0.222457,0.416061,0.129199,0.663528,0.402607,0.077385


In [45]:
def style_diff(val):
 max_shade_val = 0.2
 if val > 0:
    color = 'green'
    shade = int(255 * min(val / max_shade_val, 1))
    return f'background-color: rgb({255-shade}, 255, {255-shade}); color: black'
 elif val < 0:
    color = 'red'
    shade = int(255 * min(-val / max_shade_val, 1))
    return f'background-color: rgb(255, {255-shade}, {255-shade}); color: black'
#  else:
#     return 'background-color: white; color: black'


In [46]:
(df_tmm-df_gmm).style.applymap(style_diff)

Unnamed: 0,noisy_circles,noisy_moons,varied,aniso,blobs,uniform_circle,clusterlab10,blobs1_8,blobs1_16,blobs1_32,blobs2_8,blobs2_16,blobs2_32,blobs2_64,densired8,densired16,densired32,densired64,densired_soft_8,densired_soft_16,densired_soft_32,densired_soft_64,mnist8,mnist16,mnist32,mnist64
0,0.000128,-0.341921,-0.045144,0.0,0.002949,0.0,0.0,-0.038747,-0.029416,-0.060893,0.046997,0.255811,0.150393,-0.022229,0.038708,0.010329,0.075764,-0.016784,0.233713,0.191939,0.342907,0.322129,-0.005572,0.080847,0.137466,0.063414
1,-0.055972,-0.84372,-0.059329,0.0,-1e-06,0.0,0.0,0.005933,-0.31836,0.133353,-0.201762,-0.238081,0.187705,0.578151,0.003416,6.5e-05,0.005039,0.000841,0.180275,0.185386,0.379044,0.56292,0.150508,0.252157,0.22904,0.179899
2,-0.011976,0.0,-0.048178,0.0,0.002912,0.0,0.0,-0.121095,0.034105,0.001198,0.114516,-0.078746,0.106072,0.569746,0.003302,-0.000902,-0.000106,-0.004965,0.196401,0.213285,0.144111,0.713335,0.003876,0.254354,0.299643,0.217725
3,0.0,0.0,-0.059161,0.023704,0.017242,0.0,0.008054,0.002232,-0.173974,-0.509388,-0.375713,-0.078009,-0.038221,0.128362,-0.011504,-0.000847,0.004112,-0.111408,0.183805,0.039297,0.385846,0.630156,0.171967,0.290238,0.2248,0.223417
4,-0.004,0.0,-0.952998,0.005961,-0.40164,0.0,0.0,0.025147,-0.436625,-0.35549,-0.3468,-0.449067,-0.371727,0.445616,-0.084346,-0.000847,-0.272815,-0.074918,0.116647,0.017959,-0.161142,0.613623,0.057424,0.241468,0.312942,0.569267
5,-0.004,0.0,-0.556081,0.008941,0.42221,0.0,-0.205517,0.003767,-0.228942,-0.038366,-0.424792,-0.400094,-0.337309,0.573,-0.085894,-0.000712,0.062236,-0.098944,0.141683,0.114592,-0.153545,0.587021,0.013995,0.18904,0.39821,0.475098
6,0.0,0.0,0.000951,0.002958,0.003009,0.0,0.028987,-0.102517,-0.384201,0.103558,-0.030941,-0.215009,-0.414069,0.393415,-0.008628,-0.006376,0.065534,0.002063,0.351676,0.101446,-0.795854,0.539709,0.065931,0.313775,0.348424,0.541491
7,-0.004,0.0,-0.405438,-0.002994,-0.002848,0.0,0.141778,-0.179199,-0.058146,-0.131691,-0.042699,-0.428192,-0.048715,-0.147007,-0.015344,0.0,-0.02832,-0.007307,0.377444,0.073982,-0.758126,-0.32592,0.004118,0.234718,0.488489,0.486077
8,-0.004,0.0,-0.397814,0.01193,-0.39725,0.0,0.021756,-0.021835,-0.099185,0.27495,0.155265,-0.428492,0.054129,-0.160103,-0.011524,-0.175886,-0.00683,-0.008496,0.416061,0.042507,-0.681476,-0.096251,0.077385,0.129199,0.663528,0.402607
