In [None]:
from pathlib import Path
import h5py
import numpy as np

In [None]:
feature_dir = Path('/lustre/groups/shared/users/peng_marr/DinoBloomv2/vits_3M_350k_bs416_0.1ce+supcon_mlp_rbc/train_embeddings')

In [None]:
datasets = [
    'bm_train_patches',
    'bozdas_patches',
    'mll_mil_train_patches',
    'krd_wbc_patches',
    'ldwbc_patches',
    'lisc_refactor_patches',
    'matek_patches',
]
h5_datasets = [list(feature_dir.glob(f'{dataset}*.h5'))[0] for dataset in datasets]

In [None]:
# all features
h5_datasets = list(feature_dir.glob('*.h5'))
datasets = [h5.name.split('_embeddings')[0] for h5 in h5_datasets]
domain_labels = {dataset: i for i, dataset in enumerate(set(datasets))}

In [None]:
h5_datasets = [h5py.File(h5_dataset, 'r') for h5_dataset in h5_datasets]

In [None]:
features = np.concatenate([h5_dataset['features'] for h5_dataset in h5_datasets], axis=0)
labels = np.concatenate([h5_dataset['labels'] for h5_dataset in h5_datasets], axis=0)

In [None]:
# domain_labels = [np.ones(len(dataset['features'])) * idx for idx, dataset in enumerate(h5_datasets)]
domain_labels = [np.ones(len(data['features'])) * domain_labels[name] for name, data in zip(datasets, h5_datasets)]
domain_labels = np.concatenate(domain_labels)

In [None]:
features.shape, labels.shape, domain_labels.shape

In [None]:
# plot two umaps, one with labels and one with domain labels
import umap
import matplotlib.pyplot as plt

reducer = umap.UMAP()
embedding = reducer.fit_transform(features)

In [None]:
# randomly permute embeddings, labels, and domain_labels
perm = np.random.permutation(len(embedding))
embedding = embedding[perm]
labels = labels[perm]
domain_labels = domain_labels[perm]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# Create a custom color map where -1 is grey and other values use a categorical colormap
unique_labels = np.unique(labels)
cmap = plt.cm.get_cmap('tab10', len(unique_labels))  # 'tab10' is a categorical colormap

# Create a color array based on labels
color_array = np.array([cmap(i) if val != -1 else (0.5, 0.5, 0.5, 1) for i, val in enumerate(unique_labels)])

# Prepare colors based on the labels
colors = np.array([color_array[np.where(unique_labels == label)[0][0]] if label != -1 else (0.5, 0.5, 0.5, 1) for label in labels])

# Prepare colors based on domain labels
# domain_colors = np.array([color_array[np.where(np.unique(domain_labels) == label)[0][0]] if label != -1 else (0.5, 0.5, 0.5, 1) for label in domain_labels])

fig, ax = plt.subplots(1, 2, figsize=(20, 10))

# First subplot
sc1 = ax[0].scatter(embedding[:, 0], embedding[:, 1], c=colors, s=1, alpha=0.7)
ax[0].set_aspect('equal', 'datalim')
ax[0].set_title('Class labels', fontsize=24)

# Second subplot
sc2 = ax[1].scatter(embedding[:, 0], embedding[:, 1], c=domain_labels, cmap='twilight', s=1, alpha=0.7)
ax[1].set_aspect('equal', 'datalim')
ax[1].set_title('Domain labels', fontsize=24)

plt.show()
