In [None]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from colabtools import adhoc_import
from typing import Any, Sequence
import ml_collections
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import functools
import itertools
from scipy.stats import mode
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import time
import pickle
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import rand_score, adjusted_rand_score, silhouette_score, adjusted_mutual_info_score
from sklearn.mixture import GaussianMixture
from clu import preprocess_spec
from scipy.special import comb



# Compute linear fit of purity vs layers

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler
for model_dir in [os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_400_epochs_ema_0.99_bn_0.99/'),
    os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/'),
    os.path.join(BASE_DIR, 'breeds/entity13_8_classes_400_epochs_ema_0.99_bn_0.99/'),
    os.path.join(BASE_DIR, 'breeds/living17_8_classes_400_epochs_ema_0.99_bn_0.99/'),
    os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99/'),
    os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/'),
    ]:
  all_purity = []
  Xs = []
  Ys = []
  for stage in [1,2,3,4]:
    for block in [1,2,3,4,5,6]:
      purity_data_path = os.path.join(model_dir, f'class_purity_ckpt_stage{stage}_block{block}.pkl')
      if not gfile.Exists(purity_data_path):
        continue
      with gfile.Open(purity_data_path, 'rb') as f:
        all_purity.append(pickle.load(f))
  all_purity = np.vstack(all_purity)
  for layer_idx in range(all_purity.shape[0]):
    Ys.extend(all_purity[layer_idx, :])
    Xs.extend([layer_idx+1] * all_purity.shape[1])
  Xs = np.array(Xs).reshape(-1, 1)
  Ys = np.array(Ys).reshape(-1, 1)
  # scaler = MinMaxScaler()
  # Xs = scaler.fit_transform(Xs)
  # Ys = scaler.fit_transform(Ys)
  # print(list(Xs))
  # print(list(Ys))
  reg = LinearRegression().fit(Xs, Ys.squeeze())
  print(model_dir)
  print(reg.coef_)

# Plot adjusted mutual info over time

In [None]:
for model_dir in [os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/'), 
                  os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_400_epochs_ema_0.99_bn_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_shuffle_400_epochs_ema_0.99_bn_0.99/')]:
  all_ckpt = [f for f in list(gfile.ListDir(model_dir)) if f.startswith('adjusted_mutual_info_') and not f.startswith('adjusted_mutual_info_ckpt_stage')]
  all_ckpt = sorted([int(f.split('.')[0].split('_')[-1]) for f in all_ckpt])
  #all_ckpt = [c for c in all_ckpt if c % 10 == 1 and c > 1]
  ckpt_number = max(all_ckpt)
  print(all_ckpt)
  n_cols = 6
  n_rows = int(np.ceil(len(all_ckpt)/n_cols))
  fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
  print(model_dir)
  for idx, ckpt_number in enumerate(all_ckpt):
    row = idx // n_cols
    col = idx % n_cols
    ami_file = os.path.join(model_dir, f'adjusted_mutual_info_ckpt_{ckpt_number}.pkl')
    with gfile.Open(ami_file, 'rb') as f:
      data = pickle.load(f)
    n_subclasses = len(data[1])
    fig = plt.figure()
    all_subclass_ami = []
    color = iter(cm.rainbow(np.linspace(0, 1, n_subclasses)))
    for subclass_idx in range(n_subclasses):
      subclass_ami = [data[overcluster_factor][subclass_idx] for overcluster_factor in range(1, 6)]
      all_subclass_ami.append(subclass_ami)
      axs[row, col].plot([1,2,3,4,5], subclass_ami, marker='o', color=next(color))
    all_subclass_ami = np.vstack(all_subclass_ami)
    axs[row, col].plot([1,2,3,4,5], np.mean(all_subclass_ami, axis=0), marker='o', color='k')
    axs[row, col].set_title(f"Ckpt {ckpt_number}")#model_dir.split('/')[-2])
    axs[row, col].set_xticks([1,2,3,4,5])
    axs[row, col].set_xlabel('Overclustering factor')
    axs[row, col].set_ylabel('Adjusted mutual information')
    #axs[row, col].set_ylim([0.0, 1.0])
  plt.show()
  plt.clf()

# Plot adjusted mutual info over layers

In [None]:
for model_dir in [os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/'), 
                  os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_400_epochs_ema_0.99_bn_0.99/'),
                  os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_shuffle_400_epochs_ema_0.99_bn_0.99/')]:
  print(model_dir)
  fig, axs = plt.subplots(4, 6, figsize=(5*6, 5*4))
  for stage in [1,2,3,4]:
    for block in [1,2,3,4,5,6]:
      ami_file = os.path.join(model_dir, f'adjusted_mutual_info_ckpt_stage{stage}_block{block}.pkl')
      if not gfile.Exists(ami_file):
        continue
      with gfile.Open(ami_file, 'rb') as f:
        data = pickle.load(f)
      n_subclasses = len(data[1])
      fig = plt.figure()
      all_subclass_ami = []
      color = iter(cm.rainbow(np.linspace(0, 1, n_subclasses)))
      for subclass_idx in range(n_subclasses):
        subclass_ami = [data[overcluster_factor][subclass_idx] for overcluster_factor in range(1, 6)]
        all_subclass_ami.append(subclass_ami)
        axs[stage-1, block-1].plot([1,2,3,4,5], subclass_ami, marker='o', c=next(color))
      all_subclass_ami = np.vstack(all_subclass_ami)
      axs[stage-1, block-1].plot([1,2,3,4,5], np.mean(all_subclass_ami, axis=0), marker='o', color='k', linewidth=2)
      axs[stage-1, block-1].set_xticks([1,2,3,4,5])
      axs[stage-1, block-1].set_xlabel('Overclustering factor')
      axs[stage-1, block-1].set_ylabel('Adjusted mutual information')
      axs[stage-1, block-1].set_title(f"stage{stage}_block{block}")
      #axs[stage-1, block-1].set_ylim([0.0, 1.0])
  plt.show()
  plt.clf()

# Plot purity vs training steps, and adjust_rand_idex of clusters from successive checkpoints vs steps

In [None]:
ret = make_breeds_dataset("living17", BREEDS_INFO_DIR, BREEDS_INFO_DIR,
                    split=None,
                    #num_subclasses=4, shuffle_subclasses=True
                    )
superclasses, subclass_split, label_map = ret
train_subclasses = subclass_split[0]

In [None]:
all_purity = []
all_ckpt_numbers = list(range(3, 42, 2))
Xs = []
for ckpt_number in all_ckpt_numbers:
  purity_data_path = fos.path.join(BASE_DIR, 'breeds/1/class_purity_ckpt_{ckpt_number}.pkl'
  if not gfile.Exists(purity_data_path):
    continue
  Xs.append(ckpt_number)
  with gfile.Open(purity_data_path, 'rb') as f:
    all_purity.append(pickle.load(f))
all_purity = np.vstack(all_purity)
print(all_purity.shape)

color = iter(cm.rainbow(np.linspace(0, 1, all_purity.shape[1])))
fig = plt.figure(figsize=(15,12))
for i in range(all_purity.shape[1]):
  c = next(color)
  plt.plot([100*(c-1) for c in Xs], all_purity[:, i], label=label_map[i], color=c, marker='o')

plt.ylim((0, 1.0))
plt.xlabel("Step")
plt.ylabel("Cluster purity")
plt.legend()
plt.show()

In [None]:
all_rand_index = [[] for i in range(len(train_subclasses))]
Xs = list(range(5, 42, 2))
for ckpt_number in Xs:
  with gfile.Open(fos.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99_squared_loss/clf_labels_ckpt_{ckpt_number}.pkl'), 'rb') as f:
    all_clf_labels = pickle.load(f)
  with gfile.Open(fos.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99_squared_loss/clf_labels_ckpt_{ckpt_number-2}.pkl'), 'rb') as f:
    all_clf_labels_prev = pickle.load(f)
  for i in range(len(train_subclasses)):
    all_rand_index[i].append(adjusted_rand_score(all_clf_labels[i], all_clf_labels_prev[i]))
  
color = iter(cm.rainbow(np.linspace(0, 1, all_purity.shape[1])))
fig = plt.figure(figsize=(15,12))
for i in range(len(train_subclasses)):
  c = next(color)
  plt.plot([100*(c-1) for c in Xs], all_rand_index[i], label=label_map[i], color=c, marker='o')

plt.xlabel("Step", fontsize=16)
plt.ylabel("Adjusted Rand Index", fontsize=16)
plt.legend()
plt.show()

# Plot purity vs layers, and adjust_rand_idex of clusters from successive layers vs steps

In [None]:
all_purity = []
Xs = []
X_ticks = []
for stage in [1,2,3,4]:
  for block in [1,2,3,4,5,6]:
    purity_data_path = fos.path.join(BASE_DIR, 'breeds/1/class_purity_ckpt_stage{stage}_block{block}.pkl')
    if not gfile.Exists(purity_data_path):
      continue
    X_ticks.append(f"stage{stage}_block{block}")
    with gfile.Open(purity_data_path, 'rb') as f:
      all_purity.append(pickle.load(f))
all_purity = np.vstack(all_purity)
print(all_purity.shape)

colors = iter(cm.rainbow(np.linspace(0, 1, all_purity.shape[1])))
fig = plt.figure(figsize=(15,12))
for i in range(all_purity.shape[1]):
  color = next(colors)
  plt.plot(range(len(X_ticks)), all_purity[:, i], label=label_map[i], color=color, marker='o')

# Xs = list(range(16))
# Ys = [x * 0.01190781 + 0.5176538461538461 for x in Xs]
# plt.plot(Xs, Ys, color='k')
plt.xticks(ticks=range(len(X_ticks)), labels=X_ticks, rotation=90)
plt.ylim((0, 1.0))
plt.xlabel("Layer")
plt.ylabel("Cluster purity")
plt.legend()
plt.show()

In [None]:
all_rand_index = [[] for i in range(len(train_subclasses))]
for i in range(1, len(X_ticks)):
  current_block = X_ticks[i]
  prev_block = X_ticks[i-1]
  with gfile.Open(fos.path.join(BASE_DIR, 'breeds/1/clf_labels_ckpt_{current_block}.pkl'), 'rb') as f:
    all_clf_labels = pickle.load(f)
  with gfile.Open(fos.path.join(BASE_DIR, 'breeds/1/clf_labels_ckpt_{prev_block}.pkl'), 'rb') as f:
    all_clf_labels_prev = pickle.load(f)
  for i in range(len(train_subclasses)):
    all_rand_index[i].append(adjusted_rand_score(all_clf_labels[i], all_clf_labels_prev[i]))
  
colors = iter(cm.rainbow(np.linspace(0, 1, all_purity.shape[1])))
fig = plt.figure(figsize=(15,12))
for i in range(len(train_subclasses)):
  color = next(colors)
  plt.plot(range(1, len(X_ticks)), all_rand_index[i], label=label_map[i], color=color, marker='o')

plt.xticks(ticks=range(1, len(X_ticks)), labels=X_ticks[1:], rotation=90)
plt.xlabel("Layer", fontsize=16)
plt.ylabel("Adjusted Rand Index", fontsize=16)
plt.legend()
plt.show()

# Load eval data

In [None]:
ret = make_breeds_dataset("living17", BREEDS_INFO_DIR, BREEDS_INFO_DIR,
                          split=None)
superclasses, subclass_split, label_map = ret
train_subclasses = subclass_split[0]

In [None]:
def load_dataset(dataset_name, train_subclasses=None):
  if dataset_name == 'imagenet':
    all_subclasses = list(itertools.chain(*train_subclasses))
    new_label_map = {}
    for subclass_idx, sub in enumerate(all_subclasses):
      new_label_map.update({sub: subclass_idx})
    print(new_label_map)
    lookup_table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(new_label_map.keys()), dtype=tf.int64),
            values=tf.constant(list(new_label_map.values()), dtype=tf.int64),
        ),
        default_value=tf.constant(-1, dtype=tf.int64))

    dataset_builder = tfds.builder("imagenet2012", try_gcs=True)
    eval_preprocess = preprocess_spec.PreprocessFn([
        RescaleValues(),
        ResizeSmall(256),
        CentralCrop(224),
        ], only_jax_types=True)
    num_classes = len(train_subclasses)
  else:
    dataset_builder = tfds.builder("celeb_a", try_gcs=True)
    eval_preprocess = preprocess_spec.PreprocessFn([
        RescaleValues(),
        ResizeSmall(256),
        CentralCrop(224),
        LabelMapping()
        ], only_jax_types=True)
    num_classes = 2

  dataset_options = tf.data.Options()
  dataset_options.experimental_optimization.map_parallelization = True
  dataset_options.experimental_threading.private_threadpool_size = 48
  dataset_options.experimental_threading.max_intra_op_parallelism = 1

  read_config = tfds.ReadConfig(shuffle_seed=None, options=dataset_options)
  eval_ds = dataset_builder.as_dataset(
      split=tfds.Split.VALIDATION,
      shuffle_files=False,
      read_config=read_config,
      decoders=None)

  batch_size = 64
  if dataset_name == 'imagenet':
    eval_ds = eval_ds.filter(functools.partial(predicate, all_subclasses=all_subclasses))
  eval_ds = eval_ds.cache()
  eval_ds = eval_ds.map(eval_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  eval_ds = eval_ds.batch(batch_size, drop_remainder=False)
  eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)
  return eval_ds, num_classes

In [None]:
DATASET = 'imagenet'#'celeb_a'
eval_ds, num_classes = load_dataset(DATASET, train_subclasses=train_subclasses)
if DATASET == 'celeb_a':
  train_subclasses = [[0], [1]

# Clustering algorithms

In [None]:
def get_learning_rate(step: int,
                      *,
                      base_learning_rate: float,
                      steps_per_epoch: int,
                      num_epochs: int,
                      warmup_epochs: int = 5):
  """Cosine learning rate schedule."""
  logging.info(
      "get_learning_rate(step=%s, base_learning_rate=%s, steps_per_epoch=%s, num_epochs=%s",
      step, base_learning_rate, steps_per_epoch, num_epochs)
  if steps_per_epoch <= 0:
    raise ValueError(f"steps_per_epoch should be a positive integer but was "
                     f"{steps_per_epoch}.")
  if warmup_epochs >= num_epochs:
    raise ValueError(f"warmup_epochs should be smaller than num_epochs. "
                     f"Currently warmup_epochs is {warmup_epochs}, "
                     f"and num_epochs is {num_epochs}.")
  epoch = step / steps_per_epoch
  lr = cosine_decay(base_learning_rate, epoch - warmup_epochs,
                    num_epochs - warmup_epochs)
  warmup = jnp.minimum(1., epoch / warmup_epochs)
  return lr * warmup

config = get_config()
learning_rate_fn = functools.partial(
      get_learning_rate,
      base_learning_rate=0.1,
      steps_per_epoch=40,
      num_epochs=config.num_epochs,
      warmup_epochs=config.warmup_epochs)

In [None]:
model_dir = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99_squared_loss_hyperparam_tuned/')
checkpoint_path = os.path.join(model_dir, 'checkpoints-0/ckpt-41.flax')
model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), num_classes=num_classes, learning_rate_fn=learning_rate_fn)
state = checkpoints.restore_checkpoint(checkpoint_path, state)
print("step:", state.step)

all_intermediates = []
all_subclass_labels = []
all_images = []
for step, batch in enumerate(eval_ds):
  if step % 20 == 0:
    print(step)
  intermediates = predict(model, state, batch)
  labels = batch['label'].numpy()
  bs = labels.shape[0]
  all_subclass_labels.append(labels)
  all_images.append(batch['image'].numpy())
  all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))

all_intermediates = np.vstack(all_intermediates)
all_subclass_labels = np.hstack(all_subclass_labels)
all_images = np.vstack(all_images)

In [None]:
cluster_each_superclass = True
#overcluster_factor, n_subclasses = 5, 20
if isinstance(all_intermediates, list):
  all_intermediates = np.vstack(all_intermediates)
if isinstance(all_subclass_labels, list):
  all_subclass_labels = np.hstack(all_subclass_labels)
if isinstance(all_images, list):
  all_images = np.vstack(all_images)
if len(all_filenames) > 0:
  all_filenames = np.hstack(all_filenames)
  all_filenames = [f.decode("utf-8") for f in all_filenames]
print(all_intermediates.shape)
print(all_subclass_labels.shape)

### Kmeans (with PCA 50 dim)

In [None]:
all_intermediates_normalized = all_intermediates - np.mean(all_intermediates, axis=0)
pca = PCA(n_components=50)
all_intermediates_pca = pca.fit_transform(all_intermediates_normalized)

for overcluster_factor in [0.1, 0.25, 0.5, 1, 3, 5, 10]:
  all_clfs = []
  all_sil_scores = []
  if cluster_each_superclass:
    for subclasses in train_subclasses:-
      subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
      kmeans = KMeans(n_clusters=int(len(subclasses)*overcluster_factor), random_state=1).fit(all_intermediates_pca[subclass_idx])
      all_sil_scores.append(silhouette_score(all_intermediates_pca[subclass_idx], kmeans.labels_))
      all_clfs.append(kmeans)
    print(all_sil_scores)
    print(f"overcluster factor = {overcluster_factor}, mean sil score = {np.mean(all_sil_scores)}")
  else:
    kmeans = KMeans(n_clusters=int(len(all_subclasses)*overcluster_factor), random_state=1).fit(all_intermediates_pca)

### Minibatch Kmeans

In [None]:
start = time.time()
kmeans_minibatch = MiniBatchKMeans(n_clusters=len(all_subclasses)*overcluster_factor, random_state=0, batch_size=1024, max_iter=300, max_no_improvement=100).fit(all_intermediates)
end = time.time()
print("Time taken:", end - start)
print("number of clusters found:", len(set(kmeans_minibatch.labels_)))

### Hierarchical clustering

In [None]:
for overcluster_factor in [0.1, 0.25, 0.5, 1, 3, 5, 10]:
  all_clfs = []
  all_sil_scores = []
  if cluster_each_superclass:
    for subclasses in train_subclasses:
      subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
      hier_clustering = AgglomerativeClustering(n_clusters=int(len(subclasses)*overcluster_factor),
                                          linkage='ward').fit(all_intermediates[subclass_idx])
      all_sil_scores.append(silhouette_score(all_intermediates[subclass_idx], hier_clustering.labels_))
      all_clfs.append(hier_clustering)
    print(f"overcluster factor = {overcluster_factor}, mean sil score = {np.mean(all_sil_scores)}")
  else:
    hier_clustering = AgglomerativeClustering(n_clusters=int(len(all_subclasses)*overcluster_factor),
                                          linkage='ward').fit(all_intermediates)

### GMM

In [None]:
for overcluster_factor in [3, 5, 10, 15, 20, 25]:
  all_clfs = []
  all_bic_scores = []
  if cluster_each_superclass:
    for subclasses in train_subclasses:
      subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
      gmm = GaussianMixture(n_components=len(subclasses)*overcluster_factor, covariance_type='tied', 
                            random_state=1, reg_covar=1e-5)
      gmm.fit(all_intermediates[subclass_idx])
      all_clfs.append(gmm.predict(all_intermediates[subclass_idx]))
      all_bic_scores.append(gmm.bic(all_intermediates[subclass_idx]))
    print(f"overcluster factor = {overcluster_factor}, mean bic score = {np.mean(all_bic_scores)}")
  else:
    gmm = GaussianMixture(n_components=len(subclasses)*overcluster_factor, covariance_type='full', random_state=1)
    gmm.fit(all_intermediates)

# Analyze images and labels in each cluster

In [None]:
def show_images_horizontally(images, labels, super_label=None, max_samples=20):
  if max_samples is not None:
    number_of_files = min(images.shape[0], max_samples)
    fig = plt.figure(figsize=(1.5 * number_of_files, 1.5))
    n_rows = 1
    n_cols = number_of_files
  else:
    number_of_files = images.shape[0]
    n_cols = 10
    n_rows = int(np.ceil(number_of_files/n_cols))
    fig = plt.figure(figsize=(1.5 * n_cols, 1.5 * n_rows))
                          
  for i in range(number_of_files):
    axes = fig.add_subplot(n_rows, n_cols, i+1)
    axes.imshow(images[i])
    if isinstance(labels, list):
      axes.set_title(labels[i])
    else:
      axes.set_title(labels)
  if super_label:
    plt.title(super_label)
  plt.show()

def visualize_all_clusters(all_images, cluster_labels, all_subclass_labels, hier, gender_labels):
  for label in list(set(cluster_labels)):
    instance_idx = np.where(cluster_labels == label)[0]
    subclass_labels = all_subclass_labels[instance_idx]
    mode_stats = mode(subclass_labels)
    # if mode_stats[1][0] == 1:
    #   # Skip clusters with no dominant subclass
    #   continue
    cluster_images = all_images[instance_idx]
    cluster_label = mode_stats[0][0]
    if hier:
      cluster_label_str = hier.LEAF_NUM_TO_NAME[cluster_label].split(',')[0]
      # if cluster_label not in subclasses_to_visualize:
      #   # only visualize clusters where the dominant subclass belongs to subclasses_to_visualize
      #   continue
      subclass_labels = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in subclass_labels]
      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10)
    elif gender_labels:
      male_pc = sum(gender_labels[instance_idx]) * 100.0 / len(instance_idx)
      cluster_label_str = f'Male ratio = {male_pc}%'
      subclass_labels = ['' for s in subclass_labels]
      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10)
    else:
      show_images_horizontally(cluster_images, subclass_labels, super_label='', max_samples=10)


def visualize_cluster(cluster_idx, clusters, all_images, all_subclass_labels, hier):
  instance_idx = np.where(clusters == cluster_idx)[0]
  cluster_images = all_images[instance_idx]
  subclass_labels = all_subclass_labels[instance_idx]
  mode_stats = mode(subclass_labels)
  if mode_stats[1][0] == 1:
    print("No dominant subclass")
  else:
    cluster_label = mode_stats[0][0]
    cluster_label_str = hier.LEAF_NUM_TO_NAME[cluster_label].split(',')[0]
    print("Dominant subclass:", cluster_label_str)
  subclass_labels = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in subclass_labels]
  show_images_horizontally(cluster_images, subclass_labels, max_samples=None)

def visualize_error(cluster_label_str, pred_str, clusters, all_subclasses, all_images, hier):
  for s in all_subclasses:
    if hier.LEAF_NUM_TO_NAME[s].startswith(cluster_label_str):
      cluster_label = s
  for s in all_subclasses:
    if hier.LEAF_NUM_TO_NAME[s].startswith(pred_str):
      pred = s
  for cluster_idx in list(set(clusters)):
    instance_idx = np.where(clusters == cluster_idx)[0]
    subclass_labels = all_subclass_labels[instance_idx]
    mode_stats = mode(subclass_labels)
    if mode_stats[1][0] == 1:
      continue
    if mode_stats[0][0] != cluster_label:
      continue
    if pred not in subclass_labels:
      continue
    cluster_images = all_images[instance_idx]
    subclass_labels = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in subclass_labels]
    show_images_horizontally(cluster_images, subclass_labels, max_samples=None)

### Save image filenames from each cluster

In [None]:
cluster_filenames_dict = {}
for cluster_idx in set(clf.labels_):
  instance_idx = np.where(clf.labels_ == cluster_idx)[0]
  filenames = [all_filenames[i] for i in instance_idx]
  cluster_filenames_dict[cluster_idx] = filenames
cluster_filenames_dict['ckpt'] = checkpoint_path

with gfile.Open(os.path.join(BASE_DIR, 'breeds/cluster_filenames'), 'wb') as f:
  pickle.dump(cluster_filenames_dict, f)

### Compute some clustering metrics

In [None]:
def compute_agg_metrics(clusters, classes, hier, all_subclasses):
  tp_plus_fp = comb(np.bincount(clusters), 2).sum()
  tp_plus_fn = comb(np.bincount(classes), 2).sum()
  A = np.c_[(clusters, classes)]
  tp = sum(comb(np.bincount(A[A[:, 0] == i, 1]), 2).sum()
            for i in set(clusters))
  fp = tp_plus_fp - tp
  fn = tp_plus_fn - tp
  tn = comb(len(A), 2) - tp - fp - fn
  print("TP=", tp)
  print("TN=", tn)
  print("FP=", fp)
  print("FN=", fn)
  rand_index_score = (tp + tn) / (tp + fp + fn + tn)
  f_score = tp / (tp + 0.5 * (fp + fn))
  beta = 0.5 # downweighs fn
  f_beta = (1 + beta**2) * tp / ((1 + beta**2) * tp + beta**2 * fn + fp)

  all_cluster_sizes = []
  all_cluster_label_count = defaultdict(int)
  clusters_with_no_mode = []
  n_cluster_points = 0
  confusion_matrix = np.zeros((len(all_subclasses), len(all_subclasses)))
  for cluster_idx in set(clusters):
    instance_idx = np.where(clusters == cluster_idx)[0]
    all_cluster_sizes.append(len(instance_idx))
    subclass_labels = classes[instance_idx]
    mode_stats = mode(subclass_labels)
    n_cluster_points += mode_stats[1][0]
    assert(mode_stats[0][0] in all_subclasses)
    if mode_stats[1][0] == 1: # no single class dominating
      if len(instance_idx) == 1:
        all_cluster_label_count[mode_stats[0][0]] += 1
      else:
        clusters_with_no_mode.append(cluster_idx)
        continue
    else:
      all_cluster_label_count[mode_stats[0][0]] += 1
    for l in subclass_labels:
      if l != mode_stats[0][0]:
        i = all_subclasses.index(l)
        j = all_subclasses.index(mode_stats[0][0])
        confusion_matrix[j][i] += 1 #row=true label, col = wrong pred

  if cluster_each_superclass:
    percent_subclass_covered = len(all_cluster_label_count.keys()) / len(all_subclasses)
    print("Proportion of subclasses covered:", percent_subclass_covered)
    if percent_subclass_covered < 1:
      subclasses_uncovered = set(all_subclasses) - set(all_cluster_label_count.keys())
      print("Classes that don't dominate any cluster", [hier.LEAF_NUM_TO_NAME[s] for s in list(subclasses_uncovered)])
  print(f"{len(clusters_with_no_mode)} clusters with no mode:")
  print(clusters_with_no_mode)
  purity = n_cluster_points / len(clusters)

  # plt.hist(all_cluster_sizes, bins=list(range(max(all_cluster_sizes)+1)))
  # plt.xlabel("cluster size")
  # plt.ylabel("number of clusters")
  # plt.show()
  return purity, rand_index_score, f_score, f_beta, confusion_matrix


all_confusion_matrix = []
purity_avg, rand_index_score_avg, f_score_avg, f_beta_avg = 0, 0, 0, 0
for i, clf in enumerate(all_clfs):
  if cluster_each_superclass:
    print("----------------------------SUPERCLASS:", label_map[i])
    subclasses = train_subclasses[i]
    subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
    subclass_labels = all_subclass_labels[subclass_idx]
  else:
    subclasses = all_subclasses
    subclass_labels = all_subclass_labels
  purity, rand_index_score, f_score, f_beta, confusion_matrix = compute_agg_metrics(clf.labels_, 
                                                                                  subclass_labels, hier,
                                                                                  subclasses)
  all_confusion_matrix.append(confusion_matrix)
  print("Purity:", purity)
  print("Rand index score:", rand_index_score)
  print("F1 score:", f_score)
  print("F_beta score:", f_beta)
  purity_avg += purity
  rand_index_score_avg += rand_index_score
  f_score_avg += f_score
  f_beta_avg += f_beta

purity_avg /= len(all_clfs)
rand_index_score_avg /= len(all_clfs)
f_score_avg /= len(all_clfs)
f_beta_avg /= len(all_clfs)
print("----------------------------AVERAGE")
print("Purity:", purity_avg)
print("Rand index score:", rand_index_score_avg)
print("F1 score:", f_score_avg)
print("F_beta score:", f_beta_avg)

### Visualize subclasses that tend to be confused together

In [None]:
visualize_error('cucumber', 'bell pepper', clf.labels_, all_subclasses, all_images, hier)

In [None]:
visualize_error('orange', 'lemon', clf.labels_, all_subclasses, all_images, hier)

In [None]:
visualize_error('Samoyed', 'Persian cat', clf.labels_, all_subclasses, all_images, hier)

### Visualize clusters with no dominant subclass


In [None]:
visualize_cluster(24, clf.labels_, all_images, all_subclass_labels, hier)

### Visualize all clusters

In [None]:
ret = make_breeds_dataset("living17", BREEDS_INFO_DIR, BREEDS_INFO_DIR,
                         split=None)
superclasses, subclass_split, label_map = ret
train_subclasses = subclass_split[0]
hier = ClassHierarchy()
eval_ds, num_classes = load_dataset('imagenet', train_subclasses=train_subclasses)
all_subclass_labels = []
all_images = []
for step, batch in enumerate(eval_ds):
  if step % 20 == 0:
    print(step)
  labels = batch['label'].numpy()
  bs = labels.shape[0]
  all_subclass_labels.append(labels)
  all_images.append(batch['image'].numpy())

all_subclass_labels = np.hstack(all_subclass_labels)
all_labels = all_subclass_labels
all_images = np.vstack(all_images)

In [None]:
model_path = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/')
with gfile.Open(os.path.join(model_path, 'clf_labels_ckpt_171_ocfactor_3.pkl'), 'rb') as f:
  all_clf_labels = pickle.load(f)
superclass_idx = 15
print("----------SUPERCLASS:", label_map[superclass_idx])
subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in train_subclasses[superclass_idx]])
visualize_all_clusters(all_images[subclass_idx], all_clf_labels[superclass_idx], all_subclass_labels[subclass_idx], hier, gender_labels=None)

In [None]:
if DATASET == 'celeb_a':
  all_gender_labels = []
  for step, batch in enumerate(eval_ds):
    gender = batch['attributes']['Male'].numpy().astype(int)
    all_gender_labels.append(gender)
  all_gender_labels = np.hstack(all_gender_labels)
  print(all_gender_labels.shape)

model_path = os.path.join(BASE_DIR, 'celebA/lr_0.0001_reg_0.0001_ema_0.99/')
with gfile.Open(os.path.join(model_path, 'clf_labels_ckpt_20.pkl'), 'rb') as f:
  all_clf_labels = pickle.load(f)
all_clf_labels = [clf.labels_ for clf in all_clfs]
class_idx = 0
subclass_idx = np.array([i for i in range(len(all_labels)) if all_labels[i] == class_idx])
print(f"---------Male ratio in this class: {sum(all_gender_labels[subclass_idx]) * 100.0/ len(subclass_idx)}%")
visualize_all_clusters(all_images[subclass_idx], all_clf_labels[class_idx], all_subclass_labels[subclass_idx], hier=None, gender_labels=all_gender_labels[subclass_idx])

## Visualize cluster reassignment

In [None]:
superclass_idx = 0
print("----------SUPERCLASS:", label_map[superclass_idx])
subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in train_subclasses[superclass_idx]])
subclass_labels = all_subclass_labels[subclass_idx]
n_data = len(subclass_idx)

model_path = os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99/')
with gfile.Open(os.path.join(model_path, 'clf_labels_ckpt_151.pkl'), 'rb') as f:
  all_clf_labels_prev = pickle.load(f)[superclass_idx]
with gfile.Open(os.path.join(model_path, 'clf_labels_ckpt_161.pkl'), 'rb') as f:
  all_clf_labels_curr = pickle.load(f)[superclass_idx]

reorder_idx = []
for subclass in train_subclasses[superclass_idx]:
  reorder_idx.extend(np.where(subclass_labels == subclass)[0])

fig = plt.figure(figsize=(20, 10))
for plot_idx, all_clf_labels in enumerate([all_clf_labels_prev, all_clf_labels_curr]):
  all_clf_labels = all_clf_labels[reorder_idx]
  cluster_assignment = np.zeros((n_data, n_data), dtype=int)
  for cluster_idx in list(set(all_clf_labels)):
    instance_idx = np.where(all_clf_labels == cluster_idx)[0]
    from itertools import combinations
    for pair in list(combinations(instance_idx, 2)):
      cluster_assignment[pair[0]][pair[1]] = 1
      cluster_assignment[pair[1]][pair[0]] = 1

  ax = fig.add_subplot(1, 2, plot_idx+1)
  #fig, ax = plt.subplots(1,1, figsize=(15,15))
  img = ax.matshow(cluster_assignment)
  n_subclass_per_class = len(train_subclasses[superclass_idx])
  subclass_names = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in train_subclasses[superclass_idx]]
  ax.set_xticks(range(0, n_data, n_data//n_subclass_per_class))
  ax.set_xticklabels(subclass_names, rotation=90)
  ax.set_yticks(range(0, n_data, n_data//n_subclass_per_class))
  ax.set_yticklabels(subclass_names)
  ax.grid(color='w', linewidth=1)
  #fig.colorbar(img)
plt.show()

In [None]:
superclass_idx = 0
print("----------SUPERCLASS:", label_map[superclass_idx])
subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in train_subclasses[superclass_idx]])
subclass_labels = all_subclass_labels[subclass_idx]
n_data = len(subclass_idx)
n_subclass_per_class = len(train_subclasses[superclass_idx])
n_data_per_subclass = n_data // n_subclass_per_class

model_path = os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99/')
reorder_idx = []
for subclass in train_subclasses[superclass_idx]:
  reorder_idx.extend(np.where(subclass_labels == subclass)[0])

fig = plt.figure(figsize=(20, 10))
plot_data = [[] for _ in range(n_subclass_per_class)]
for ckpt_number in range(11, 162, 10):
  with gfile.Open(os.path.join(model_path, f'clf_labels_ckpt_{ckpt_number}.pkl'), 'rb') as f:
    all_clf_labels = pickle.load(f)[superclass_idx][reorder_idx]

  cluster_assignment = np.zeros((n_data, n_data), dtype=int)
  for cluster_idx in list(set(all_clf_labels)):
    instance_idx = np.where(all_clf_labels == cluster_idx)[0]
    from itertools import combinations
    for pair in list(combinations(instance_idx, 2)):
      cluster_assignment[pair[0]][pair[1]] = 1
      cluster_assignment[pair[1]][pair[0]] = 1
  for i, _ in enumerate(train_subclasses[superclass_idx]):
    count = np.sum(cluster_assignment[n_data_per_subclass*i : (n_data_per_subclass*(i+1)), n_data_per_subclass*i : (n_data_per_subclass*(i+1))]) / n_data_per_subclass**2
    plot_data[i].append(count)

for i, s in enumerate(train_subclasses[superclass_idx]):
  plt.plot([(c-1) * 1000 for c in range(11, 162, 10)], plot_data[i], label=hier.LEAF_NUM_TO_NAME[s])
plt.legend()
plt.xlabel("Step")
plt.ylabel('Recall')
plt.show()