# Globals

In [None]:
import copy
import datetime
import os
from collections import defaultdict

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
import scipy.stats
import seaborn
import sklearn.metrics
import torchvision.datasets
from IPython.display import display
from tqdm.autonotebook import tqdm

In [None]:
FIGS_DIR = "figs"
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
VALIDATION_DATASETS = ["imagenet", "imagenette", "imagewoof"]
RESNET50_MODELS = [
    "random_resnet50",
    "resnet50",
    "mocov3_resnet50",
    "dino_resnet50",
    "vicreg_resnet50",
    "clip_RN50",
]
VITB16_MODELS = [
    "random_vitb16",
    "vitb16",
    "mocov3_vit_base",
    "dino_vitb16",
    "timm_vit_base_patch16_224.mae",
    "mae_pretrain_vit_base_global",
    "clip_vitb16",
]
FT_RESNET50_MODELS = [
    "ft_mocov3_resnet50",
    "ft_dino_resnet50",
    "ft_vicreg_resnet50",
]
FT_VITB16_MODELS = [
    "ft_mocov3_vit_base",
    "ft_dino_vitb16",
    "mae_finetuned_vit_base_global",
]
FT_MODELS = FT_RESNET50_MODELS + FT_VITB16_MODELS
ALL_MODELS = ["none"] + RESNET50_MODELS + VITB16_MODELS + FT_RESNET50_MODELS + FT_VITB16_MODELS

RESNET50_MODELS_INTERLEAVED = [
    "random_resnet50",
    "resnet50",
    "mocov3_resnet50",
    "ft_mocov3_resnet50",
    "dino_resnet50",
    "ft_dino_resnet50",
    "vicreg_resnet50",
    "ft_vicreg_resnet50",
]
VITB16_MODELS_INTERLEAVED = [
    "random_vitb16",
    "vitb16",
    "mocov3_vit_base",
    "ft_mocov3_vit_base",
    "dino_vitb16",
    "ft_dino_vitb16",
    "timm_vit_base_patch16_224.mae",
    "mae_pretrain_vit_base_global",
    "mae_finetuned_vit_base_global",
]

CLUSTERERS = [
    "KMeans",
    "LouvainCommunities",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "SpectralClustering",
    "HDBSCAN",
    "OPTICS",
]
ALL_CLUSTERERS = copy.deepcopy(CLUSTERERS)
DISTANCE_METRICS = [
    "euclidean",
    "l1",
    "chebyshev",
    "cosine",
    "arccos",
    "braycurtis",
    "canberra",
]

In [None]:
PRE2FT = {
    k: "ft_" + k
    for k in [
        "mocov3_resnet50",
        "dino_resnet50",
        "vicreg_resnet50",
        "mocov3_vit_base",
        "dino_vitb16",
    ]
}
PRE2FT["mae_pretrain_vit_base_global"] = "mae_finetuned_vit_base_global"
FT2PRE = {v: k for k, v in PRE2FT.items()}

In [None]:
DATASET2LS = {
    "imagenet": "-.",
    "imagenette": "--",
    "imagewoof": ":",
}

In [None]:
DEFAULT_PARAMS = {
    "all": {
        "dim_reducer": "None",
        "dim_reducer_man": "None",
        "zscore": False,
        "normalize": False,
        "zscore2": False,
        "ndim_correction": False,
    },
    "KMeans": {"clusterer": "KMeans"},
    "LouvainCommunities": {
        "clusterer": "LouvainCommunities",
        "louvain_resolution": 1.0,
        "louvain_threshold": 1e-7,
        "louvain_remove_self_loops": False,
        "distance_metric": "l2",
    },
    "AffinityPropagation": {
        "clusterer": "AffinityPropagation",
        "affinity_damping": 0.9,
        "affinity_conv_iter": 15,
    },
    "SpectralClustering": {
        "clusterer": "SpectralClustering",
        "spectral_assigner": "cluster_qr",
        "spectral_affinity": "nearest_neighbors",
        "spectral_n_neighbors": 10,
        "spectral_n_components": None,
    },
    "AgglomerativeClustering": {
        "clusterer": "AgglomerativeClustering",
        "distance_metric": "euclidean",
        "aggclust_linkage": "ward",
    },
    "HDBSCAN": {
        "clusterer": "HDBSCAN",
        "hdbscan_method": "eom",
        "min_samples": 5,
        "max_samples": 0.2,
        "distance_metric": "euclidean",
    },
    "OPTICS": {
        "clusterer": "OPTICS",
        "optics_method": "xi",
        "optics_xi": 0.05,
        "distance_metric": "euclidean",
    },
}

## Set best params

These were discovered by the search in hpsearch.ipynb.

### Num dims

In [None]:
models = RESNET50_MODELS + VITB16_MODELS
BEST_PARAMS = {clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in models} for clusterer in ALL_CLUSTERERS}

# KMeans
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - clip_RN50 : a little better to use PCA with 500d than UMAP. UMAP beats PCA if you
#   reduce the PCA dims below 500.
# - clip_vitb16 : same behaviour as clip_RN50
# - timm_vit_base_patch16_224.mae : best is PCA 0.85 variance explained. Need at least
#   200 PCA dims, and PCA perf beats UMAP throughout

for model in RESNET50_MODELS + VITB16_MODELS:
    if model.startswith("clip") or model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["KMeans"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})

BEST_PARAMS["KMeans"]["clip_RN50"].update({"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None})
BEST_PARAMS["KMeans"]["clip_vitb16"].update({"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None})
BEST_PARAMS["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True, "ndim_reduced": None}
)

# AffinityPropagation
# Use PCA with 10 dims for every encoder except
# - resnet50 (supervised) : original embeddings, no reduction (AMI=0.62);
#   perf gets worse if they are whitened (AMI=0.55) and although the perf increases
#   as num dims are reduced it doesn't quite recover. PCA perf peaks at 10-20 dim (AMI=0.57).
# - dino_resnet50 : does marginally better at UMAP 50 (AMI=0.52495) than PCA 10 (AMI=0.5044)
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.303).
#   Definite improvement from 10 to 20 dims, but not much improvement above that.

for model in models:
    if model in ["resnet50", "dino_resnet50", "timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["AffinityPropagation"][model].update(
        {
            "dim_reducer": "PCA",
            "ndim_reduced": 10,
            "zscore": True,
            "pca_variance": None,
            "dim_reducer_man": "None",
        }
    )

BEST_PARAMS["AffinityPropagation"]["resnet50"].update({"dim_reducer": "None", "dim_reducer_man": "None", "zscore": False})
BEST_PARAMS["AffinityPropagation"]["dino_resnet50"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)
BEST_PARAMS["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# AgglomerativeClustering
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.98 variance explained (i.e. nearly all
#   dimensions kept), which is not noticably better than using 500 dim PCA but there is
#   an increase compared to using less than 500d.

for model in models:
    if model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["AgglomerativeClustering"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.98,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# HDBSCAN
# Use UMAP for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.085) which is
#   not noticably better than PCA with 50 dim

for model in models:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["HDBSCAN"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# OPTICS
# Use UMAP for every encoder, no exceptions necessary
for model in models:
    BEST_PARAMS["OPTICS"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

In [None]:
BEST_PARAMS_v1 = copy.deepcopy(BEST_PARAMS)
BEST_PARAMS_v1["_version"] = "v1.0"

In [None]:
BEST_PARAMS_v2 = copy.deepcopy(BEST_PARAMS)
BEST_PARAMS_v2["_version"] = "v2.0"

print("Updating dim choices for new method")
# Updated dim choices
# (changed to this when we swapped to using weighted average instead of straight
# average between Imagenet-1k, Imagenette, Imagewoof)

# Changed KMeans clip_RN50 from PCA 500 to UMAP 50, so it uses fewer dimensions
# (probably more stable than using 500-d which is what PCA needs to marginally beat UMAP)
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update({"dim_reducer": None, "ndim_reduced": None, "zscore": False, "pca_variance": None})
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
# Changed KMeans MAE from PCA 85% to PCA 200
# (since we see perf above plateaus at 200-d, there is no point going above that)
BEST_PARAMS_v2["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 200, "pca_variance": None}
)
# Changed KMeans clip_vitb16 from PCA 500 to PCA 75%
# (gives a notably better train set AMI measurement above)
BEST_PARAMS_v2["KMeans"]["clip_vitb16"].update({"dim_reducer": "PCA", "zscore": True, "pca_variance": 0.75, "ndim_reduced": None})

# Changed AffinityPropagation dino_resnet50 from PCA 95% to PCA 10
# (performance is basically equal, so no point using higher-dim space;
# could have done UMAP 50 instead with basically equal train AMI to PCA 10,
# but didn't for consistency with other models)
BEST_PARAMS_v2["AffinityPropagation"]["dino_resnet50"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 10, "pca_variance": None}
)
# Changed AffinityPropagation MAE from PCA 95% to PCA 100
BEST_PARAMS_v2["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 100, "pca_variance": None}
)

In [None]:
print(
    "Updating dim choices to use Affinity Prop dim results found with 0.9 damping,"
    " prefering PCA reduction by percentage variance explained"
)
BEST_PARAMS_v3 = {clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in ALL_MODELS} for clusterer in ALL_CLUSTERERS}
BEST_PARAMS_v3["_version"] = "v3.0"

# KMeans
for model in RESNET50_MODELS + VITB16_MODELS + FT_MODELS:
    if model == "none" or model.startswith("random") or model.startswith("clip") or model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS_v3["KMeans"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})

BEST_PARAMS_v3["KMeans"]["none"].update({"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True})
BEST_PARAMS_v3["KMeans"]["random_resnet50"].update({"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True})
BEST_PARAMS_v3["KMeans"]["random_vitb16"].update({"dim_reducer": "PCA", "ndim_reduced": 100, "zscore": True})

BEST_PARAMS_v3["KMeans"]["clip_RN50"].update({"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True})
BEST_PARAMS_v3["KMeans"]["clip_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True})
BEST_PARAMS_v3["KMeans"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True})

# AffinityPropagation
for model in ALL_MODELS:
    BEST_PARAMS_v3["AffinityPropagation"][model].update({"affinity_damping": 0.9})

for model in [
    "resnet50",
    "clip_RN50",
    "vitb16",
    "mocov3_vit_base",
    "mae_pretrain_vit_base_global",
    "dino_vitb16",
    "clip_vitb16",
] + FT_MODELS:
    BEST_PARAMS_v3["AffinityPropagation"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
for model in ["mocov3_resnet50", "vicreg_resnet50", "dino_resnet50"]:
    BEST_PARAMS_v3["AffinityPropagation"][model].update(
        {
            "dim_reducer_man": "PaCMAP",
            "ndim_reduced_man": 50,
            "dim_reducer_man_nn": None,
        }
    )

BEST_PARAMS_v3["AffinityPropagation"]["none"].update({"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.8, "zscore": True})
BEST_PARAMS_v3["AffinityPropagation"]["random_resnet50"].update({"dim_reducer": "PCA", "pca_variance": 0.99, "zscore": True})
BEST_PARAMS_v3["AffinityPropagation"]["random_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True})

BEST_PARAMS_v3["KMeans"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "pca_variance": 0.99, "zscore": True})

# AgglomerativeClustering
for model in ALL_MODELS:
    if model == "none" or model.startswith("random") or model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS_v3["AgglomerativeClustering"]["none"].update({"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True})
BEST_PARAMS_v3["AgglomerativeClustering"]["random_resnet50"].update({"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True})
BEST_PARAMS_v3["AgglomerativeClustering"]["random_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True})
BEST_PARAMS_v3["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True}
)

# HDBSCAN
for model in ALL_MODELS:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS_v3["HDBSCAN"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS_v3["HDBSCAN"]["none"].update({"image_size": 32})
BEST_PARAMS_v3["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True})

# OPTICS - TODO
# Use UMAP for every encoder, no exceptions necessary (not checked raw or random)
for model in ALL_MODELS:
    BEST_PARAMS_v3["OPTICS"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

In [None]:
print("Updating dim choices to use Affinity Prop dim results found with 0.9 damping," " stop PCA at 95%")
BEST_PARAMS_v4 = {clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in ALL_MODELS} for clusterer in ALL_CLUSTERERS}
BEST_PARAMS_v4["_version"] = "v4.0"
for clusterer in BEST_PARAMS_v4:
    if clusterer.startswith("_"):
        continue
    BEST_PARAMS_v4[clusterer]["none"].update({"image_size": 32})

# KMeans
for model in RESNET50_MODELS + VITB16_MODELS + FT_MODELS:
    if (
        model == "none"
        or model.startswith("random")
        or model.startswith("clip")
        or model == "timm_vit_base_patch16_224.mae"
        or model == "mae_pretrain_vit_base_global"
    ):
        continue
    BEST_PARAMS_v4["KMeans"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})

BEST_PARAMS_v4["KMeans"]["none"].update({"dim_reducer": "PCA", "pca_variance": 0.90, "zscore": True})
BEST_PARAMS_v4["KMeans"]["random_resnet50"].update({"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True})
BEST_PARAMS_v4["KMeans"]["random_vitb16"].update({"dim_reducer": "PCA", "ndim_reduced": 100, "zscore": True})

BEST_PARAMS_v4["KMeans"]["clip_RN50"].update({"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True})
BEST_PARAMS_v4["KMeans"]["clip_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True})
BEST_PARAMS_v4["KMeans"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True})
BEST_PARAMS_v4["KMeans"]["mae_pretrain_vit_base_global"].update({"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True})

# AffinityPropagation
for model in ALL_MODELS:
    BEST_PARAMS_v4["AffinityPropagation"][model].update({"affinity_damping": 0.9})

for model in [
    "resnet50",
    "clip_RN50",
    "vitb16",
    "mocov3_vit_base",
    "mae_pretrain_vit_base_global",
    "dino_vitb16",
    "clip_vitb16",
] + FT_MODELS:
    BEST_PARAMS_v4["AffinityPropagation"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
for model in ["mocov3_resnet50", "vicreg_resnet50", "dino_resnet50"]:
    # tbc
    BEST_PARAMS_v4["AffinityPropagation"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer_man_nn": None})

BEST_PARAMS_v4["AffinityPropagation"]["none"].update({"dim_reducer": "PCA", "pca_variance": 0.8, "zscore": True})
BEST_PARAMS_v4["AffinityPropagation"]["random_resnet50"].update({"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True})
BEST_PARAMS_v4["AffinityPropagation"]["random_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True})
BEST_PARAMS_v4["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True})

# AgglomerativeClustering
for model in ALL_MODELS:
    if model == "none" or model.startswith("random") or model == "timm_vit_base_patch16_224.mae" or model == "mae_pretrain_vit_base_global":
        continue
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS_v4["AgglomerativeClustering"]["none"].update({"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True})
BEST_PARAMS_v4["AgglomerativeClustering"]["random_resnet50"].update({"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True})
BEST_PARAMS_v4["AgglomerativeClustering"]["random_vitb16"].update({"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True})
BEST_PARAMS_v4["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.90, "zscore": True}
)
BEST_PARAMS_v4["AgglomerativeClustering"]["mae_pretrain_vit_base_global"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)

# HDBSCAN
for model in ALL_MODELS:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS_v4["HDBSCAN"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update({"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True})

# OPTICS - TODO
# Use UMAP for every encoder, no exceptions necessary (not checked raw or random)
for model in ALL_MODELS:
    BEST_PARAMS_v4["OPTICS"][model].update({"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"})

In [None]:
BEST_PARAMS_v5 = copy.deepcopy(BEST_PARAMS_v4)
BEST_PARAMS_v5["_version"] = "v5.0"

BEST_PARAMS_v5["SpectralClustering"]["none"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["random_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "ndim_reduced": 200,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["resnet50"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["mocov3_resnet50"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["dino_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.8,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["vicreg_resnet50"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["clip_RN50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.9,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["random_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.95,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.7,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["mocov3_vit_base"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.85,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["dino_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.9,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["timm_vit_base_patch16_224.mae"].update(
    {"zscore": True, "dim_reducer": "None", "dim_reducer_man": "None"}
)
BEST_PARAMS_v5["SpectralClustering"]["mae_pretrain_vit_base_global"].update(
    {"zscore": True, "dim_reducer": "None", "dim_reducer_man": "None"}
)
BEST_PARAMS_v5["SpectralClustering"]["clip_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.7,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["ft_mocov3_resnet50"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["ft_dino_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.8,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["ft_vicreg_resnet50"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["SpectralClustering"]["ft_mocov3_vit_base"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.95,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["ft_dino_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.9,
    }
)
BEST_PARAMS_v5["SpectralClustering"]["mae_finetuned_vit_base_global"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)

In [None]:
BEST_PARAMS_v5["LouvainCommunities"]["none"].update({"zscore": False, "dim_reducer": "None", "dim_reducer_man": "None"})
BEST_PARAMS_v5["LouvainCommunities"]["random_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.7,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["resnet50"].update({"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_resnet50"].update({"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
BEST_PARAMS_v5["LouvainCommunities"]["dino_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["vicreg_resnet50"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.9,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["random_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "ndim_reduced": 10,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_vit_base"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["dino_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["timm_vit_base_patch16_224.mae"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "pca_variance": 0.75,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["mae_pretrain_vit_base_global"].update(
    {"zscore": True, "dim_reducer": "None", "dim_reducer_man": "None"}
)
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_resnet50"].update(
    {"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
)
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_resnet50"].update({"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50})
BEST_PARAMS_v5["LouvainCommunities"]["ft_vicreg_resnet50"].update(
    {"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
)  # adjusted 20 -> 50
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_vit_base"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "ndim_reduced": 100,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_vitb16"].update(
    {
        "zscore": True,
        "dim_reducer": "PCA",
        "dim_reducer_man": "None",
        "ndim_reduced": 10,
    }
)
BEST_PARAMS_v5["LouvainCommunities"]["mae_finetuned_vit_base_global"].update(
    {"dim_reducer": "None", "dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
)  # adjusted 10 -> 50

### Agglomerative specific settings

In [None]:
for model in [
    "resnet50",
    "mocov3_resnet50",
    "vicreg_resnet50",
    "vitb16",
    "timm_vit_base_patch16_224.mae",
]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# vicreg_resnet50 is the only change from v1 to v2
for model in ["resnet50", "mocov3_resnet50", "vitb16", "timm_vit_base_patch16_224.mae"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
for model in ["none", "resnet50", "mocov3_resnet50", "vitb16"] + FT_MODELS:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16", "random_resnet50", "random_vitb16"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )
for model in ["timm_vit_base_patch16_224.mae"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "cosine",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# TODO:
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leaving as-is for now)
for model in ALL_MODELS:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "tbd",
            "aggclust_linkage": "tbd",
        }
    )

for model in ["resnet50", "mocov3_resnet50", "vitb16"] + FT_MODELS:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16", "random_resnet50", "random_vitb16"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )
for model in ["none", "timm_vit_base_patch16_224.mae", "mae_pretrain_vit_base_global"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "cosine",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# TODO:
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leaving as-is for now)
for model in ALL_MODELS:
    BEST_PARAMS_v5["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "tbd",
            "aggclust_linkage": "tbd",
        }
    )

for model in ["resnet50", "mocov3_resnet50", "vitb16"] + FT_MODELS:
    BEST_PARAMS_v5["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v5["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16", "random_resnet50", "random_vitb16"]:
    BEST_PARAMS_v5["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )
for model in ["none", "timm_vit_base_patch16_224.mae", "mae_pretrain_vit_base_global"]:
    BEST_PARAMS_v5["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "cosine",
            "aggclust_linkage": "average",
        }
    )

In [None]:
BEST_PARAMS_v1["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v1["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])
BEST_PARAMS_v3["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v3["AgglomerativeClustering"])
BEST_PARAMS_v3["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v3["AgglomerativeClustering"])
BEST_PARAMS_v4["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v4["AgglomerativeClustering"])
BEST_PARAMS_v4["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v4["AgglomerativeClustering"])
BEST_PARAMS_v5["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v5["AgglomerativeClustering"])
BEST_PARAMS_v5["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v5["AgglomerativeClustering"])

In [None]:
for model in BEST_PARAMS_v1["AC w/ C"]:
    BEST_PARAMS_v1["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v2["AC w/ C"]:
    BEST_PARAMS_v2["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v3["AC w/ C"]:
    BEST_PARAMS_v3["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v4["AC w/ C"]:
    BEST_PARAMS_v4["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v5["AC w/ C"]:
    BEST_PARAMS_v5["AC w/ C"][model].update({"aggclust_dist_thresh": None})

In [None]:
for model in BEST_PARAMS_v2["AC w/o C"]:
    BEST_PARAMS_v2["AC w/o C"][model].update({"zscore2": "average", "ndim_correction": True})
for model in BEST_PARAMS_v3["AC w/o C"]:
    BEST_PARAMS_v3["AC w/o C"][model].update({"zscore2": "average", "ndim_correction": True})
for model in BEST_PARAMS_v4["AC w/o C"]:
    BEST_PARAMS_v4["AC w/o C"][model].update({"zscore2": "average", "ndim_correction": True})
for model in BEST_PARAMS_v5["AC w/o C"]:
    BEST_PARAMS_v5["AC w/o C"][model].update({"zscore2": "average", "ndim_correction": True})

In [None]:
# Run AgglomerativeClustering experiments with number of clusters unknown
# 	resnet50        	20.0
# 	mocov3_resnet50 	20.0
# 	vicreg_resnet50 	20.0
# 	vitb16 	            20.0
# 	dino_resnet50     	 1.0
# 	clip_RN50 	         1.0
# 	dino_vitb16 	     2.0
# 	mocov3_vit_base 	 1.0
# 	clip_vitb16 	     0.5
# 	timm_vit_base_patch16_224.mae 	200.0

for model in ["resnet50", "mocov3_resnet50", "vicreg_resnet50", "vitb16"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 20.0})
for model in ["dino_resnet50", "clip_RN50", "mocov3_vit_base"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 1.0})
BEST_PARAMS_v1["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v1["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v1["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 200.0

In [None]:
BEST_PARAMS_v2["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v2["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v2["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 5.0
BEST_PARAMS_v2["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v2["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0

In [None]:
BEST_PARAMS_v3["AC w/o C"]["none"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v3["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0

In [None]:
# TODO:
# - none
# - random_resnet50
# - timm_vit_base_patch16_224.mae
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leave as-is)
# - ft_mocov3_resnet50 (tbc)
# - mae_finetuned_vit_base_global
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0  # tbc
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.0"

In [None]:
# TODO:
# - none
# - timm_vit_base_patch16_224.mae (tbc)
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leave as-is)
BEST_PARAMS_v4["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.5  # tbc
BEST_PARAMS_v4["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mae_finetuned_vit_base_global"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.1"

In [None]:
# v4.4
BEST_PARAMS_v4["AC w/o C"]["none"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["mae_pretrain_vit_base_global"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mae_finetuned_vit_base_global"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.4"

In [None]:
# v5.0
BEST_PARAMS_v5["AC w/o C"]["none"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v5["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v5["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v5["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v5["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v5["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v5["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v5["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v5["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v5["AC w/o C"]["mae_pretrain_vit_base_global"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v5["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v5["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v5["AC w/o C"]["mae_finetuned_vit_base_global"]["aggclust_dist_thresh"] = 2.0

### Affinity Prop

In [None]:
for model in BEST_PARAMS_v1["AffinityPropagation"]:
    BEST_PARAMS_v1["AffinityPropagation"][model]["affinity_damping"] = 0.5
for model in BEST_PARAMS_v2["AffinityPropagation"]:
    BEST_PARAMS_v2["AffinityPropagation"][model]["affinity_damping"] = 0.5
for model in BEST_PARAMS_v3["AffinityPropagation"]:
    BEST_PARAMS_v3["AffinityPropagation"][model]["affinity_damping"] = 0.9
for model in BEST_PARAMS_v4["AffinityPropagation"]:
    BEST_PARAMS_v4["AffinityPropagation"][model]["affinity_damping"] = 0.9

In [None]:
BEST_PARAMS_v3["AffinityPropagation"]["none"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["random_resnet50"]["affinity_damping"] = 0.5
BEST_PARAMS_v3["AffinityPropagation"]["resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["mocov3_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v3["AffinityPropagation"]["dino_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v3["AffinityPropagation"]["vicreg_resnet50"]["affinity_damping"] = 0.75
BEST_PARAMS_v3["AffinityPropagation"]["clip_RN50"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["random_vitb16"]["affinity_damping"] = 0.7
BEST_PARAMS_v3["AffinityPropagation"]["vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["mocov3_vit_base"]["affinity_damping"] = 0.75
BEST_PARAMS_v3["AffinityPropagation"]["dino_vitb16"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["timm_vit_base_patch16_224.mae"]["affinity_damping"] = 0.5
BEST_PARAMS_v3["AffinityPropagation"]["mae_pretrain_vit_base_global"]["affinity_damping"] = 0.9  # To match
BEST_PARAMS_v3["AffinityPropagation"]["clip_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v3["AffinityPropagation"]["ft_mocov3_resnet50"]["affinity_damping"] = 0.9  # Match supervised/ft resnet50
BEST_PARAMS_v3["AffinityPropagation"]["ft_vicreg_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["ft_dino_vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["ft_mocov3_vit_base"]["affinity_damping"] = 0.9  # Match supervised/ft resnet50
BEST_PARAMS_v3["AffinityPropagation"]["mae_finetuned_vit_base_global"]["affinity_damping"] = 0.9

In [None]:
BEST_PARAMS_v4["AffinityPropagation"]["none"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["random_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mocov3_resnet50"]["affinity_damping"] = 0.75
BEST_PARAMS_v4["AffinityPropagation"]["dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["vicreg_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v4["AffinityPropagation"]["clip_RN50"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["random_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mocov3_vit_base"]["affinity_damping"] = 0.75
BEST_PARAMS_v4["AffinityPropagation"]["dino_vitb16"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["timm_vit_base_patch16_224.mae"]["affinity_damping"] = 0.6
BEST_PARAMS_v4["AffinityPropagation"]["mae_pretrain_vit_base_global"]["affinity_damping"] = 0.6
BEST_PARAMS_v4["AffinityPropagation"]["clip_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_mocov3_resnet50"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["ft_vicreg_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["ft_mocov3_vit_base"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_dino_vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mae_finetuned_vit_base_global"]["affinity_damping"] = 0.9

In [None]:
BEST_PARAMS_v5["AffinityPropagation"]["none"]["affinity_damping"] = 0.85
BEST_PARAMS_v5["AffinityPropagation"]["random_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["mocov3_resnet50"]["affinity_damping"] = 0.75
BEST_PARAMS_v5["AffinityPropagation"]["dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["vicreg_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v5["AffinityPropagation"]["clip_RN50"]["affinity_damping"] = 0.85
BEST_PARAMS_v5["AffinityPropagation"]["random_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v5["AffinityPropagation"]["vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["mocov3_vit_base"]["affinity_damping"] = 0.75
BEST_PARAMS_v5["AffinityPropagation"]["dino_vitb16"]["affinity_damping"] = 0.85
BEST_PARAMS_v5["AffinityPropagation"]["timm_vit_base_patch16_224.mae"]["affinity_damping"] = 0.6
BEST_PARAMS_v5["AffinityPropagation"]["mae_pretrain_vit_base_global"]["affinity_damping"] = 0.6
BEST_PARAMS_v5["AffinityPropagation"]["clip_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v5["AffinityPropagation"]["ft_mocov3_resnet50"]["affinity_damping"] = 0.95
BEST_PARAMS_v5["AffinityPropagation"]["ft_dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["ft_vicreg_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["ft_mocov3_vit_base"]["affinity_damping"] = 0.95
BEST_PARAMS_v5["AffinityPropagation"]["ft_dino_vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v5["AffinityPropagation"]["mae_finetuned_vit_base_global"]["affinity_damping"] = 0.9

### HDBSCAN

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v1["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )

v2 selection

|    | model                         | distance_metric   | hdbscan_method   |      AMI |
|---:|:------------------------------|:------------------|:-----------------|---------:|
|  0 | resnet50                      | euclidean         | eom              | 0.828368 |
|  1 | mocov3_resnet50               | euclidean         | eom              | 0.531644 |
|  2 | vicreg_resnet50               | l1                | eom              | 0.472324 |
|  3 | dino_resnet50                 | l1                | eom              | 0.503147 |
|  4 | clip_RN50                     | l1                | eom              | 0.461363 |
|  5 | vitb16                        | chebyshev         | eom              | 0.906110 |
|  6 | mocov3_vit_base               | euclidean         | eom              | 0.629966 |
|  7 | timm_vit_base_patch16_224.mae | euclidean         | eom              | 0.070495 |
|  8 | dino_vitb16                   | l1                | eom              | 0.691547 |
|  9 | clip_vitb16                   | l1                | eom              | 0.592489 |

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )
for model in [
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
    "dino_vitb16",
    "clip_vitb16",
]:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "l1",
        }
    )
BEST_PARAMS_v2["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"

In [None]:
for model in [
    "resnet50",
    "mocov3_resnet50",
    "mocov3_vit_base",
    "timm_vit_base_patch16_224.mae",
]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )

for model in [
    "random_resnet50",
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
    "random_vitb16",
    "dino_vitb16",
    "clip_vitb16",
]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "l1",
            "hdbscan_method": "eom",
        }
    )

for model in ["vitb16"]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "chebyshev",
            "hdbscan_method": "eom",
        }
    )

In [None]:
BEST_PARAMS_v4["HDBSCAN"]["none"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["none"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["random_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["random_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["vicreg_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["clip_RN50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["clip_RN50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["random_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["random_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_vit_base"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["dino_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mae_pretrain_vit_base_global"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["mae_pretrain_vit_base_global"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["clip_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["clip_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_resnet50"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_vicreg_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["ft_vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_vit_base"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mae_finetuned_vit_base_global"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["mae_finetuned_vit_base_global"]["hdbscan_method"] = "eom"

In [None]:
BEST_PARAMS_v5["HDBSCAN"]["none"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["none"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["random_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["random_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["mocov3_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["vicreg_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["clip_RN50"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["clip_RN50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["random_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["random_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v5["HDBSCAN"]["vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["mocov3_vit_base"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["dino_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["mae_pretrain_vit_base_global"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["mae_pretrain_vit_base_global"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["clip_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["clip_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["ft_mocov3_resnet50"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v5["HDBSCAN"]["ft_mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["ft_dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v5["HDBSCAN"]["ft_dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["ft_vicreg_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v5["HDBSCAN"]["ft_vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["ft_mocov3_vit_base"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v5["HDBSCAN"]["ft_mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["ft_dino_vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v5["HDBSCAN"]["ft_dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v5["HDBSCAN"]["mae_finetuned_vit_base_global"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v5["HDBSCAN"]["mae_finetuned_vit_base_global"]["hdbscan_method"] = "eom"

### Spectral

In [None]:
BEST_PARAMS_v5["SpectralClustering"]["none"]["spectral_n_neighbors"] = 10
BEST_PARAMS_v5["SpectralClustering"]["random_resnet50"]["spectral_n_neighbors"] = 50
BEST_PARAMS_v5["SpectralClustering"]["resnet50"]["spectral_n_neighbors"] = 20
BEST_PARAMS_v5["SpectralClustering"]["mocov3_resnet50"]["spectral_n_neighbors"] = 30
BEST_PARAMS_v5["SpectralClustering"]["dino_resnet50"]["spectral_n_neighbors"] = 10
BEST_PARAMS_v5["SpectralClustering"]["vicreg_resnet50"]["spectral_n_neighbors"] = 10
BEST_PARAMS_v5["SpectralClustering"]["clip_RN50"]["spectral_n_neighbors"] = 30
BEST_PARAMS_v5["SpectralClustering"]["random_vitb16"]["spectral_n_neighbors"] = 50
BEST_PARAMS_v5["SpectralClustering"]["vitb16"]["spectral_n_neighbors"] = 30
BEST_PARAMS_v5["SpectralClustering"]["mocov3_vit_base"]["spectral_n_neighbors"] = 50
BEST_PARAMS_v5["SpectralClustering"]["dino_vitb16"]["spectral_n_neighbors"] = 10
BEST_PARAMS_v5["SpectralClustering"]["timm_vit_base_patch16_224.mae"]["spectral_n_neighbors"] = 10
BEST_PARAMS_v5["SpectralClustering"]["mae_pretrain_vit_base_global"]["spectral_n_neighbors"] = 30
BEST_PARAMS_v5["SpectralClustering"]["clip_vitb16"]["spectral_n_neighbors"] = 20
BEST_PARAMS_v5["SpectralClustering"]["ft_mocov3_resnet50"]["spectral_n_neighbors"] = 30
BEST_PARAMS_v5["SpectralClustering"]["ft_dino_resnet50"]["spectral_n_neighbors"] = 20
BEST_PARAMS_v5["SpectralClustering"]["ft_vicreg_resnet50"]["spectral_n_neighbors"] = 20
BEST_PARAMS_v5["SpectralClustering"]["ft_mocov3_vit_base"]["spectral_n_neighbors"] = 50
BEST_PARAMS_v5["SpectralClustering"]["ft_dino_vitb16"]["spectral_n_neighbors"] = 50
BEST_PARAMS_v5["SpectralClustering"]["mae_finetuned_vit_base_global"]["spectral_n_neighbors"] = 50

### Louvain

In [None]:
BEST_PARAMS_v5["LouvainCommunities"]["none"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["random_resnet50"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["resnet50"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_resnet50"].update({"distance_metric": "l1", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["dino_resnet50"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["vicreg_resnet50"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["random_vitb16"].update({"distance_metric": "l1", "louvain_remove_self_loops": True})
BEST_PARAMS_v5["LouvainCommunities"]["vitb16"].update({"distance_metric": "chebyshev", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_vit_base"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["dino_vitb16"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["timm_vit_base_patch16_224.mae"].update({"distance_metric": "l1", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["mae_pretrain_vit_base_global"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_resnet50"].update({"distance_metric": "l1", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_resnet50"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["ft_vicreg_resnet50"].update({"distance_metric": "chebyshev", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_vit_base"].update({"distance_metric": "l2", "louvain_remove_self_loops": True})
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_vitb16"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})
BEST_PARAMS_v5["LouvainCommunities"]["mae_finetuned_vit_base_global"].update({"distance_metric": "l2", "louvain_remove_self_loops": False})

In [None]:
BEST_PARAMS_v5["LouvainCommunities"]["none"].update({"louvain_resolution": 1.2})
BEST_PARAMS_v5["LouvainCommunities"]["random_resnet50"].update({"louvain_resolution": 2.0})
BEST_PARAMS_v5["LouvainCommunities"]["resnet50"].update({"louvain_resolution": 1.0})  # Tied - used default
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_resnet50"].update({"louvain_resolution": 1.4})
BEST_PARAMS_v5["LouvainCommunities"]["dino_resnet50"].update({"louvain_resolution": 1.0})
BEST_PARAMS_v5["LouvainCommunities"]["vicreg_resnet50"].update({"louvain_resolution": 1.0})
BEST_PARAMS_v5["LouvainCommunities"]["random_vitb16"].update({"louvain_resolution": 3.0})
BEST_PARAMS_v5["LouvainCommunities"]["vitb16"].update({"louvain_resolution": 1.1})
BEST_PARAMS_v5["LouvainCommunities"]["mocov3_vit_base"].update({"louvain_resolution": 1.0})
BEST_PARAMS_v5["LouvainCommunities"]["dino_vitb16"].update({"louvain_resolution": 1.0})
BEST_PARAMS_v5["LouvainCommunities"]["timm_vit_base_patch16_224.mae"].update({"louvain_resolution": 1.1})
BEST_PARAMS_v5["LouvainCommunities"]["mae_pretrain_vit_base_global"].update({"louvain_resolution": 1.1})
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_resnet50"].update({"louvain_resolution": 1.0})  # Tied - used default
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_resnet50"].update({"louvain_resolution": 1.0})  # Tied - used default
BEST_PARAMS_v5["LouvainCommunities"]["ft_vicreg_resnet50"].update({"louvain_resolution": 1.0})  # Tied - used default
BEST_PARAMS_v5["LouvainCommunities"]["ft_mocov3_vit_base"].update({"louvain_resolution": 1.0})
BEST_PARAMS_v5["LouvainCommunities"]["ft_dino_vitb16"].update({"louvain_resolution": 1.2})
BEST_PARAMS_v5["LouvainCommunities"]["mae_finetuned_vit_base_global"].update({"louvain_resolution": 1.0})  # Tied - used default

### Finally, set overall hparams

In [None]:
BEST_PARAMS = BEST_PARAMS_v5

## Utility functions

In [None]:
def categorical_cmap(nc, nsc, cmap="tab10", continuous=False):
    """
    Create a colormap with a certain number of shades of colours.

    Based on https://stackoverflow.com/a/47232942/1960959

    Parameters
    ----------
    nc : int
        Number of categories.
    nsc : int
        Number of shades per category.
    cmap : str, default=tab10
        Original colormap to extend into multiple shades.
    continuous : bool, default=False
        Whether ``cmap`` is continous. Otherwise it is treated
        as categorical with adjacent colors unrelated.

    Returns
    -------
    matplotlib.colors.ListedColormap
        New cmap which alternates between ``nsc`` shades of ``nc``
        colors from ``cmap``.
    """
    if nc > plt.get_cmap(cmap).N:
        raise ValueError("Too many categories for colormap.")
    if continuous:
        ccolors = plt.get_cmap(cmap)(np.linspace(0, 1, nc))
    else:
        ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
    cols = np.zeros((nc * nsc, 3))
    for i, c in enumerate(ccolors):
        chsv = matplotlib.colors.rgb_to_hsv(c[:3])
        arhsv = np.tile(chsv, nsc).reshape(nsc, 3)
        arhsv[:, 1] = np.linspace(chsv[1], 0.25, nsc)
        arhsv[:, 2] = np.linspace(chsv[2], 1, nsc)
        rgb = matplotlib.colors.hsv_to_rgb(arhsv)
        cols[i * nsc : (i + 1) * nsc, :] = rgb
    cmap = matplotlib.colors.ListedColormap(cols)
    return cmap

In [None]:
categorical_cmap(len(RESNET50_MODELS), len(VALIDATION_DATASETS))

In [None]:
from zs_ssl_clustering.datasets import image_dataset_sizes


def clip_imgsize(dataset, target_image_size):
    if target_image_size is None:
        return target_image_size
    dataset_imsize = image_dataset_sizes(dataset)[1]
    if dataset_imsize is None:
        return target_image_size
    if hasattr(dataset_imsize, "__len__"):
        dataset_imsize = min(dataset_imsize)
    return min(target_image_size, dataset_imsize)

In [None]:
def fixup_filter(filters):
    dataset = filters.get("dataset_name", filters.get("dataset", None))
    if dataset and "image_size" in filters:
        filters["image_size"] = clip_imgsize(dataset, filters["image_size"])
    if dataset and "min_samples" in filters:
        if dataset.lower() in ["celeba", "utkface", "bioscan1m"]:
            filters["min_samples"] = 2
    return filters

In [None]:
def select_rows(df, filters, allow_missing=True, fixup=True):
    if fixup:
        filters = fixup_filter(filters)
    select = np.ones(len(df), dtype=bool)
    for col, val in filters.items():
        if col == "dataset":
            col = "dataset_name"
        if col == "clusterer":
            col = "clusterer_name"
        if val is None or val == "None" or val == "none":
            select_i = pd.isna(df[col])
            select_i |= df[col] == "None"
            select_i |= df[col] == "none"
        else:
            select_i = df[col] == val
            select_i |= df[col] == str(val)
            if allow_missing or val == "None" or val == "none":
                select_i |= pd.isna(df[col])
        select &= select_i
    return df[select]

In [None]:
def find_differing_columns(df, cols=None):
    if cols is None:
        cols = df.columns
    my_cols = []
    for col in cols:
        if col not in df.columns:
            continue
        if df[col].nunique(dropna=False) > 1:
            my_cols.append(col)
    return my_cols

In [None]:
def filter2command(*filters, partition="val"):
    f = {}
    for filter in filters:
        for k, v in filter.items():
            f[k] = v
    dataset = f.get("dataset", "")
    clusterer = f.get("clusterer", "")

    mem = 2  # RAM in gigabytes

    if clusterer in ["LouvainCommunities"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 3_700
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 926
        elif dataset in ["places365"]:
            # 36,500 samples
            mem = 494
        elif dataset in ["imagenet-r"]:
            # 30,000 samples
            mem = 333
        elif dataset in ["svhn"]:
            # 26,000 samples
            mem = 250
        elif dataset in ["bioscan1m", "nabirds"]:
            # 24,600 samples
            mem = 224
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 128
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
            "breakhis",
        ]:
            # 8,000 - 10,000 samples
            mem = 32
        elif dataset in ["flowers102", "utkface"]:
            # 5,925 - 6,200 samples
            mem = 18
        elif dataset.startswith("in9") or dataset in ["eurosat"]:
            # 4,500 samples
            mem = 8
        elif dataset in ["imagenette", "imagewoof", "aircraft"]:
            # 3,333 - 3,930 samples
            mem = 6
        elif dataset in ["imagenet-o", "dtd"]:
            # 2,000 samples
            mem = 4
        else:
            mem = 12

    elif clusterer in ["AffinityPropagation"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 292
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 72
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 48
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 12
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 6
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 2
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 1
        else:
            mem = 8

    elif clusterer in ["AgglomerativeClustering", "SpectralClustering"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 72
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 20
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 16
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 12
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 6
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 4
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 2
        else:
            mem = 8
        if clusterer == "SpectralClustering":
            snn = f.get("spectral_n_neighbors", 100)
            if snn <= 10:
                mem = mem * 8 / 20
            elif snn <= 20:
                mem = mem * 3 / 4
            mem = int(np.ceil(mem))

    elif clusterer in ["HDBSCAN", "KMeans"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 6
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 4
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 4
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 4
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 2
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 2
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 1
        else:
            mem = 4

    if mem > 300:
        return ""
    if mem > 129:
        pass

    mem = f"{mem}G"

    if partition == "val":
        seed = 100
    elif partition == "test":
        seed = 1
    else:
        seed = 0
    s = (
        f"sbatch --array={seed} --mem={mem}"
        f' --job-name="zsc-{f.get("model", "")}-{dataset}-{clusterer}"'
        f" slurm/cluster.slrm --partition={partition}"
    )
    for k, v in f.items():
        if v is None:
            continue
        if k == "zscore":
            if v == "False" or not v:
                s += " --no-zscore"
            elif v == "True" or v:
                s += " --zscore"
            continue
        if k == "normalize":
            if v == "False" or not v:
                pass
            elif v == "True" or v:
                s += " --normalize"
            continue
        if k == "zscore2":
            if v == "False" or not v:
                s += " --no-zscore2"
            elif v == "average":
                s += " --azscore2"
            elif v == "standard" or v:
                s += " --zscore2"
            continue
        if k == "ndim_correction":
            if v == "False" or not v:
                s += " --no-ndim-correction"
            elif v == "True" or v:
                s += " --ndim-correction"
            continue
        if k == "louvain_remove_self_loops":
            if v == "False" or not v:
                s += " --louvain-keep-self"
            elif v == "True" or v:
                pass
            continue
        s += f" --{k.replace('_', '-')}={v}"
    return s

# Final results

In [None]:
# Exclude CLIP from analysis
RESNET50_MODELS = [v for v in RESNET50_MODELS if not v.startswith("clip")]
VITB16_MODELS = [v for v in VITB16_MODELS if not v.startswith("clip")]

In [None]:
TEST_DATASETS = [
    "imagenet",
    "imagenetv2",
    "imagenet-o",
    "cifar10",
    "cifar100",
    "in9original",
    "in9mixedrand",
    # "in9onlybgt",
    "in9onlyfg",
    "imagenet-r",
    "imagenet-sketch",
    "aircraft",
    "stanfordcars",
    "flowers102",
    "bioscan1m",
    "nabirds",
    "inaturalist",
    "celeba",
    "utkface",
    "breakhis",
    "dtd",
    "eurosat",
    "lsun",
    "places365",
    "mnist",
    "fashionmnist",
    "svhn",
]
DATASET2SH = {
    "aircraft": "Air",
    "bioscan1m": "Bio",
    "breakhis": "BHis",
    "celeba": "CelA",
    "cifar10": "C10",
    "cifar100": "C100",
    "dtd": "DTD",
    "eurosat": "ESAT",
    "flowers102": "F102",
    "fashionmnist": "Fash",
    "imagenet": "IN1k",
    "imagenet-o": "IN-O",
    "imagenet-r": "IN-R",
    "imagenet-sketch": "IN-S",
    "imagenetv2": "INv2",
    "imagenette": "IN10",
    "imagewoof": "INwf",
    "in9original": "IN9",
    "in9mixednext": "9-MN",
    "in9mixedrand": "9-MR",
    "in9mixedsame": "9-MS",
    "in9nofg": "9-NoFG",
    "in9onlybgb": "9-BGB",
    "in9onlybgt": "9-BGT",
    "in9onlyfg": "9-FG",
    "inaturalist": "iNat",
    "lsun": "LSU",
    "mnist": "MNST",
    "nabirds": "Birds",
    "places365": "P365",
    "stanfordcars": "Cars",
    "svhn": "SVHN",
    "utkface": "UTKF",
}
MODEL_GROUPS = {
    "ResNet-50": RESNET50_MODELS,
    "ViT-B": VITB16_MODELS,
    "ResNet-50 [FT]": FT_RESNET50_MODELS,
    "ViT-B [FT]": FT_VITB16_MODELS,
    "all": ALL_MODELS,
}
MODEL2SH = {
    "none": "Raw image",
    "random_resnet50": "Rand.",  # "Random",
    "random_vitb16": "Rand.",  # "Random",
    "resnet50": "X-Ent.",
    "mocov3_resnet50": "MoCo-v3",
    "dino_resnet50": "DINO",
    "vicreg_resnet50": "VICReg",
    "clip_RN50": "CLIP",
    "vitb16": "X-Ent.",
    "mocov3_vit_base": "MoCo-v3",
    "dino_vitb16": "DINO",
    "timm_vit_base_patch16_224.mae": "MAE (CLS)",
    "mae_pretrain_vit_base_global": "MAE (avg)",
    "clip_vitb16": "CLIP",
    "ft_mocov3_resnet50": "MoCo-v3 [FT]",
    "ft_dino_resnet50": "DINO [FT]",
    "ft_vicreg_resnet50": "VICReg [FT]",
    "ft_mocov3_vit_base": "MoCo-v3 [FT]",
    "ft_dino_vitb16": "DINO [FT]",
    "mae_finetuned_vit_base_global": "MAE (avg) [FT]",
}
CLUSTERER2SH = {
    "KMeans": "K-Means",
    "SpectralClustering": "Spectral",
    "AffinityPropagation": "Affinity Prop",
    "AgglomerativeClustering": "AC",
    "AC w/ C": "AC w/  C",
}

In [None]:
MODEL2SH_ARCH = dict(MODEL2SH)
for k, v in MODEL2SH.items():
    if "resnet" in k or "RN50" in k:
        MODEL2SH_ARCH[k] = f"ResNet-50 {v}"
    elif "vit" in k:
        MODEL2SH_ARCH[k] = f"ViT-B {v}"

In [None]:
TEST_DATASETS_GROUPED = {
    "In-domain": [
        "imagenet",
        "imagenetv2",
        "cifar10",
        "cifar100",
        "in9original",
    ],
    "Domain-shift": [
        "in9onlyfg",
        # "in9onlybgt",
        "in9mixedrand",
        "imagenet-r",
        "imagenet-sketch",
    ],
    "Near-OOD": [
        "imagenet-o",
        "lsun",
        "places365",
    ],
    "Fine-grained": [
        "aircraft",
        "stanfordcars",
        "flowers102",
        "bioscan1m",
        "nabirds",
        "inaturalist",
    ],
    "Far-OOD": [
        "celeba",
        "utkface",
        "breakhis",
        "dtd",
        "eurosat",
        "mnist",
        "fashionmnist",
        "svhn",
    ],
}

DATASETGROUP2TITLE = {
    "Domain-shift": "Domain-shifted",
    "Out-of-distribution": "OOD",
}

In [None]:
IN9_DATASETS = [
    "in9original",
    "in9onlyfg",
    "in9nofg",
    "in9onlybgt",
    "in9mixedsame",
    "in9mixedrand",
]
IN92SH = {
    "in9original": "OG",
    "in9mixednext": "MN",
    "in9mixedrand": "MR",
    "in9mixedsame": "MS",
    "in9nofg": r"FG$^\text{C}$",
    "in9onlybgb": "BG(B)",
    "in9onlybgt": "BG",
    "in9onlyfg": "FG",
    "in9bggap": "Gap",
}

In [None]:
CLUSTERER2COLORSTR = {
    "KMeans": "tab:purple",
    "SpectralClustering": "tab:cyan",
    "AC w/ C": "tab:red",
    "AC w/o C": "tab:orange",
    "AffinityPropagation": "tab:green",
    "HDBSCAN": "tab:blue",
}
CLUSTERER2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in CLUSTERER2COLORSTR.items()}

In [None]:
# ICLR2024
MODEL2COLORSTR = {
    "none": "black",
    "random_resnet50": "tab:grey",
    "random_vitb16": "tab:grey",
    "resnet50": "tab:red",
    "mocov3_resnet50": "tab:cyan",
    "dino_resnet50": "tab:green",
    "vicreg_resnet50": "tab:purple",
    "clip_RN50": "tab:blue",
    "vitb16": "tab:red",
    "mocov3_vit_base": "tab:cyan",
    "dino_vitb16": "tab:green",
    "timm_vit_base_patch16_224.mae": "tab:olive",
    "mae_pretrain_vit_base_global": "tab:brown",
    "clip_vitb16": "tab:blue",
    "mae_finetuned_vit_base_global": "tab:brown",
}
MODEL2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in MODEL2COLORSTR.items()}

In [None]:
# ICML2024
MODEL2COLORSTR = {
    "none": "black",
    "random_resnet50": "dimgrey",
    "random_vitb16": "dimgrey",
    "resnet50": "tab:red",
    "mocov3_resnet50": "tab:green",
    "dino_resnet50": "tab:purple",
    "vicreg_resnet50": "tab:orange",
    "clip_RN50": "tab:olive",
    "vitb16": "tab:red",
    "mocov3_vit_base": "tab:green",
    "dino_vitb16": "tab:purple",
    "timm_vit_base_patch16_224.mae": "tab:blue",
    "mae_pretrain_vit_base_global": "tab:brown",
    "clip_vitb16": "tab:olive",
    "mae_finetuned_vit_base_global": "tab:brown",
}
MODEL2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in MODEL2COLORSTR.items()}

In [None]:
for model in FT_MODELS:
    MODEL2COLORRGB[model] = tuple(c * 0.8 for c in MODEL2COLORRGB[FT2PRE[model]])
for model in RESNET50_MODELS + VITB16_MODELS:
    MODEL2COLORRGB[model] = tuple(1 - (1 - c) * 0.7 for c in MODEL2COLORRGB[model])

### Tabulate hyperparams

In [None]:
METRIC2TABLE = {
    "l1": r"$\ell_1$",
    "cityblock": r"$\ell_1$",
    "manhattan": r"$\ell_1$",
    "l2": r"$\ell_2$",
    "euclidean": r"$\ell_2$",
    "chebyshev": r"$\ell_\infty$",
    "infinity": r"$\ell_\infty$",
}

In [None]:
METRIC2TABLE = {
    "l1": r"$L1$",
    "cityblock": r"$L1$",
    "manhattan": r"$L1$",
    "l2": r"$L2$",
    "euclidean": r"$L2$",
    "chebyshev": r"$L\infty$",
    "infinity": r"$L\infty$",
}

In [None]:
clusterers = ["KMeans", "AC w/o C", "AffinityPropagation", "HDBSCAN"]

model_groups = {
    "---": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

latex_table = r"% Hyperparameters " + f"{BEST_PARAMS['_version']}" + "\n"
now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
latex_table += r"% Generated " + now_str + "\n"
label = "hparams"
latex_table += r"\label{tab:" + label + r"}" + "\n"
latex_table += r"% \resizebox{\columnwidth}{!}{%" + "\n"
latex_table += r"\begin{tabular}{lllllrccrr}" + "\n"
latex_table += r"\toprule" + "\n"
# Begin main header row

latex_table += r"      &             &   "
latex_table += r" &           & &                &      & \multicolumn{2}{c}{Agg. Clustering} & Aff. Prop. \\"
latex_table += r" \\" + "\n"
latex_table += r"\cmidrule(l){8-9} \cmidrule(l){10-10}" + "\n"
latex_table += r"Arch. & " + f"{'Encoder':<11s}" + r" & FT"
latex_table += r" & Clusterer & \multicolumn{2}{c}{Dim Reduction} & Metric & Linkage & Dist. Thr. & Damping"
latex_table += r" \\" + "\n"
# Begin table contents
latex_table += r"\midrule" + "\n"

for i_group, group in enumerate(model_groups):
    if i_group > 0:
        latex_table += r"\midrule" + "\n"
    latex_table += group + "\n"
    for i_model, model in enumerate(list(model_groups[group])):
        if i_model != 0:
            latex_table += r"\cmidrule(l){2-10}" + "\n"
        for i_clusterer, clusterer in enumerate(clusterers):
            model_sh = MODEL2SH.get(model, model) if i_clusterer == 0 else ""
            if MODEL2SH.get(model, model).endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s} &"
            latex_table += f"& {model_sh:<23s}"
            clusterername = CLUSTERER2SH.get(clusterer, clusterer)
            latex_table += f" & {clusterername:<13s}"

            dim_reducer = "None"
            dim_reduced = None
            if BEST_PARAMS[clusterer][model].get("dim_reducer_man", None):
                dim_reducer = BEST_PARAMS[clusterer][model]["dim_reducer_man"]
                dim_reduced = BEST_PARAMS[clusterer][model]["ndim_reduced_man"]
            elif BEST_PARAMS[clusterer][model].get("dim_reducer", None):
                dim_reducer = BEST_PARAMS[clusterer][model]["dim_reducer"]
                if BEST_PARAMS[clusterer][model].get("ndim_reduced", None):
                    dim_reduced = BEST_PARAMS[clusterer][model]["ndim_reduced"]
                elif BEST_PARAMS[clusterer][model].get("pca_variance", None):
                    dim_reduced = BEST_PARAMS[clusterer][model]["pca_variance"]
                    dim_reduced = f"{dim_reduced:.2f}"
            distance_metric = BEST_PARAMS[clusterer][model].get("distance_metric", None)
            affinity_damping = BEST_PARAMS[clusterer][model].get("affinity_damping", None)
            aggclust_linkage = BEST_PARAMS[clusterer][model].get("aggclust_linkage", None)
            aggclust_dist_thresh = BEST_PARAMS[clusterer][model].get("aggclust_dist_thresh", None)

            dim_reduced = dim_reduced if dim_reduced else r"\noval{}"
            distance_metric = distance_metric if distance_metric else r"\noval{}"
            distance_metric = METRIC2TABLE.get(distance_metric, distance_metric)
            affinity_damping = f"{affinity_damping:8.2f}" if affinity_damping else r"\noval{}"
            aggclust_linkage = aggclust_linkage if aggclust_linkage else r"\noval{}"
            aggclust_dist_thresh = f"{aggclust_dist_thresh:8.2f}" if aggclust_dist_thresh else r"\noval{}"

            latex_table += f" & {dim_reducer:<4s} & {dim_reduced:<4}"
            latex_table += f" & {distance_metric:<14s}"
            latex_table += f" & {aggclust_linkage:<8s} & {aggclust_dist_thresh}"
            latex_table += f" & {affinity_damping}"
            latex_table += r" \\" + "\n"

latex_table += r"\bottomrule" + "\n"
latex_table += r"\end{tabular}" + "\n"
latex_table += r"% }" + "\n"

print(latex_table)

## Fetch results

In [None]:
runs_df_long = pd.DataFrame({"id": []})
config_keys = set()
summary_keys = set()

In [None]:
# Load previous results from CSV file
CSV_FNAME = "test_runs_df.csv"
if os.path.isfile(CSV_FNAME):
    pass
    # runs_df_long = test_runs_df = pd.read_csv(CSV_FNAME)

In [None]:
# Project is specified by <entity/project-name>
api = wandb.Api(timeout=720)
runs = api.runs(
    "uoguelph_mlrg/zs-ssl-clustering",
    filters={
        "state": "Finished",
        "config.partition": "test",
    },  # "config.predictions_dir": "y_pred"},
    per_page=10_000,
)
len(runs)

In [None]:
print(f"{len(runs_df_long)} runs currently in dataframe")
rows_to_add = []
existing_ids = set(runs_df_long["id"].values)
for run in tqdm(runs):
    if run.id in existing_ids:
        if len(rows_to_add) >= len(runs) - len(runs_df_long):
            break
        continue
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary = run.summary._json_dict
    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config = {k: v for k, v in run.config.items() if not k.startswith("_")}
    # .name is the human-readable name of the run.
    row = {"id": run.id, "name": run.name}
    row.update({k: v for k, v in config.items() if not k.startswith("_")})
    row.update({k: v for k, v in summary.items() if not k.startswith("_")})
    if "_timestamp" in summary:
        row["_timestamp"] = summary["_timestamp"]
    rows_to_add.append(row)
    config_keys = config_keys.union(config.keys())
    summary_keys = summary_keys.union(summary.keys())

if not len(rows_to_add):
    print("No new runs to add")
else:
    print(f"Adding {len(rows_to_add)} runs")
    runs_df_long = pd.concat([runs_df_long, pd.DataFrame.from_records(rows_to_add)])
print(f"{len(runs_df_long)} runs")

In [None]:
# Remove entries without an AMI metric
test_runs_df = runs_df_long[~runs_df_long["AMI"].isna()]
len(test_runs_df)

In [None]:
# Handle changed default value for spectral_assigner after config arg was introduced
if "spectral_n_components" not in test_runs_df.columns:
    test_runs_df["spectral_n_components"] = None

if "spectral_assigner" not in test_runs_df.columns:
    test_runs_df["spectral_assigner"] = None
select = test_runs_df["clusterer_name"] != "SpectralClustering"
test_runs_df.loc[select, "spectral_assigner"] = None
select = (test_runs_df["clusterer_name"] == "SpectralClustering") & pd.isna(test_runs_df["spectral_assigner"])
test_runs_df.loc[select, "spectral_assigner"] = "kmeans"

# Accidentally wasn't clearing this hparam when it was unused
if "spectral_affinity" not in test_runs_df.columns:
    test_runs_df["spectral_affinity"] = None
select = test_runs_df["clusterer_name"] != "SpectralClustering"
test_runs_df.loc[select, "spectral_affinity"] = None

if "zscore2" not in test_runs_df.columns:
    test_runs_df["zscore2"] = False
test_runs_df.loc[pd.isna(test_runs_df["zscore2"]), "zscore2"] = False

if "ndim_correction" not in test_runs_df.columns:
    test_runs_df["ndim_correction"] = False
test_runs_df.loc[pd.isna(test_runs_df["ndim_correction"]), "ndim_correction"] = False

if "dim_reducer_man_nn" not in test_runs_df.columns:
    test_runs_df["dim_reducer_man_nn"] = None

if "image_size" not in test_runs_df.columns:
    test_runs_df["image_size"] = None

In [None]:
# Save results to CSV file, so we can optionally skip downloading them
test_runs_df.to_csv(CSV_FNAME, index=False)

In [None]:
config_keys = config_keys.difference({"workers", "memory_avail_GB", "memory_total_GB", "memory_slurm"})

In [None]:
test_runs_df

## Result loading utility functions

In [None]:
model = "mocov3_resnet50"
dataset = "imagenet"
clusterer = "AC w/o C"
metric_key = "AMI"

my_override_fields = {}

filter1 = {"model": model, "dataset": dataset}
filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
filter2.update(filter1)
filter2.update(my_override_fields)
filter2 = fixup_filter(filter2)
sdf = select_rows(test_runs_df, filter2, allow_missing=False)
my_val = np.nanmedian(sdf[metric_key])

print(f"{metric_key} = {my_val * 100:.0f}%")

In [None]:
def build_results_table(
    models,
    clusterers,
    datasets,
    metric_keys="AMI",
    override_fields=None,
    return_cmds=False,
    verbosity=0,
):
    if override_fields is None:
        override_fields = {}

    do_squeeze = False
    if isinstance(metric_keys, str):
        do_squeeze = True
        metric_keys = [metric_keys]

    result_table = np.nan * np.ones((len(models), len(clusterers), len(datasets), len(metric_keys)))
    cmds = []

    for i_model, model in enumerate(models):
        for i_clusterer, clusterer in enumerate(clusterers):
            for i_dataset, dataset in enumerate(datasets):
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                if dataset == "in9bggap":
                    filter2["dataset"] = "in9mixedsame"
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                missing_val = False
                if len(sdf) > 0:
                    for i_key, key in enumerate(metric_keys):
                        val = np.nanmedian(sdf[key])
                        result_table[i_model, i_clusterer, i_dataset, i_key] = val
                        if np.isnan(val):
                            missing_val = True
                if len(sdf) < 1 or missing_val:
                    if verbosity >= 1:
                        print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    cmds.append(filter2command(filter2, partition="test"))

                if dataset == "in9bggap":
                    filter2["dataset"] = "in9mixedrand"
                    sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                    for i_key, key in enumerate(metric_keys):
                        result_table[i_model, i_clusterer, i_dataset, i_key] -= np.nanmedian(sdf[key])

    if do_squeeze:
        result_table = np.squeeze(result_table, axis=3)

    if return_cmds:
        return result_table, cmds
    else:
        return result_table

In [None]:
def dict_generator(indict, pre=None):
    pre = pre[:] if pre else []
    if isinstance(indict, dict):
        for key, value in indict.items():
            if isinstance(value, dict):
                for d in dict_generator(value, pre + [key]):
                    yield d
            elif isinstance(value, list) or isinstance(value, tuple):
                for v in value:
                    for d in dict_generator(v, pre + [key]):
                        yield d
            else:
                yield pre + [key, value]
    else:
        yield pre + [indict]

In [None]:
def make_flat_hierarchy_from_dict(indict, pad_right=True):
    groups_flattened = list(dict_generator(indict))
    depth = max(len(m) for m in groups_flattened)
    if pad_right:
        groups_flattened = [m + [""] * (depth - len(m)) for m in groups_flattened]
    else:
        groups_flattened = [[""] * (depth - len(m)) + m for m in groups_flattened]

    return groups_flattened

### Draft table

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:5.1f}"
show_commands = False
eps = 0.001
override_fields = {
    # "aggclust_dist_thresh": None,  # to flip between unknown/known n clusters for AC
    # "predictions_dir": "y_pred",
}

# KMeans  AffinityPropagation  AgglomerativeClustering  HDBSCAN
backbones = MODEL_GROUPS.keys()
clusterer = "AgglomerativeClustering"

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {clusterer}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    latex_table += r"\label{tab:" + clusterer + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    for i_group, model_group_name in enumerate(list(backbones)):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        for i_model, model in enumerate(MODEL_GROUPS[model_group_name]):
            if i_model == 0:
                latex_table += r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{" + model_group_name + "}}}"
                latex_table += "\n"
            latex_table += f"& {MODEL2SH.get(model, model):<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {filter2}")
                    if clusterer == "AffinityPropagation" and dataset in [
                        "imagenet",
                        "inaturalist",
                    ]:
                        continue
                        pass
                    cmds.append(filter2command(filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(
                            f"More than one result with {metric_key} values",
                            list(sdf[metric_key]),
                        )
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by encoder

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
eps = 0.001
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "ResNet-50"

CLUSTERERS = [
    "KMeans",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "HDBSCAN",
]
print(MODEL2SH)

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    latex_table += r"\label{tab:" + backbone + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Clusterer':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
        print(model)
        if i_group > 0:
            latex_table += r"\midrule" + "\n"

        first_agg = True
        for i_clusterer, clusterer in enumerate(CLUSTERERS):
            if i_clusterer == 0:
                latex_table += r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{" + MODEL2SH[model] + "}}}"
                latex_table += "\n"
            clusterername = CLUSTERER2SH.get(clusterer, clusterer)

            my_override_fields = override_fields.copy()
            if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
                first_agg = False
                my_override_fields["aggclust_dist_thresh"] = None
                clusterername = "AC  w/ C"
            elif clusterer == "AgglomerativeClustering":
                clusterername = "AC w/o C"
                if "aggclust_dist_thresh" in my_override_fields:
                    del my_override_fields["aggclust_dist_thresh"]

            if clusterer == "HDBSCAN" and dataset in ["celeba", "utkface"]:
                my_override_fields["min_samples"] = 2
            elif "min_samples" in my_override_fields:
                del my_override_fields["min_samples"]

            latex_table += f"& {clusterername:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {filter2}")
                    cmds.append(filter2command(filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(
                            f"More than one result with {metric_key} values",
                            list(sdf[metric_key]),
                        )
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by clusterer

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "all"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

best_results = {k: [] for k in TEST_DATASETS}
best_results_grouped = {k: defaultdict(list) for k in TEST_DATASETS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(TEST_DATASETS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    if clusterer == "AffinityPropagation" and dataset in [
                        "imagenet",
                        "places365",
                        "imagenet-r",
                        "svhn",
                        "bioscan1m",
                        "nabirds",
                    ]:
                        cmds.append(filter2command(filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(
                            f"More than one result with {metric_key} values",
                            list(sdf[metric_key]),
                        )
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][clusterername])
                sc_base = np.nanmedian(best_results[dataset])
                sc_top = np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

In [None]:
for cmd in cmds:
    print(cmd)

### With grouped datasets

In [None]:
MODEL_GROUPS.keys()

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
colour_bg = False
use_si_num = False
eps = 0.005
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "ViT-B"  # "ResNet-50" "ViT-B"
model_group = MODEL_GROUPS[backbone]
# model_group = RESNET50_MODELS + FT_RESNET50_MODELS
model_group = VITB16_MODELS + FT_VITB16_MODELS

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(model_group)

test_datasets = []
for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
    test_datasets.extend(datagroupset)

best_results = {k: [] for k in test_datasets}
best_results_grouped = {k: defaultdict(list) for k in test_datasets}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\linewidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{lll" + r"r" * len(test_datasets) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'':<11s} &    "
    for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
        latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
    latex_table += r"\\" + "\n"
    icol = 4
    for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
        latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
        icol += len(datagroupset)
    latex_table += "\n"
    latex_table += r"& " + f"{'Encoder':<11s}" + r" & FT "
    for dataset in test_datasets:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\midrule" + "\n"
    print(model_group)
    if metric_key == "num_cluster_pred":
        latex_table += r"& \textit{\textnumero{} targets} & "
        for i_dataset, dataset in enumerate(test_datasets):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item():.0f}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\midrule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(model_group))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(model_group)):
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            latex_table += f"& {model_sh:<23s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\midrule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(model_group))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(model_group)):
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            latex_table += f"& {model_sh:<23s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    cmds.append(filter2command(filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(
                            f"More than one result with {metric_key} values",
                            list(sdf[metric_key]),
                        )
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][clusterername])
                sc_base = np.nanmedian(best_results[dataset])
                sc_top = np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                if colour_bg:
                    latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

### Number of clusters

In [None]:
metric_key = "num_cluster_pred"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
clusterers = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
show_pc = False
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = False
use_si_num = True
highlight_best = True
colour_bg = True
eps = 0.005


backbone = "ViT-B"  # "ResNet-50" "ViT-B"
model_group = MODEL_GROUPS[backbone]
# model_group = RESNET50_MODELS + FT_RESNET50_MODELS
model_group = VITB16_MODELS + FT_VITB16_MODELS
test_datasets_grouped = TEST_DATASETS_GROUPED

print(model_group)


######################################
# Get the number of clusters for everything

model_groups = {
    "---": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)


print("Encoders:")
print(model_groups_flattened)

print("Datasets:")
print(test_datasets)

result_table = build_results_table(
    model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=["num_cluster_pred", "num_cluster_true"],
)
# Shaped [models, clusterers, datasets]
print("result_table.shape", result_table.shape)

######################################

In [None]:
result_table.shape

In [None]:
result_table_min = np.nanmin(result_table[..., 0], axis=1, keepdims=True)
result_table_max = np.nanmax(result_table[..., 0], axis=1, keepdims=True)

In [None]:
result_table_logratio = np.log(result_table[..., 0] / result_table[..., 1])
result_table_logratiomin = np.nanmin(result_table_logratio, axis=(0, 1))
result_table_logratiomax = np.nanmax(result_table_logratio, axis=(0, 1))

In [None]:
result_table_logratioabsmax = np.nanmax(np.abs(result_table_logratio), axis=(0, 1))

In [None]:
result_table_logratioabsmax

In [None]:
clusterers = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = False
use_si_num = False
highlight_best = True
colour_bg = True
eps = 0.005
merge_model_group_column = True


######################################
# Get the number of clusters for everything

model_groups = {
    "": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)


print("Encoders:")
print(model_groups_flattened)

print("Datasets:")
print(test_datasets)

metric_keys = ["num_cluster_pred", "num_cluster_true"]
result_table, cmds = build_results_table(
    model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=metric_keys,
    return_cmds=True,
)
# Shaped [models, clusterers, datasets]
print("result_table.shape", result_table.shape)

result_table_logratio = np.log(result_table[..., 0] / result_table[..., 1])
result_table_logratiomin = np.nanmin(result_table_logratio, axis=(0, 1))
result_table_logratiomax = np.nanmax(result_table_logratio, axis=(0, 1))
result_table_logratioabsmax = np.nanmax(np.abs(result_table_logratio), axis=(0, 1))

######################################

metric_key = metric_keys[0]

print("Encoders:")
print(model_groups_flattened)

print("Datasets:")
print(test_datasets)

print()
print()
latex_table = r"\begin{landscape}" + "\n"
for i_clusterer, clusterer in enumerate(clusterers):
    clusterername = CLUSTERER2SH.get(clusterer, clusterer)

    latex_table += r"\begin{table}" + "\n"
    latex_table += r"\captionsetup{width=.707\linewidth}" + "\n"
    latex_table += r"\caption{" + "\n"
    latex_table += r"\textbf{Number of clusters generating using " + clusterername + ".}\n"
    latex_table += r"}" + "\n"

    latex_table += r"% Results for " + f"{metric_key}, {clusterer}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = metric_key.replace("_", "-") + ":" + clusterer
    latex_table += r"\label{tab:" + label + r"}" + "\n"

    latex_table += r"\resizebox{\columnwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{"
    if not merge_model_group_column:
        latex_table += "l"
    latex_table += "ll" + r"r" * len(test_datasets) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    # Begin dataset group header row
    if len(test_datasets_grouped) > 1:
        if not merge_model_group_column:
            latex_table += r"& "
        latex_table += f"{'':<11s}" + r" &   "
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
        latex_table += r"\\" + "\n"
        icol = 3
        if not merge_model_group_column:
            icol += 1
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
            icol += len(datagroupset)
        latex_table += "\n"
    # Begin main header row, with actual dataset names
    if merge_model_group_column:
        latex_table += r"\quad "
    else:
        latex_table += r"Arch. & "
    latex_table += f"{'Encoder':<11s}" + r" & FT "
    for dataset in test_datasets:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\midrule" + "\n"

    # Ground truth number of targets
    latex_table += r"\textit{\textnumero{} GT classes} & "
    for i_dataset, dataset in enumerate(test_datasets):
        my_val = np.nanmean(result_table[:, :, i_dataset, 1])
        latex_table += r"& "
        latex_table += r"\num{" if use_si_num else r"$"
        latex_table += show_fmt.format(my_val)
        latex_table += r"}" if use_si_num else r"$"
    latex_table += r"\\" + "\n"
    latex_table += r"\midrule" + "\n"

    i_model_o = -1
    for i_group, group in enumerate(model_groups):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        if merge_model_group_column:
            if not group:
                latex_table += r"\quad "
            else:
                latex_table += r"\textbf{" + group + r"} --- "
        elif not group:
            latex_table += "---" + "\n"
        else:
            latex_table += group + "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            i_model_o += 1
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            if merge_model_group_column and i_model > 0:
                latex_table += r"\quad "
            latex_table += f"{model_sh:<23s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                my_val = result_table[i_model_o, i_clusterer, i_dataset, 0]
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = np.abs(result_table_logratio[i_model_o, i_clusterer, i_dataset]) - eps <= np.nanmin(
                    np.abs(result_table_logratio[:, :, i_dataset])
                )
                is_secd = False
                is_best_grp = np.abs(result_table_logratio[i_model_o, i_clusterer, i_dataset]) - eps <= np.nanmin(
                    np.abs(result_table_logratio[:, i_clusterer, i_dataset])
                )
                sc_base = 0
                # sc_top = np.max(best_results[dataset])
                # sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                if colour_bg:
                    sc = result_table_logratio[i_model_o, i_clusterer, i_dataset]
                    sc = 100.0 * sc / result_table_logratioabsmax[i_dataset]
                    latex_table += r"\cc{"
                    if sc < 0:
                        latex_table += r"cbr!"
                    else:
                        latex_table += r"cbg!"
                    latex_table += f"{abs(sc):.0f}" + "}"

                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += ""
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += ""
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else ""
                    latex_table += r"}" if is_best_grp else ""
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"
    latex_table += r"\end{table}" + "\n"
latex_table += r"\end{landscape}"

print()
print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

In [None]:
result_table_logratio.shape

In [None]:
result_table_logratioabsmax.shape

In [None]:
result_table_logratio / result_table_logratioabsmax

## Single clusterer, grouped encoders, grouped datasets

In [None]:
CLUSTERERS

In [None]:
clusterer = "HDBSCAN"  # "AC w/o C"
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    # "predictions_dir": "y_pred",
}

model_groups = {
    "---": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

test_datasets_grouped, dataset2sh = TEST_DATASETS_GROUPED, DATASET2SH
# test_datasets_grouped, dataset2sh = {"in9_bg_challenge": IN9_DATASETS + ["in9bggap"]}, IN92SH  # IN9 special

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(model_groups)

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)

best_results = {k: [] for k in test_datasets}
best_results_grouped = {k: defaultdict(list) for k in test_datasets}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {clusterer}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = clusterer
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\columnwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{lll" + r"r" * len(test_datasets) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    # Begin dataset group header row
    if len(test_datasets_grouped) > 1:
        latex_table += r"& " + f"{'':<11s}" + r" &   "
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
        latex_table += r"\\" + "\n"
        icol = 4
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
            icol += len(datagroupset)
        latex_table += "\n"
    # Begin main header row, with actual dataset names
    latex_table += r"Arch. & " + f"{'Encoder':<11s}" + r" & FT "
    for dataset in test_datasets:
        latex_table += r"&" + "{:^15s}".format(dataset2sh.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    # Begin table contents
    latex_table += r"\midrule" + "\n"
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(test_datasets):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\midrule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(model_groups[group]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][group].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][group])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\midrule" + "\n"

    first_agg = True
    clusterername = CLUSTERER2SH.get(clusterer, clusterer)
    my_override_fields = override_fields.copy()
    if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
        first_agg = False
        my_override_fields["aggclust_dist_thresh"] = None
        clusterername = "AC  w/ C"
    elif clusterer == "AgglomerativeClustering":
        clusterername = "AC w/o C"
        if "aggclust_dist_thresh" in my_override_fields:
            del my_override_fields["aggclust_dist_thresh"]

    for i_group, group in enumerate(model_groups):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        latex_table += group + "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            latex_table += f"& {model_sh:<23s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                if dataset == "in9bggap":
                    filter2["dataset"] = "in9mixedsame"
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    # if dataset == "imagenet-sketch":  # not in ["imagenet", "imagenet-sketch", "inaturalist"]:
                    cmds.append(filter2command(filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(
                            f"More than one result with {metric_key} values",
                            list(sdf[metric_key]),
                        )
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dataset == "in9bggap":
                    filter_mr = dict(filter2, dataset="in9mixedrand")
                    sdf = select_rows(test_runs_df, filter_mr, allow_missing=False)
                    my_val = my_val - np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][group].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][group])
                is_best_grp &= len(best_results_grouped[dataset][group]) > 1
                sc_base = 0  # np.nanmedian(best_results[dataset])
                sc_top = 1  # np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    # latex_table += "     "
                    pass
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    # latex_table += "     "
                    pass
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {clusterer}:")
print()
print()
print(latex_table)

In [None]:
for cmd in cmds:
    print(cmd)

### Refactored to use a results matrix

In [None]:
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
clusterers = ["KMeans"]
metric_key = "silhouette-euclidean_true"  # AMI  AMI_clus  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
merge_model_group_column = True

# override_fields = {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None", "pca_variance": None, "ndim_reduced": None, "zscore": False}
override_fields = {
    "dim_reducer_man": "None",
    "ndim_reduced_man": None,
    "dim_reducer": "PCA",
    "pca_variance": 0.9,
    "ndim_reduced": None,
    "zscore": True,
}
# override_fields = {}
fixed_sc_base = 0

In [None]:
clusterers = ["SpectralClustering"]
metric_key = "AMI"  # AMI  AMI_clus  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
merge_model_group_column = True

override_fields = {}
fixed_sc_base = None

In [None]:
clusterers = ["LouvainCommunities"]
metric_key = "AMI"  # AMI  AMI_clus  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
merge_model_group_column = True

override_fields = {}
fixed_sc_base = None

In [None]:
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
metric_key = "AMI"  # AMI  AMI_clus  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
merge_model_group_column = True

override_fields = {}
fixed_sc_base = None

In [None]:
# override_fields = {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None", "pca_variance": None, "ndim_reduced": None, "zscore": False}
# override_fields = {"dim_reducer_man": "None", "ndim_reduced_man": None, "dim_reducer": "PCA", "pca_variance": 0.9, "ndim_reduced": None, "zscore": True}
# fixed_sc_base = 0

if metric_key == "num_cluster_pred":
    clusterers = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True

if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"


model_groups = {
    "": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

test_datasets_grouped, dataset2sh = TEST_DATASETS_GROUPED, DATASET2SH


if False:
    # IN9 table
    model_groups = {
        # "": ["none"],
        "RN50": RESNET50_MODELS[1:] + FT_RESNET50_MODELS,
        "ViT-B": VITB16_MODELS[1:] + FT_VITB16_MODELS,
    }
    test_datasets_grouped, dataset2sh = {"in9_bg_challenge": IN9_DATASETS + ["in9bggap"]}, IN92SH


if len(clusterers) == 1:
    clustererstr = clusterers[0]
    if metric_key.endswith("_true"):
        clustererstr = "GT"
else:
    clustererstr = f"{len(clusterers)}c-avg"

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)

# test_datasets = [d for d in test_datasets if "inaturalist" not in d]
# test_datasets = [d for d in test_datasets if "bioscan" not in d]

print("Encoders:")
print(model_groups_flattened)
print()
print("Datasets:")
print(test_datasets)
print()
print("Clusterers:")
print(clusterers)
print()

result_table, cmds = build_results_table(
    model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=metric_key,
    override_fields=override_fields,
    return_cmds=True,
)
# Shaped [models, clusterers, datasets]
print("result_table.shape", result_table.shape)

# Remove clusterer-dataset combos which are NaN for any model
result_table[:, np.any(np.isnan(result_table), axis=0)] = np.nan

# Take mean over clusterers
result_table = np.nanmean(result_table, axis=1)
# Shaped [models, datasets]

print("result_table.shape", result_table.shape)

if use_rank:
    result_table_r = np.argsort(result_table, -1)[::-1, :] + 1
else:
    result_table_r = result_table


print(model_groups)


best_results = {k: [] for k in test_datasets}
best_results_grouped = {k: defaultdict(list) for k in test_datasets}

for dummy in [True, False]:
    latex_table = r"% Results for " + f"{metric_key}, {clustererstr}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = clustererstr
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\columnwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{"
    if not merge_model_group_column:
        latex_table += "l"
    latex_table += "ll" + r"r" * len(test_datasets) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    # Begin dataset group header row
    if len(test_datasets_grouped) > 1:
        if not merge_model_group_column:
            latex_table += r"& "
        latex_table += f"{'':<11s}" + r" &   "
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
        latex_table += r"\\" + "\n"
        icol = 3
        if not merge_model_group_column:
            icol += 1
        for datagroupname, datagroupset in test_datasets_grouped.items():
            latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
            icol += len(datagroupset)
        latex_table += "\n"
    # Begin main header row, with actual dataset names
    if merge_model_group_column:
        latex_table += r"\quad "
    else:
        latex_table += r"Arch. & "
    latex_table += f"{'Encoder':<11s}" + r" & FT "
    for dataset in test_datasets:
        latex_table += r"&" + "{:^15s}".format(dataset2sh.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    # Begin table contents
    latex_table += r"\midrule" + "\n"
    i_model_o = -1
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(test_datasets):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\midrule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(model_groups[group]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            i_model_o += 1
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                my_val = result_table[i_model_o, i_dataset]
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][group].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][group])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\midrule" + "\n"

    i_model_o = -1
    for i_group, group in enumerate(model_groups):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        if merge_model_group_column:
            if not group:
                latex_table += r"\quad "
            else:
                latex_table += r"\textbf{" + group + r"} --- "
        elif not group:
            latex_table += "---" + "\n"
        else:
            latex_table += group + "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            i_model_o += 1
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            if merge_model_group_column and i_model > 0:
                latex_table += r"\quad "
            latex_table += f"{model_sh:<23s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                my_val = result_table[i_model_o, i_dataset]
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][group].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dataset][group])
                is_best_grp &= len(best_results_grouped[dataset][group]) > 1
                sc_base = np.nanmedian(best_results[dataset])
                if fixed_sc_base is not None:
                    sc_base = fixed_sc_base
                sc_top = np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                if sc_top >= sc_base:
                    latex_table += r"\cc{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    # latex_table += "     "
                    pass
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    # latex_table += "     "
                    pass
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {clustererstr}:")
print()
print()
print(latex_table)

In [None]:
for cmd in cmds:
    print(cmd)

### Clusterer correlations

In [None]:
metric_key = "AMI"  # num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]

model_groups = {
    "---": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

test_datasets_grouped = TEST_DATASETS_GROUPED

if len(clusterers) == 1:
    clustererstr = clusterers[0]
else:
    clustererstr = f"{len(clusterers)}c-avg"

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)

result_table, cmds = build_results_table(
    model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=metric_key,
    return_cmds=True,
)
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
print()
print("result_table.shape", result_table.shape)  # Shaped [models, clusterers, datasets]

xx, yy, zz = np.meshgrid(
    [MODEL2SH_ARCH[k] for k in model_groups_flattened],
    clusterers,
    test_datasets,
    indexing="ij",
)
print("xx.shape", xx.shape)

result_table = np.swapaxes(result_table, 1, 2)  # Shaped [models, datasets, clusterers]
xx = np.swapaxes(xx, 1, 2)
yy = np.swapaxes(yy, 1, 2)
zz = np.swapaxes(zz, 1, 2)
result_table = np.reshape(result_table, (-1, result_table.shape[-1]))  # Shaped [models * datasets, clusterers]
xx = np.reshape(xx, (-1, xx.shape[-1]))
yy = np.reshape(yy, (-1, yy.shape[-1]))
zz = np.reshape(zz, (-1, zz.shape[-1]))
print("result_table.shape", result_table.shape)
print("xx.shape", xx.shape)

result_df = pd.DataFrame(data=result_table, columns=[CLUSTERER2SH.get(c, c) for c in clusterers])
result_df["encoder"] = xx[:, 0]
result_df["dataset"] = zz[:, 0]

In [None]:
for cmd in cmds:
    print(cmd)

In [None]:
result_df

In [None]:
# seaborn.set_context("paper", rc={"axes.labelsize": 12})
seaborn.set_context("paper", font_scale=1.4)
seaborn.pairplot(result_df)

plt.savefig(os.path.join(FIGS_DIR, f"scatter-clusterers_{metric_key}.pdf"), bbox_inches="tight")

In [None]:
# seaborn.set_context("paper", font_scale=1.4)
seaborn.set_context("paper", rc={"axes.labelsize": 25})
palette = {v: MODEL2COLORRGB[k] for k, v in MODEL2SH_ARCH.items()}
seaborn.pairplot(result_df, diag_kind="hist")

plt.savefig(
    os.path.join(FIGS_DIR, f"scatter-clusterers_{metric_key}_biglab.pdf"),
    bbox_inches="tight",
)

In [None]:
seaborn.set_context("paper", font_scale=1.4)
# seaborn.set_context("paper", rc={"axes.labelsize": 22})
palette = {v: MODEL2COLORRGB[k] for k, v in MODEL2SH_ARCH.items()}
seaborn.pairplot(result_df, hue="encoder", palette=palette, diag_kind="hist")

plt.savefig(
    os.path.join(FIGS_DIR, f"scatter-clusterers_{metric_key}_col-enc.pdf"),
    bbox_inches="tight",
)

In [None]:
# seaborn.set_context("paper", rc={"axes.labelsize": 12})
seaborn.set_context("paper", font_scale=1.4)
seaborn.pairplot(result_df, hue="dataset")

plt.savefig(
    os.path.join(FIGS_DIR, f"scatter-clusterers_{metric_key}_col-dataset.pdf"),
    bbox_inches="tight",
)

### Encoder correlation

In [None]:
metric_key = "AMI"  # num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]

model_groups = {
    "---": ["none"],
    "RN50": RESNET50_MODELS + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS + FT_VITB16_MODELS,
}

test_datasets_grouped = TEST_DATASETS_GROUPED

if len(clusterers) == 1:
    clustererstr = clusterers[0]
else:
    clustererstr = f"{len(clusterers)}c-avg"

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

test_datasets = []
for datagroupname, datagroupset in test_datasets_grouped.items():
    test_datasets.extend(datagroupset)

result_table, cmds = build_results_table(
    model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=metric_key,
    return_cmds=True,
)
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
print()
print("result_table.shape", result_table.shape)  # Shaped [models, clusterers, datasets]

xx, yy, zz = np.meshgrid(
    [MODEL2SH_ARCH[k] for k in model_groups_flattened],
    clusterers,
    test_datasets,
    indexing="ij",
)
print("xx.shape", xx.shape)

result_table = np.swapaxes(result_table, 0, 2)  # Shaped [datasets, models, clusterers]
xx = np.swapaxes(xx, 0, 2)
yy = np.swapaxes(yy, 0, 2)
zz = np.swapaxes(zz, 0, 2)
result_table = np.reshape(result_table, (-1, result_table.shape[-1]))  # Shaped [datasets * models, clusterers]
xx = np.reshape(xx, (-1, xx.shape[-1]))
yy = np.reshape(yy, (-1, yy.shape[-1]))
zz = np.reshape(zz, (-1, zz.shape[-1]))
print("result_table.shape", result_table.shape)
print("xx.shape", xx.shape)

result_df = pd.DataFrame(data=result_table, columns=model_groups_flattened)
result_df["encoder"] = xx[:, 0]
result_df["clusterer"] = yy[:, 0]
result_df["dataset"] = zz[:, 0]

In [None]:
result_df

In [None]:
# seaborn.set_context("paper", rc={"axes.labelsize": 12})
seaborn.set_context("paper", font_scale=1.4)
seaborn.pairplot(result_df)

# plt.savefig(os.path.join(FIGS_DIR, f"scatter-encoders_{metric_key}.pdf"), bbox_inches="tight")

## Correlation between AMI and Silhouette

In [None]:
metric_key1 = "silhouette-euclidean_pred"  # silhouette-euclidean_pred | silhouette-og-euclidean_pred
metric_key2 = "AMI"

override_fields = {}

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]

fig, ax = plt.subplots(1, len(backbones), sharey=True, figsize=(6, 3))

for i_backbone, backbone in enumerate(backbones):
    my_valx_method = {clusterer: [] for clusterer in CLUSTERERS}
    my_valy_method = {clusterer: [] for clusterer in CLUSTERERS}

    print(backbone)
    print(CLUSTERERS)
    print(TEST_DATASETS)
    print(MODEL_GROUPS[backbone])
    print()

    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        for i_dataset, dataset in enumerate(TEST_DATASETS):
            for i_model, model in enumerate(list(MODEL_GROUPS[backbone])):
                if model in [
                    "timm_vit_base_patch16_224.mae",
                    "mae_pretrain_vit_base_global",
                ]:
                    # print(f"Skipping {model}")
                    continue
                    pass
                if model.startswith("random"):
                    continue
                    pass
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                my_valx_method[clusterer].append(np.nanmedian(sdf[metric_key1]))
                my_valy_method[clusterer].append(np.nanmedian(sdf[metric_key2]))

    my_valx_method = {k: np.array(v) for k, v in my_valx_method.items()}
    my_valy_method = {k: np.array(v) for k, v in my_valy_method.items()}
    my_valx_overall = np.concatenate([my_valx_method[clusterer] for clusterer in CLUSTERERS])
    my_valy_overall = np.concatenate([my_valy_method[clusterer] for clusterer in CLUSTERERS])
    my_cols = np.concatenate(
        [
            np.tile(
                CLUSTERER2COLORRGB.get(clusterer, (0.0, 0.0, 0.0)),
                [len(my_valx_method[clusterer]), 1],
            )
            for clusterer in CLUSTERERS
        ]
    )
    indices = np.arange(len(my_valx_overall))
    np.random.shuffle(indices)
    ax[i_backbone].scatter(
        my_valx_overall[indices],
        my_valy_overall[indices],
        color=my_cols[indices],
        s=20,
        alpha=0.5,
    )
    ax[i_backbone].set_xlabel(r"$S$" if metric_key1.startswith("silhouette") else metric_key1)
    if i_backbone == 0:
        ax[i_backbone].set_ylabel(metric_key2)
    ax[i_backbone].set_xlim(-1.05, 1.05)
    ax[i_backbone].set_ylim(-0.05, max(max(my_valy_overall), 0.95))
    ax[i_backbone].set_title(backbone)
    print(f"{backbone:<20s} Correlation coef")
    cors = []
    for clusterer in CLUSTERERS:
        sel = (~np.isnan(my_valx_method[clusterer])) & (~np.isnan(my_valy_method[clusterer]))
        cor = np.corrcoef(my_valx_method[clusterer][sel], my_valy_method[clusterer][sel])[0, 1]
        cors.append(cor)
        print(f"{clusterer:<20s} {cor:.4f}")
    print(f"{'Average':<20s} {np.nanmean(cors):.4f}")
    sel = (~np.isnan(my_valx_overall)) & (~np.isnan(my_valy_overall))
    cor = np.corrcoef(my_valx_overall[sel], my_valy_overall[sel])[0, 1]
    print(f"{'Overall':<20s} {cor:.4f}")
    print()
    ax[i_backbone].text(-0.85, 0.85, f"$r={cor:.2f}$")
    ax[i_backbone].text(-0.85, 0.75, r"$\bar{r}=" + f"{np.mean(cors):.2f}$")

label_fn = lambda c, marker: plt.plot([], [], color=c, ls="None", marker=marker, linewidth=6)[0]  # noqa:E731
handles = [label_fn(CLUSTERER2COLORRGB.get(clusterer), "o") for clusterer in CLUSTERERS]
data_labels = [CLUSTERER2SH.get(c, c) for c in CLUSTERERS]
ax[1].legend(handles, data_labels, loc="center left", bbox_to_anchor=(1, 0.5))

fig.savefig(
    os.path.join(FIGS_DIR, f"scatter__{metric_key1}__{metric_key2}.pdf"),
    bbox_inches="tight",
)

## Rankings

In [None]:
exclude_random = True

metric_key1 = "AMI"

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
test_datasets = TEST_DATASETS

figenc, axenc = plt.subplots(1, 2, figsize=(6, 2))
figclus, axclus = plt.subplots(1, 2, figsize=(6, 2))

for i_backbone, backbone in enumerate(backbones):
    model_group = MODEL_GROUPS[backbone]
    if exclude_random:
        model_group = [m for m in model_group if not m.startswith("random")]

    result_table = build_results_table(
        model_group,
        CLUSTERERS,
        test_datasets,
        metric_keys=metric_key1,
    )
    result_table[np.isnan(result_table)] = -100.0

    print(backbone)
    print(model_group)

    # RANK PER ENCODER - go through each dataset, look at each clusterer,
    # and determine the rank of each encoder in that setting
    print(list(model_group))
    ranks_encoders = np.nan * np.ones((len(model_group), len(CLUSTERERS), len(test_datasets)))
    for i_dataset, dataset in enumerate(test_datasets):
        for i_clusterer, clusterer in enumerate(CLUSTERERS):
            cluster_data = result_table[:, i_clusterer, i_dataset]
            if np.all(cluster_data == cluster_data[0]) or np.all(np.isnan(cluster_data)):
                print(f"Skipping {dataset} {clusterer} (all same)")
                continue
            if np.any(cluster_data == -100.0):
                print(f"Skipping {dataset} {clusterer} (incomplete)")
                continue
            rank = np.argsort(cluster_data)[::-1]
            ranks_encoders[:, i_clusterer, i_dataset] = 1 + rank.argsort()
    mean_rank_encoders = np.nanmean(ranks_encoders, axis=(1, 2))
    std_rank_encoders = np.nanstd(ranks_encoders, axis=(1, 2))
    # order = np.argsort(mean_rank_encoders)
    order = np.arange(len(model_group))

    for i_plot, i_model in enumerate(order):
        axenc[i_backbone].barh(
            i_plot,
            mean_rank_encoders[i_model],
            xerr=std_rank_encoders[i_model],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=MODEL2COLORRGB.get(model_group[i_model], (0.0, 0.0, 0.0)),
            capsize=2,
            zorder=10,
        )

    axenc[i_backbone].invert_yaxis()
    axenc[i_backbone].set_yticks([])
    axenc[i_backbone].set_yticklabels([])
    axenc[i_backbone].set_xticks(np.arange(1, 1 + len(model_group)))
    axenc[i_backbone].set_xlim([0, 0.5 + len(model_group)])
    axenc[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axenc[i_backbone].set_title(backbone)

    # RANK PER CLUSTERER - go through each dataset, look at each encoder,
    # and determine the rank of each clusterer in that setting

    print(CLUSTERERS)
    ranks_clusterers = np.nan * np.ones((len(model_group), len(CLUSTERERS), len(test_datasets)))
    for i_dataset, dataset in enumerate(test_datasets):
        for i_encoder, encoder in enumerate(model_group):
            encoder_data = result_table[i_encoder, :, i_dataset]
            if np.all(encoder_data == encoder_data[0]) or np.all(np.isnan(encoder_data)):
                print(f"Skipping {dataset} {encoder} (all same)")
                continue
            if np.any(encoder_data == -100.0):
                print(f"Skipping {dataset} {encoder} (incomplete)")
                continue
            rank = np.argsort(encoder_data)[::-1]
            ranks_clusterers[i_encoder, :, i_dataset] = 1 + rank.argsort()
    mean_rank_clusters = np.nanmean(ranks_clusterers, axis=(0, 2))
    std_rank_clusters = np.nanstd(ranks_clusterers, axis=(0, 2))
    # order = np.argsort(mean_rank_clusters)
    order = np.arange(len(CLUSTERERS))

    for i_plot, i_clusterer in enumerate(order):
        axclus[i_backbone].barh(
            i_plot,
            mean_rank_clusters[i_clusterer],
            xerr=std_rank_clusters[i_clusterer],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=CLUSTERER2COLORSTR.get(CLUSTERERS[i_clusterer], (0.0, 0.0, 0.0)),
            capsize=2,
            zorder=10,
        )

    axclus[i_backbone].invert_yaxis()
    axclus[i_backbone].set_yticks([])
    axclus[i_backbone].set_yticklabels([])
    axclus[i_backbone].set_xticks(np.arange(1, 1 + len(CLUSTERERS)))
    axclus[i_backbone].set_xlim([0, 0.6 + len(CLUSTERERS)])
    axclus[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axclus[i_backbone].set_title(backbone)

    axclus[i_backbone].set_xlabel("Rank")
    axenc[i_backbone].set_xlabel("Rank")

label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731

model_names = list(MODEL_GROUPS["ResNet-50"]) + MODEL_GROUPS["ViT-B"][-2:]
if exclude_random:
    model_names = [m for m in model_names if not m.startswith("random")]
handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_names]
axenc[1].legend(
    handles_enc,
    [MODEL2SH[x] for x in model_names],
    loc="center left",
    bbox_to_anchor=(1, 0.5),
)

handles_clus = [label_fn(CLUSTERER2COLORRGB[clusterer], "-") for clusterer in CLUSTERERS]
axclus[1].legend(handles_clus, CLUSTERERS, loc="center left", bbox_to_anchor=(1, 0.5))

figenc.savefig(os.path.join(FIGS_DIR, "ranking_enc.pdf"), bbox_inches="tight")
figclus.savefig(os.path.join(FIGS_DIR, "ranking_clus.pdf"), bbox_inches="tight")

### With grouped datasets

In [None]:
metric_key1 = "AMI"

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]

for test_group, test_datasets in TEST_DATASETS_GROUPED.items():
    figenc, axenc = plt.subplots(1, 2, figsize=(6, 1.6))
    figclus, axclus = plt.subplots(1, 2, figsize=(6, 2))

    for i_backbone, backbone in enumerate(backbones):
        result_table = build_results_table(
            MODEL_GROUPS[backbone],
            CLUSTERERS,
            test_datasets,
            metric_keys=metric_key1,
        )
        result_table[np.isnan(result_table)] = -100.0

        print(backbone)
        print(MODEL_GROUPS[backbone])

        # RANK PER ENCODER - go through each dataset, look at each clusterer,
        # and determine the rank of each encoder in that setting
        print(list(MODEL_GROUPS[backbone]))
        ranks_encoders = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
        for i_dataset, dataset in enumerate(test_datasets):
            for i_clusterer, clusterer in enumerate(CLUSTERERS):
                cluster_data = result_table[:, i_clusterer, i_dataset]
                if np.all(cluster_data == cluster_data[0]) or np.all(np.isnan(cluster_data)):
                    print(f"Skipping {dataset} {clusterer} (all same)")
                    continue
                if np.any(cluster_data == -100.0):
                    print(f"Skipping {dataset} {clusterer} (incomplete)")
                    continue
                rank = np.argsort(cluster_data)[::-1]
                ranks_encoders[:, i_clusterer, i_dataset] = 1 + rank.argsort()
        mean_rank_encoders = np.nanmean(ranks_encoders, axis=(1, 2))
        std_rank_encoders = np.nanstd(ranks_encoders, axis=(1, 2))
        # order = np.argsort(mean_rank_encoders)
        order = np.arange(len(MODEL_GROUPS[backbone]))

        for i_plot, i_model in enumerate(order):
            axenc[i_backbone].barh(
                i_plot,
                mean_rank_encoders[i_model],
                xerr=std_rank_encoders[i_model],
                align="center",
                alpha=0.6,
                ecolor="black",
                color=MODEL2COLORRGB.get(MODEL_GROUPS[backbone][i_model], (0.0, 0.0, 0.0)),
                capsize=2,
                zorder=10,
            )

        axenc[i_backbone].invert_yaxis()
        axenc[i_backbone].set_yticks([])
        axenc[i_backbone].set_yticklabels([])
        axenc[i_backbone].set_xticks(np.arange(1, 1 + len(MODEL_GROUPS[backbone])))
        axenc[i_backbone].set_xlim([0, 0.5 + len(MODEL_GROUPS[backbone])])
        axenc[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
        axenc[i_backbone].set_title(f"{DATASETGROUP2TITLE.get(test_group, test_group)}, {backbone}")
        axenc[i_backbone].set_xlabel("Rank")

        # RANK PER CLUSTERER - go through each dataset, look at each encoder,
        # and determine the rank of each clusterer in that setting

        print(CLUSTERERS)
        ranks_clusterers = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
        for i_dataset, dataset in enumerate(test_datasets):
            for i_encoder, encoder in enumerate(MODEL_GROUPS[backbone]):
                encoder_data = result_table[i_encoder, :, i_dataset]
                if np.all(encoder_data == encoder_data[0]) or np.all(np.isnan(encoder_data)):
                    print(f"Skipping {dataset} {encoder} (all same)")
                    continue
                if np.any(encoder_data == -100.0):
                    print(f"Skipping {dataset} {encoder} (incomplete)")
                    continue
                rank = np.argsort(encoder_data)[::-1]
                ranks_clusterers[i_encoder, :, i_dataset] = 1 + rank.argsort()
        mean_rank_clusters = np.nanmean(ranks_clusterers, axis=(0, 2))
        std_rank_clusters = np.nanstd(ranks_clusterers, axis=(0, 2))
        # order = np.argsort(mean_rank_clusters)
        order = np.arange(len(CLUSTERERS))

        for i_plot, i_clusterer in enumerate(order):
            axclus[i_backbone].barh(
                i_plot,
                mean_rank_clusters[i_clusterer],
                xerr=std_rank_clusters[i_clusterer],
                align="center",
                alpha=0.6,
                ecolor="black",
                color=CLUSTERER2COLORSTR.get(CLUSTERERS[i_clusterer], (0.0, 0.0, 0.0)),
                capsize=2,
                zorder=10,
            )

        axclus[i_backbone].invert_yaxis()
        axclus[i_backbone].set_yticks([])
        axclus[i_backbone].set_yticklabels([])
        axclus[i_backbone].set_xticks(np.arange(1, 1 + len(CLUSTERERS)))
        axclus[i_backbone].set_xlim([0, 0.6 + len(CLUSTERERS)])
        axclus[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
        axclus[i_backbone].set_title(f"{DATASETGROUP2TITLE.get(test_group, test_group)}, {backbone}")
        axclus[i_backbone].set_xlabel("Rank")

    label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731

    model_names = list(MODEL_GROUPS["ResNet-50"]) + MODEL_GROUPS["ViT-B"][-2:]
    handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_names]
    axenc[1].legend(
        handles_enc,
        [MODEL2SH[x] for x in model_names],
        loc="center left",
        bbox_to_anchor=(1, 0.5),
    )

    handles_clus = [label_fn(CLUSTERER2COLORRGB[clusterer], "-") for clusterer in CLUSTERERS]
    axclus[1].legend(handles_clus, CLUSTERERS, loc="center left", bbox_to_anchor=(1, 0.5))

    figenc.savefig(os.path.join(FIGS_DIR, f"ranking_enc__{test_group}.pdf"), bbox_inches="tight")
    figclus.savefig(os.path.join(FIGS_DIR, f"ranking_clus__{test_group}.pdf"), bbox_inches="tight")

### Encoder rankings and Delta v2

In [None]:
def plot_ranking_encoders(
    model_group,
    clusterers,
    test_datasets,
    metric_key="AMI",
    ax=None,
    use_rank=True,
    hide_ft=False,
):
    show_error = use_rank

    if ax is None:
        ax = plt.gca()

    if exclude_random:
        model_group = [m for m in model_group if not m.startswith("random")]

    print("Encoders:")
    print(model_group)

    print("Datasets:")
    print(test_datasets)

    result_table = build_results_table(
        model_group,
        clusterers,
        test_datasets,
        metric_keys=metric_key,
    )
    # Shaped [models, clusterers, datasets]

    # Remove clusterer-dataset combos which are NaN for any model
    result_table[:, np.any(np.isnan(result_table), axis=0)] = np.nan

    # Take mean over clusterers
    result_table = np.nanmean(result_table, axis=1)
    # Shaped [models, datasets]

    print(result_table.shape)

    # Scale up to be a percentage
    result_table *= 100.0

    if use_rank:
        result_table_r = np.argsort(result_table, -1)[::-1, :] + 1
    else:
        result_table_r = result_table

    # Take mean and stdev over samples
    mu = np.mean(result_table_r, axis=-1)
    sd = np.std(result_table_r, axis=-1)

    # Do statistical tests
    print()
    print("Signed rank tests:")
    print(f"{result_table.shape[1]} samples")
    print()
    jj = np.argsort(mu)
    print("Ordering:")
    for i in jj:
        print(f"  {model_group[i] + ' ':.<32s} {mu[i]}")

    idx_low = jj[0]
    print(f"Lowest {metric_key}: {model_group[idx_low]}")
    for i in jj[1:]:
        wtest = scipy.stats.wilcoxon(result_table[idx_low, :], result_table[i, :], method="exact")
        print(f"  vs {model_group[i] + ' ':.<32s} Wilcoxon pvalue={wtest.pvalue}")

    idx_high = jj[-1]
    print(f"Highest {metric_key}: {model_group[idx_high]}")
    for i in jj[:-1]:
        wtest = scipy.stats.wilcoxon(result_table[idx_high, :], result_table[i, :], method="exact")
        print(f"  vs {model_group[i] + ' ':.<32s} Wilcoxon pvalue={wtest.pvalue}")

    for model in model_group:
        if model in FT_MODELS:
            continue

    # order = np.argsort(mean_rank_encoders)
    order = np.arange(len(model_group))

    for i_plot, i_model in enumerate(order):
        ax.barh(
            i_plot,
            mu[i_model],
            xerr=sd[i_model],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=MODEL2COLORRGB.get(model_group[i_model], (0.0, 0.0, 0.0)),
            capsize=4,
            zorder=10,
            hatch="//" if model_group[i_model] in FT_MODELS else None,
        )
        if show_error:
            ax.plot(
                mu[i_model],
                i_plot,
                "ok",
                markerfacecolor="none",
                zorder=11,
            )

    labels = [MODEL2SH.get(c, c) for c in model_group]
    if hide_ft:
        labels = [m.replace(" [FT]", "") for m in labels]
    # ax.tick_params(axis="x", labelsize=12)
    # ax.tick_params(axis="y", labelsize=12)
    ax.invert_yaxis()
    ax.set_yticks(np.arange(len(model_group)))
    ax.set_yticklabels(labels)
    if use_rank:
        ax.set_xticks(np.arange(1, 1 + len(model_group)))
        ax.set_xlim([0, 0.5 + len(model_group)])
        ax.xaxis.grid(True, zorder=1, alpha=0.5)

    if use_rank:
        ax.set_xlabel("Rank", fontsize=12)
    else:
        ax.set_xlabel(metric_key, fontsize=12)

    if False:
        label_fn = lambda c, ls: ax.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731
        handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_group]
        ax.legend(
            handles_enc,
            [MODEL2SH[x] for x in model_names],
            loc="center left",
            bbox_to_anchor=(1, 0.5),
        )

In [None]:
def plot_encoders_vs_sup(
    model_group,
    clusterers,
    test_datasets,
    metric_key="AMI",
    ax=None,
    use_rank=False,
    hide_ft=False,
):
    if ax is None:
        ax = plt.gca()

    if exclude_random:
        model_group = [m for m in model_group if not m.startswith("random")]

    print("Encoders:")
    print(model_group)

    print("Datasets:")
    print(test_datasets)

    result_table = build_results_table(
        model_group,
        clusterers,
        test_datasets,
        metric_keys=metric_key,
    )
    # Shaped [models, clusterers, datasets]

    # Remove clusterer-dataset combos which are NaN for any model
    for ix, iy, iz in zip(*np.where(np.isnan(result_table))):
        print(f"    Missing value for {model_group[ix]}  {clusterers[iy]}  {test_datasets[iz]}")
    select = np.any(np.isnan(result_table), axis=0)
    for iy, iz in zip(*np.where(select)):
        print(f"    Removing datapoints for all models for  {clusterers[iy]}  {test_datasets[iz]}")
    result_table[:, select] = np.nan

    # Take mean over clusterers
    result_table = np.nanmean(result_table, axis=1)
    # Shaped [models, datasets]

    print(result_table.shape)

    # Scale up to be a percentage
    result_table *= 100.0

    # Compare to baseline model
    print(f"Comparing vs {model_group[0]}")
    result_table = result_table - result_table[[0], :]

    if use_rank:
        result_table_r = np.argsort(result_table, -1)[::-1, :] + 1
    else:
        result_table_r = result_table

    # Take mean and stdev over samples
    mu = np.mean(result_table_r, axis=-1)
    sd = np.std(result_table_r, axis=-1)

    # Do statistical tests
    if False:
        # Old version of tests, from when comparing ranks
        print()
        print("Signed rank tests:")
        print(f"{result_table.shape[1]} samples")
        print()
        jj = np.argsort(mu)
        print("Ordering:")
        for i in jj:
            print(f"  {model_group[i] + ' ':.<32s} {mu[i]}")

        idx_low = jj[0]
        print(f"Lowest {metric_key}: {model_group[idx_low]}")
        for i in jj[1:]:
            wtest = scipy.stats.wilcoxon(result_table[idx_low, :], result_table[i, :], method="exact")
            print(f"  vs {model_group[i] + ' ':.<32s} Wilcoxon pvalue={wtest.pvalue}")

        idx_high = jj[-1]
        print(f"Highest {metric_key}: {model_group[idx_high]}")
        for i in jj[:-1]:
            wtest = scipy.stats.wilcoxon(result_table[idx_high, :], result_table[i, :], method="exact")
            print(f"  vs {model_group[i] + ' ':.<32s} Wilcoxon pvalue={wtest.pvalue}")

    print()
    print("Paired t-tests:")
    print(f"{result_table.shape[1]} samples")
    print()
    for i in range(1, result_table.shape[0]):
        ttest = scipy.stats.ttest_rel(result_table[0, :], result_table[i, :])
        print(f"  {model_group[0]:<10s} vs {model_group[i] + ' ':.<32s} {mu[i]:+6.2f}  Paired t-test pvalue={ttest.pvalue}")

    for model in model_group:
        if model in FT_MODELS:
            continue

    order = np.arange(1, len(model_group))
    for i_model in order:
        ax.barh(
            i_model,
            mu[i_model],
            xerr=sd[i_model],
            align="center",
            alpha=0.6,
            ecolor="black",
            # color=MODEL2COLORRGB.get(model_group[i_model], (0.0, 0.0, 0.0)),
            color="tab:red" if mu[i_model] < 0 else "tab:green",
            capsize=4,
            zorder=10,
            # hatch="//" if model_group[i_model] in FT_MODELS else None,
        )
        ax.plot(
            mu[i_model],
            i_model,
            "ok",
            markerfacecolor="none",
            zorder=11,
        )

    labels = [MODEL2SH.get(model_group[a], model_group[a]) for a in order]
    if hide_ft:
        labels = [m.replace(" [FT]", "") for m in labels]
    # ax.tick_params(axis="x", labelsize=12)
    # ax.tick_params(axis="y", labelsize=12)
    ax.invert_yaxis()
    ax.set_yticks(order)
    ax.set_yticklabels(labels)
    if use_rank:
        ax.set_xticks(np.arange(1, 1 + len(model_group)))
        ax.set_xlim([0, 0.5 + len(model_group)])
        ax.xaxis.grid(True, zorder=1, alpha=0.5)
    else:
        # ax.set_xlim([0, 0.5 + len(model_group)])
        ax.grid("x")
        # ax.set_xlim([-5, 5])
        ymin, ymax = ax.get_ylim()
        ax.vlines([0], ymin, ymax, "k")
        ax.set_ylim([ymin, ymax])

    if use_rank:
        ax.set_xlabel("Rank", fontsize=12)
    else:
        ax.set_xlabel(f"Δ{metric_key} (p.p.)", fontsize=12)

    if False:
        label_fn = lambda c, ls: ax.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731
        handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_group]
        ax.legend(
            handles_enc,
            [MODEL2SH[x] for x in model_names],
            loc="center left",
            bbox_to_anchor=(1, 0.5),
        )

In [None]:
exclude_random = True
metric_key = "AMI"
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
test_datasets = TEST_DATASETS
# model_group = RESNET50_MODELS + VITB16_MODELS
# model_group = RESNET50_MODELS + FT_RESNET50_MODELS + VITB16_MODELS + FT_VITB16_MODELS
# model_group = RESNET50_MODELS + FT_RESNET50_MODELS
# model_group = VITB16_MODELS + FT_VITB16_MODELS

model_group = RESNET50_MODELS
# model_group = RESNET50_MODELS_INTERLEAVED
# model_group = VITB16_MODELS_INTERLEAVED

hf = plt.figure(figsize=(6, 4))

plot_encoders_vs_sup(model_group, clusterers, test_datasets)

In [None]:
# Vertical orientation of ranking or delta subplots

# plotting_fn, plotname = plot_ranking_encoders, "rank"
plotting_fn, plotname = plot_encoders_vs_sup, "delta"

exclude_random = True
metric_key = "AMI"
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
# clusterers = ["AC w/o C"]
test_datasets_grouped = TEST_DATASETS_GROUPED

model_groups = {
    "ResNet-50": [m for m in RESNET50_MODELS if not m.startswith("random")],
    "ViT-B": [m for m in VITB16_MODELS if not m.startswith("random")],
}
fname_suffix = ""
hide_ft = False
xtickvalues = [-40, -20, 0, 20] if plotname == "delta" else None

if False:
    model_groups = {
        "ResNet-50": [m for m in RESNET50_MODELS_INTERLEAVED if not m.startswith("random")],
        "ViT-B": [m for m in VITB16_MODELS_INTERLEAVED if not m.startswith("random")],
    }
    fname_suffix = "_FT"
    hide_ft = False

if False:
    model_groups = {
        "ResNet-50": ["resnet50"] + FT_RESNET50_MODELS,
        "ViT-B": ["vitb16"] + FT_VITB16_MODELS,
    }
    fname_suffix = "_FTonly"
    hide_ft = True
    xtickvalues = None


nmodelgroups = len(model_groups)
ndatagroups = len(test_datasets_grouped)

sharex = plotname == "delta"
fig, axs = plt.subplots(
    ndatagroups,
    nmodelgroups,
    figsize=(nmodelgroups * 3, ndatagroups * 2),
    sharex=sharex,
)
plt.subplots_adjust(wspace=0.5)

for i_backbone, backbone in enumerate(model_groups):
    for i_domain, domain in enumerate(test_datasets_grouped):
        print()
        print(f"{backbone}  {domain}")
        print()
        ax = axs[i_domain, i_backbone]
        plotting_fn(
            model_groups[backbone],
            clusterers,
            test_datasets_grouped[domain],
            metric_key=metric_key,
            ax=ax,
            hide_ft=hide_ft,
        )
        if i_backbone == 0:
            ax.set_ylabel(domain)
        if i_domain < ndatagroups - 1:
            if not sharex:
                ax.set_xticklabels([])
            ax.set_xlabel("")

if xtickvalues is not None:
    print("Overriding xtick values")
    ax.set_xticks(xtickvalues)

if len(clusterers) == 1:
    clusterer_suffix = clusterers[0].replace(" ", "").replace("/", "")
else:
    clusterer_suffix = f"{len(clusterers)}c-avg"
fname = f"enc_{plotname}_{metric_key}_bydomain{fname_suffix}_{clusterer_suffix}.pdf"
print(f"Saving to {fname}")
fig.savefig(os.path.join(FIGS_DIR, fname), bbox_inches="tight")

In [None]:
# Vertical orientation of subplots
exclude_random = True
metric_key = "AMI"
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
# clusterers = ["AC w/o C"]
test_datasets_grouped = TEST_DATASETS_GROUPED


model_groups = {
    "ResNet-50": [m for m in RESNET50_MODELS if not m.startswith("random")],
    "ViT-B": [m for m in VITB16_MODELS if not m.startswith("random")],
}
fname_suffix = ""
hide_ft = False

if False:
    model_groups = {
        "ResNet-50": [m for m in RESNET50_MODELS_INTERLEAVED if not m.startswith("random")],
        "ViT-B": [m for m in VITB16_MODELS_INTERLEAVED if not m.startswith("random")],
    }
    fname_suffix = "_FT"
    hide_ft = False

if False:
    model_groups = {
        "ResNet-50": ["resnet50"] + FT_RESNET50_MODELS,
        "ViT-B": ["vitb16"] + FT_VITB16_MODELS,
    }
    fname_suffix = "_FTonly"
    hide_ft = True


nmodelgroups = len(model_groups)
ndatagroups = len(test_datasets_grouped)

fig, axs = plt.subplots(ndatagroups, nmodelgroups, figsize=(nmodelgroups * 3, ndatagroups * 2), sharex=True)
plt.subplots_adjust(wspace=0.5)

for i_backbone, backbone in enumerate(model_groups):
    for i_domain, domain in enumerate(test_datasets_grouped):
        print()
        print(f"{backbone}  {domain}")
        print()
        ax = axs[i_domain, i_backbone]
        plot_encoders_vs_sup(
            model_groups[backbone],
            clusterers,
            test_datasets_grouped[domain],
            ax=ax,
            hide_ft=hide_ft,
        )
        if i_backbone == 0:
            ax.set_ylabel(domain, fontsize=12)
        if i_domain < ndatagroups - 1:
            # ax.set_xticklabels([])
            ax.set_xlabel("")
        if i_domain == 0:
            ax.set_title(backbone)

fig.savefig(
    os.path.join(FIGS_DIR, "enc_delta_AMI_bydomain" + fname_suffix + ".pdf"),
    bbox_inches="tight",
)

In [None]:
# Horizontal orientation of ranking subplots

# plotting_fn, plotname = plot_ranking_encoders, "rank"
plotting_fn, plotname = plot_encoders_vs_sup, "delta"

exclude_random = True
metric_key = "AMI"
# clusterers = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation"]
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]
# clusterers = ["AC w/o C"]
test_datasets_grouped = TEST_DATASETS_GROUPED

model_groups = {
    "ResNet-50": [m for m in RESNET50_MODELS if not m.startswith("random")],
    "ViT-B": [m for m in VITB16_MODELS if not m.startswith("random")],
}
fname_suffix = ""
hide_ft = False
xtickvalues = [-40, -20, 0, 20] if plotname == "delta" else None

if False:
    model_groups = {
        "ResNet-50": [m for m in RESNET50_MODELS_INTERLEAVED if not m.startswith("random")],
        "ViT-B": [m for m in VITB16_MODELS_INTERLEAVED if not m.startswith("random")],
    }
    fname_suffix = "_FT"
    hide_ft = False

if True:
    model_groups = {
        "ResNet-50": ["resnet50"] + FT_RESNET50_MODELS,
        "ViT-B": ["vitb16"] + FT_VITB16_MODELS,
    }
    fname_suffix = "_FTonly"
    hide_ft = True
    xtickvalues = None


nmodelgroups = len(model_groups)
ndatagroups = len(test_datasets_grouped)

sharex = plotname == "delta"
fig, axs = plt.subplots(
    nmodelgroups,
    ndatagroups,
    figsize=(ndatagroups * 2, nmodelgroups * 1.75),
    sharex=sharex,
)
axs = axs.T

for i_domain, domain in enumerate(test_datasets_grouped):
    for i_backbone, backbone in enumerate(model_groups):
        print()
        print(f"{backbone}  {domain}")
        print()
        ax = axs[i_domain, i_backbone]
        plotting_fn(
            model_groups[backbone],
            clusterers,
            test_datasets_grouped[domain],
            metric_key=metric_key,
            ax=ax,
            hide_ft=hide_ft,
        )
        if i_domain > 0:
            ax.set_yticklabels([])
        if i_backbone == 0:
            ax.set_title(domain)
        if i_backbone < nmodelgroups - 1:
            if not sharex:
                ax.set_xticklabels([])
            ax.set_xlabel("")
        if i_domain == 0:
            ax.set_ylabel(backbone, fontsize=12)

if xtickvalues is not None:
    print("Overriding xtick values")
    ax.set_xticks(xtickvalues)

if len(clusterers) == 1:
    clusterer_suffix = clusterers[0].replace(" ", "").replace("/", "")
else:
    clusterer_suffix = f"{len(clusterers)}c-avg"
fname = f"horiz_enc_{plotname}_{metric_key}_bydomain{fname_suffix}_{clusterer_suffix}.pdf"
print(f"Saving to {fname}")
fig.savefig(os.path.join(FIGS_DIR, fname), bbox_inches="tight")

### Compare clusterers

In [None]:
use_rank = True
exclude_random = True
metric_key = "AMI"
clusterers = [
    "KMeans",
    "SpectralClustering",
    "AC w/ C",
    "AC w/o C",
    "AffinityPropagation",
    "HDBSCAN",
]

test_datasets = [d for d in TEST_DATASETS if not d.startswith("in9onlybg")]
model_group = RESNET50_MODELS + VITB16_MODELS
# model_group = RESNET50_MODELS + FT_RESNET50_MODELS + VITB16_MODELS + FT_VITB16_MODELS


show_error = use_rank

hf = plt.figure(figsize=(4, 3))

if exclude_random:
    model_group = [m for m in model_group if not m.startswith("random")]

print("Encoders:")
print(model_group)

print("Datasets:")
print(test_datasets)

result_table = build_results_table(
    model_group,
    clusterers,
    test_datasets,
    metric_keys=metric_key,
)
# Shaped [models, clusterers, datasets]


# Note clusterer-dataset combos which are NaN for any model, which we will remove
for ix, iy, iz in zip(*np.where(np.isnan(result_table))):
    print(f"    Missing value for {model_group[ix]}  {clusterers[iy]}  {test_datasets[iz]}")

result_table = np.swapaxes(result_table, 1, 2)
# Shaped [models, datasets, clusterers]

result_table = np.reshape(result_table, [-1, len(clusterers)])
# Shaped [models x datasets, clusterers]
print(result_table.shape)

# Remove samples which are all NaN
select = ~np.any(np.isnan(result_table), axis=-1)
result_table = result_table[select]

# Scale up to be a percentage
result_table *= 100.0

print(result_table.shape)
if use_rank:
    # Add small amount of random noise so tie breaks are allocated
    # equally and first in the array doesn't have priority.
    # noise = 1e-9 * np.random.randn(*result_table.shape)
    # noise = 0
    # BAD VERSION!
    # result_table_r = np.argsort(result_table + noise, -1)[:, ::-1] + 1
    # Correct version
    result_table_r = result_table.shape[1] - np.argsort(np.argsort(result_table, axis=1), axis=1)
    print(result_table_r.shape)
else:
    result_table_r = result_table

# Take mean and stdev over samples
mu = np.mean(result_table_r, axis=0)
sd = np.std(result_table_r, axis=0)


# Do statistical tests
print()
print(f"{result_table.shape[0]} samples")
print()
jj = np.argsort(mu)
print("Ordering:")
for i in jj:
    print(f"  {clusterers[i]:<20s} = {mu[i]}")

idx_low = jj[0]
print(f"Lowest {metric_key}: {clusterers[idx_low]}")
for i in jj[1:]:
    wtest = scipy.stats.wilcoxon(result_table[:, idx_low], result_table[:, i], method="exact")
    print(f"  vs {clusterers[i]:<20s}  pvalue={wtest.pvalue}")

idx_high = jj[-1]
print(f"Highest {metric_key}: {clusterers[idx_high]}")
for i in jj[:-1]:
    wtest = scipy.stats.wilcoxon(result_table[:, idx_high], result_table[:, i], method="exact")
    print(f"  vs {clusterers[i]:<20s}  pvalue={wtest.pvalue}")

clusterers = np.asarray(clusterers)
idx1 = np.where(np.asarray(clusterers) == "AC w/o C")[0][0]
idx2 = np.where(np.asarray(clusterers) == "AffinityPropagation")[0][0]
wtest = scipy.stats.wilcoxon(result_table[:, idx1], result_table[:, idx2], method="exact")
print(f"{metric_key} for {clusterers[idx1]} vs {clusterers[idx2]:<20s}  pvalue={wtest.pvalue}")

# order = np.argsort(mean_rank_clusters)
order = np.arange(len(clusterers))

for i_plot, i_clusterer in enumerate(order):
    plt.barh(
        i_plot,
        mu[i_clusterer],
        xerr=sd[i_clusterer] if show_error else None,
        align="center",
        alpha=0.6,
        ecolor="black",
        color=CLUSTERER2COLORSTR.get(clusterers[i_clusterer], (0.0, 0.0, 0.0)),
        capsize=4,
        zorder=10,
    )
    if show_error:
        plt.plot(
            mu[i_clusterer],
            i_plot,
            "ok",
            markerfacecolor="none",
        )

labels = [CLUSTERER2SH.get(c, c) for c in clusterers]
ax = plt.gca()
ax.tick_params(axis="x", labelsize=12)
ax.tick_params(axis="y", labelsize=12)
ax.invert_yaxis()
ax.set_yticks(np.arange(len(clusterers)))
ax.set_yticklabels(labels)
if use_rank:
    ax.set_xticks(np.arange(1, 1 + len(clusterers)))
    ax.set_xlim([0, 0.75 + len(clusterers)])
else:
    XLIM = [np.min(mu), np.max(mu)]
    XLIM = XLIM + 0.075 * np.array([-1, 1]) * (XLIM[1] - XLIM[0])
    ax.set_xlim(XLIM)
ax.xaxis.grid(True, zorder=1, alpha=0.5)

if use_rank:
    ax.set_xlabel("Rank", fontsize=15)
else:
    ax.set_xlabel(metric_key, fontsize=15)

if False:
    # Show legend
    label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731
    handles_clus = [label_fn(CLUSTERER2COLORRGB[clusterer], "-") for clusterer in clusterers]
    ax.legend(handles_clus, labels, loc="center left", bbox_to_anchor=(1, 0.5))

prfx = "ranking" if use_rank else metric_key
hf.savefig(os.path.join(FIGS_DIR, f"{prfx}_clus_overall.pdf"), bbox_inches="tight")

## Plot sample images

In [None]:
import re

from zs_ssl_clustering.io import sanitize_filename


def get_pred_path(row):
    """
    Generate path to y_pred file.
    """
    run_id = row["name"].split("__")[-1]
    fname = f"{row['partition']}-{row['dataset_name']}__{row['model']}__{run_id}.npz"
    fname = sanitize_filename(fname)
    fname = os.path.join(
        row["predictions_dir"],
        sanitize_filename(row["partition"] + f"__z{float(row['zoom_ratio'])}"),
        fname,
    )
    return fname

In [None]:
import torch
from torchvision.transforms.functional import crop, get_dimensions
from torchvision.utils import _log_api_usage_once


def center_squaring(img):
    """Crops the given image at the center.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.

    Args:
        img (PIL Image or Tensor): Image to be cropped.
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
            it is used for both directions.

    Returns:
        PIL Image or Tensor: Cropped image.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_squaring)

    _, image_height, image_width = get_dimensions(img)

    if image_height == image_width:
        return img

    crop_height = crop_width = min(image_height, image_width)

    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop(img, crop_top, crop_left, crop_height, crop_width)


class CenterSquaring(torch.nn.Module):
    """Crops the given image to the center square."""

    def __init__(self):
        super().__init__()
        _log_api_usage_once(self)

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            PIL Image or Tensor: Cropped image.
        """
        return center_squaring(img)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}"

In [None]:
from zs_ssl_clustering import datasets


def show_samples(row, nsamp=12, ds=None, save=False, clusterer="", nclusters=None, skip_existing=None):
    if skip_existing is None:
        skip_existing = save

    if clusterer:
        clusterer = clusterer.replace("/", "").replace(" ", "")
    else:
        clusterer = row["clusterer_name"]

    output_dir = "../samples"
    output_fname = f"samples__{row['dataset_name']}__{row['model']}__{clusterer}.png"
    output_fname = os.path.join(output_dir, output_fname)
    if skip_existing and os.path.exists(output_fname):
        print(f"Output {output_fname} already exists. Skipping.")
        return

    if ds is None:
        dses = datasets.fetch_image_dataset(row["dataset_name"], transform_eval=CenterSquaring())
        if row["partition"] == "train":
            ds = dses[0]
        elif row["partition"] == "test":
            ds = dses[-1]
        else:
            raise NotImplementedError()

    y_pred = np.load("../" + get_pred_path(row))["y_pred"]

    u_labels, label_count = np.unique(y_pred, return_counts=True)
    # Remove clusters with very few samples in the cluster
    # u_labels = u_labels[label_count >= nsamp]

    if nclusters is None:
        nclusters = len(u_labels)
    else:
        nclusters = min(nclusters, len(u_labels))

    fig, axs = plt.subplots(nclusters, nsamp, figsize=(nsamp / 2, nclusters / 2))

    for i_label, label in enumerate(u_labels[:nclusters]):
        indices = np.where(y_pred == label)[0]
        np.random.default_rng(seed=label).shuffle(indices)
        for i in range(nsamp):
            if i < len(indices):
                idx = indices[i]
                axs[i_label, i].imshow(ds[idx][0].convert("RGB"))
            axs[i_label, i].axis("off")

    if save:
        print(f"Saving to {output_fname}")
        os.makedirs(output_dir, exist_ok=True)
        plt.savefig(output_fname, bbox_inches="tight")


def fetch_row(dataset, model, clusterer):
    override_fields = {
        "predictions_dir": "y_pred",
    }
    if clusterer == "HDBSCAN" and dataset in ["celeba", "utkface"]:
        override_fields["min_samples"] = 2
    filter1 = {"model": model, "dataset": dataset}
    filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
    filter2.update(filter1)
    filter2.update(override_fields)
    filter2 = fixup_filter(filter2)
    sdf = select_rows(test_runs_df, filter2, allow_missing=False)
    if len(sdf) < 1:
        print(f"No data for {filter2}")
        print(filter2command(filter2, partition="test"))
        return
    elif len(sdf) > 1:
        perf = sdf.iloc[0]["AMI"]
        if sum(sdf["AMI"] != perf) > 0:
            print()
            print("More than one result with AMIs:", list(sdf["AMI"]))
            print(f"for search {filter2}")
            dif_cols = find_differing_columns(sdf, config_keys)
            print(f"columns which differ: {dif_cols}")
            if dif_cols:
                for col in dif_cols:
                    print(f"  {col}: {list(sdf[col])}")
        return
    return sdf.iloc[0]

In [None]:
dataset = "svhn"
model = "mocov3_resnet50"
clusterer = "AC w/ C"

override_fields = {
    "predictions_dir": "y_pred",
    # "aggclust_dist_thresh": None,  # Use this to flip between unknown/known num clusters for Agglom
}
filter1 = {"model": model, "dataset": dataset}
filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
filter2.update(filter1)
filter2.update(override_fields)
filter2 = fixup_filter(filter2)
sdf = select_rows(test_runs_df, filter2, allow_missing=False)
if len(sdf) < 1:
    print(f"No data for {filter2}")
    print(filter2command(filter2, partition="test"))
elif len(sdf) > 1:
    perf = sdf.iloc[0]["AMI"]
    if sum(sdf["AMI"] != perf) > 0:
        print()
        print("More than one result with AMIs:", list(sdf["AMI"]))
        print(f"for search {filter2}")
        dif_cols = find_differing_columns(sdf, config_keys)
        print(f"columns which differ: {dif_cols}")
        if dif_cols:
            for col in dif_cols:
                print(f"  {col}: {list(sdf[col])}")
else:
    display(sdf)
    row = sdf.iloc[0]
    print(
        row["name"].split("__")[-1],
        "\n" + row["name"],
        "\n  " + row["dataset_name"],
        "\n  " + row["model"],
        "\n  " + row["clusterer_name"],
        f"\n  AMI={row['AMI']}",
        f"\n  S_reduced={row['silhouette-euclidean_pred']}",
        f"\n  S_originl={row['silhouette-og-euclidean_pred']}",
    )

In [None]:
y_pred = np.load("../" + get_pred_path(row))["y_pred"]

In [None]:
len(y_pred)

In [None]:
ds = datasets.fetch_image_dataset(row["dataset_name"])[-1]

In [None]:
indices = np.where(y_pred == 0)[0]
np.random.default_rng(seed=0).shuffle(indices)

In [None]:
indices[:10]

In [None]:
np.unique(y_pred)

In [None]:
label = 3
nsamp = 10

indices = np.where(y_pred == label)[0]
np.random.default_rng(seed=label).shuffle(indices)

fig, axs = plt.subplots(1, nsamp, figsize=(6, 2))

for i in range(10):
    idx = indices[i]
    axs[i].imshow(ds[idx][0])
    axs[i].axis("off")

In [None]:
nsamp = 10

u_labels = np.unique(y_pred)

fig, axs = plt.subplots(len(u_labels), nsamp, figsize=(len(u_labels) / 2, nsamp / 2))

for i_label, label in enumerate(u_labels):
    indices = np.where(y_pred == label)[0]
    np.random.default_rng(seed=label).shuffle(indices)
    for i in range(10):
        idx = indices[i]
        axs[i_label, i].imshow(ds[idx][0])
        axs[i_label, i].axis("off")

# plt.savefig(f"{row['dataset_name']}_{row['model']}_{row['clusterer_name']}.png", bbox_inches='tight')
plt.show()

In [None]:
ds = datasets.fetch_image_dataset("flowers102")[-1]

In [None]:
ds[1][0]

In [None]:
ds = datasets.fetch_image_dataset("flowers102", transform_eval=CenterSquaring())[-1]

In [None]:
ds[1][0]

In [None]:
dataset = "svhn"
model = "mocov3_resnet50"
clusterer = "AC w/ C"

row = fetch_row(dataset, model, clusterer)
print(
    row["name"].split("__")[-1],
    "\n" + row["name"],
    "\n  " + row["dataset_name"],
    "\n  " + row["model"],
    "\n  " + row["clusterer_name"],
    f"\n  AMI        = {row['AMI']}",
    f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
    f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
)
show_samples(row, save=True, clusterer=clusterer)

In [None]:
for clusterer in ["AC w/ C"]:  # , "AC w/o C"]:
    for model in ["mocov3_resnet50", "mocov3_vit_base", "dino_resnet50", "dino_vitb16"]:
        for dataset in [
            "mnist",
            "fashionmnist",
            "svhn",
            "cifar10",
            "cifar100",
            "flowers102",
            "aircraft",
        ]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in [
            "mnist",
            "fashionmnist",
            "svhn",
            "cifar10",
            "cifar100",
            "flowers102",
            "aircraft",
        ]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in ["inaturalist"]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            plt.close()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in TEST_DATASETS:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            try:
                fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            except Exception:
                print(f"{dataset} not found")
            try:
                plt.close()
            except Exception:
                pass
            print("\n\nStopping early!")
            break
        break
    break

## Breakdown information about datasets with multiple labels

### CelebA attributes

In [None]:
celeba_test = torchvision.datasets.CelebA(
    os.path.expanduser("~/Datasets"),
    target_type="attr",
    split="test",
)

In [None]:
celeba_test.attr_names

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "celeba"

TEST_ATTRS = ["Identity"] + celeba_test.attr_names[:-1]
TEST_ATTRS = [
    "Identity",
    "Attractive",
    "Bald",
    "Eyeglasses",
    "Heavy_Makeup",
    "Male",
    "No_Beard",
    "Wearing_Lipstick",
]
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(
                        f"More than one result with {metric_key} values",
                        list(sdf[metric_key]),
                    )
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                if attr.lower() == "identity":
                    my_val = sklearn.metrics.adjusted_mutual_info_score(celeba_test.identity[:, 0], y_pred)
                else:
                    i_attr = np.where(np.asarray(celeba_test.attr_names) == attr)[0][0]
                    my_val = sklearn.metrics.adjusted_mutual_info_score(celeba_test.attr[:, i_attr], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

### UTKFace breakdown

In [None]:
import zs_ssl_clustering.datasets

utkface_test = zs_ssl_clustering.datasets.fetch_image_dataset("utkface")[-1]

In [None]:
utkface_test.metadata

In [None]:
attr_names = ["age", "gender", "race"]
attrs = utkface_test.metadata[["age", "gender_id", "race_id"]].values

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.1f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.0005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ViT-B"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "utkface"

TEST_ATTRS = attr_names
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\textwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(
                        f"More than one result with {metric_key} values",
                        list(sdf[metric_key]),
                    )
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() == "ami":
                    my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                elif metric_key.lower() == "nmi":
                    my_val = sklearn.metrics.normalized_mutual_info_score(attrs[:, i_attr], y_pred)
                elif metric_key.lower() == "mi":
                    my_val = sklearn.metrics.mutual_info_score(attrs[:, i_attr], y_pred)
                elif metric_key.lower() == "rand_score":
                    my_val = sklearn.metrics.rand_score(attrs[:, i_attr], y_pred)
                else:
                    raise NotImplementedError()
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"  # Disabled

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

### ImageNet-R

In [None]:
imagenetr_test = datasets.fetch_image_dataset("imagenet-r")[-1]

In [None]:
artforms = [os.path.basename(fname[0]).split("_")[0] for fname in imagenetr_test.imgs]

In [None]:
u_artform, artform_ids = np.unique(artforms, return_inverse=True)

In [None]:
len(u_artform)

In [None]:
np.unique(imagenetr_test.targets)

In [None]:
class_artform_ids = imagenetr_test.targets + artform_ids * 200

In [None]:
len(np.unique(class_artform_ids))

In [None]:
attrs = np.stack([imagenetr_test.targets, artform_ids, class_artform_ids], axis=-1)

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "imagenet-r"

TEST_ATTRS = ["Class", "Artform", "Both"]
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\textwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\midrule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\midrule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(
                        f"More than one result with {metric_key} values",
                        list(sdf[metric_key]),
                    )
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"  # Disabled

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

#### New

In [None]:
def build_attr_ami_table(models, clusterers, dataset, attrs, return_cmds=False, verbosity=0):
    result_table = np.nan * np.ones((len(models), len(clusterers), attrs.shape[1]))
    cmds = []

    for i_model, model in enumerate(models):
        print(f"[{i_model + 1}/{len(models)}] {model}")
        for i_clusterer, clusterer in enumerate(clusterers):
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                if verbosity >= 1:
                    print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                continue
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]

            for i_attr in range(attrs.shape[1]):
                my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                result_table[i_model, i_clusterer, i_attr] = my_val

    if return_cmds:
        return result_table, cmds
    else:
        return result_table

In [None]:
dataset = "imagenet-r"
# clusterers = ['KMeans', 'AC w/ C', 'AC w/o C', 'AffinityPropagation', 'HDBSCAN']
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005

if metric_key == "num_cluster_pred":
    clusterers = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    clusterers = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"


model_groups = {
    # "---": ["none"],
    "RN50": RESNET50_MODELS[1:] + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS[1:] + FT_VITB16_MODELS,
}


if len(clusterers) == 1:
    clustererstr = clusterers[0]
else:
    clustererstr = f"{len(clusterers)}c-avg"

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]


print("Encoders:")
print(model_groups_flattened)

result_table = build_attr_ami_table(
    model_groups_flattened,
    clusterers,
    dataset,
    attrs=attrs,
)

In [None]:
result_table_actual = result_table

In [None]:
result_table.shape

In [None]:
TEST_ATTRS = ["Class", "Artform", "Both"]

In [None]:
print("Encoders:")
print(model_groups_flattened)

print("Datasets:")
print(test_datasets)

result_table = copy.deepcopy(result_table_actual)

# Shaped [models, clusterers, datasets]
print("result_table.shape", result_table.shape)

# Remove clusterer-dataset combos which are NaN for any model
result_table[:, np.any(np.isnan(result_table), axis=0)] = np.nan

# Take mean over clusterers
result_table = np.nanmean(result_table, axis=1)
# Shaped [models, datasets]

print("result_table.shape", result_table.shape)

# Scale up to be a percentage
# result_table *= 100.0


if use_rank:
    result_table_r = np.argsort(result_table, -1)[::-1, :] + 1
else:
    result_table_r = result_table

# Take mean and stdev over samples
mu = np.mean(result_table_r, axis=-1)
sd = np.std(result_table_r, axis=-1)
print("mu.shape", mu.shape)


print(model_groups)


print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {clustererstr}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = clustererstr
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\columnwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    # Begin table contents
    latex_table += r"\midrule" + "\n"
    i_model_o = -1
    for i_group, group in enumerate(model_groups):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        latex_table += group + "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            i_model_o += 1
            model_sh = MODEL2SH.get(model, model)
            if model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            latex_table += f"& {model_sh:<23s}"
            # for i_dataset, dataset in enumerate(test_datasets):
            for i_attr, attrname in enumerate(TEST_ATTRS):
                latex_table += " &"
                my_val = result_table[i_model_o, i_attr]
                if dummy:
                    best_results[attrname].append(my_val)
                    best_results_grouped[attrname][group].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attrname])
                if len(best_results[attrname]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attrname])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attrname][group])
                is_best_grp &= len(best_results_grouped[attrname][group]) > 1
                sc_base = np.nanmedian(best_results[attrname])
                sc_top = np.max(best_results[attrname])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    # latex_table += "     "
                    pass
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    # latex_table += "     "
                    pass
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%} % resizebox" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {clustererstr}:")
print()
print()
print(latex_table)

In [None]:
result_table.shape

### FGVC Aircraft

In [None]:
annotation_levels = ["manufacturer", "family", "variant"]
attrs = np.stack(
    [
        torchvision.datasets.FGVCAircraft(
            os.path.expanduser("~/Datasets"),
            split="test",
            annotation_level=annotation_level,
        )._labels
        for annotation_level in annotation_levels
    ],
    axis=-1,
)

In [None]:
for i_attr in range(len(annotation_levels)):
    print(annotation_levels[i_attr], len(np.unique(attrs[:, i_attr])))

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = [
        "KMeans",
        "SpectralClustering",
        "AC w/ C",
        "AC w/o C",
        "AffinityPropagation",
        "HDBSCAN",
    ]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "aircraft"

TEST_ATTRS = annotation_levels
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\textwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if first_agg and clusterer == "AgglomerativeClustering" and metric_key != "num_cluster_pred":
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(
                        f"More than one result with {metric_key} values",
                        list(sdf[metric_key]),
                    )
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[attr][clusterername])
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"  # Disabled

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)