In [None]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from colabtools import adhoc_import
from typing import Any, Sequence
import ml_collections
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import functools
import itertools
from scipy.stats import mode
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import time
import pickle
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_mutual_info_score
from sklearn.mixture import GaussianMixture

from clu import preprocess_spec
from scipy.special import comb



In [None]:
def get_learning_rate(step: int,
                      *,
                      base_learning_rate: float,
                      steps_per_epoch: int,
                      num_epochs: int,
                      warmup_epochs: int = 5):
  """Cosine learning rate schedule."""
  logging.info(
      "get_learning_rate(step=%s, base_learning_rate=%s, steps_per_epoch=%s, num_epochs=%s",
      step, base_learning_rate, steps_per_epoch, num_epochs)
  if steps_per_epoch <= 0:
    raise ValueError(f"steps_per_epoch should be a positive integer but was "
                     f"{steps_per_epoch}.")
  if warmup_epochs >= num_epochs:
    raise ValueError(f"warmup_epochs should be smaller than num_epochs. "
                     f"Currently warmup_epochs is {warmup_epochs}, "
                     f"and num_epochs is {num_epochs}.")
  epoch = step / steps_per_epoch
  lr = cosine_decay(base_learning_rate, epoch - warmup_epochs,
                    num_epochs - warmup_epochs)
  warmup = jnp.minimum(1., epoch / warmup_epochs)
  return lr * warmup

def predict(model, state, batch):
  """Get intermediate representations from a model."""
  variables = {
      "params": state.ema_params,
      "batch_stats": state.batch_stats
  }
  _, state = model.apply(variables, batch['image'], capture_intermediates=True, mutable=["intermediates"], train=False)
  intermediates = state['intermediates']#['stage4']['__call__'][0]
  return intermediates

def compute_purity(clusters, classes):
  """Compute purity of the cluster."""
  n_cluster_points = 0
  for cluster_idx in set(clusters):
    instance_idx = np.where(clusters == cluster_idx)[0]
    subclass_labels = classes[instance_idx]
    mode_stats = mode(subclass_labels)
    n_cluster_points += mode_stats[1][0]
  purity = n_cluster_points / len(clusters)
  return purity

def show_images_horizontally(images, labels, super_label=None, max_samples=20, show_labels=False):
  """Display images in a row using matplotlib."""
  if max_samples is not None:
    number_of_files = min(images.shape[0], max_samples)
    fig = plt.figure(figsize=(1.5 * number_of_files, 1.5))
    n_rows = 1
    n_cols = number_of_files
  else:
    number_of_files = images.shape[0]
    n_cols = 10
    n_rows = int(np.ceil(number_of_files/n_cols))
    fig = plt.figure(figsize=(1.5 * n_cols, 1.5 * n_rows))
                          
  for i in range(number_of_files):
    axes = fig.add_subplot(n_rows, n_cols, i+1)
    axes.imshow(images[i])
    if show_labels:
      if isinstance(labels, list):
        axes.set_title(labels[i])
      else:
        axes.set_title(labels)
  if super_label:
    plt.title(super_label)
  plt.show()

def visualize_all_clusters(all_images, cluster_labels, all_subclass_labels, hier, gender_labels, show_labels=False):
  """Visualize images from each cluster."""
  for label in list(set(cluster_labels)):
    instance_idx = np.where(cluster_labels == label)[0]
    subclass_labels = all_subclass_labels[instance_idx]
    mode_stats = mode(subclass_labels)
    cluster_images = all_images[instance_idx]
    cluster_label = mode_stats[0][0]
    if hier:
      cluster_label_str = hier.LEAF_NUM_TO_NAME[cluster_label].split(',')[0]
      subclass_labels = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in subclass_labels]
      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10, show_labels=show_labels)
    elif gender_labels:
      male_pc = sum(gender_labels[instance_idx]) * 100.0 / len(instance_idx)
      cluster_label_str = f'Male ratio = {male_pc}%'
      subclass_labels = ['' for s in subclass_labels]
      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10, show_labels=show_labels)
    else:
      show_images_horizontally(cluster_images, subclass_labels, super_label='', max_samples=10, show_labels=show_labels)

config = get_config()
learning_rate_fn = functools.partial(
      get_learning_rate,
      base_learning_rate=0.1,
      steps_per_epoch=40,
      num_epochs=config.num_epochs,
      warmup_epochs=config.warmup_epochs)


def evaluate_purity(eval_dataset, model_dir, ckpt_number, n_classes, overcluster_factors, n_subclasses):
  """Given a model and a dataset, cluster the second-to-last layer representations and compute average purity."""
  checkpoint_path = os.path.join(model_dir, f'checkpoints-0/ckpt-{ckpt_number}.flax')
  model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), 
                                    num_classes=n_classes, learning_rate_fn=learning_rate_fn)
  state = checkpoints.restore_checkpoint(checkpoint_path, state)
  print("Ckpt number", ckpt_number, "Ckpt step:", state.step)

  result_dict = {}
  all_intermediates = []
  all_subclass_labels = []
  all_images = []
  for step, batch in enumerate(eval_ds):
    if step % 50 == 0:
      print(step)
    intermediates = predict(model, state, batch)
    labels = batch['label'].numpy()
    bs = labels.shape[0]
    all_subclass_labels.append(labels)
    all_images.append(batch['image'].numpy())
    all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))

  all_intermediates = np.vstack(all_intermediates)
  all_subclass_labels = np.hstack(all_subclass_labels)
  all_images = np.vstack(all_images)

  for overcluster_factor in overcluster_factors:
    clf = AgglomerativeClustering(n_clusters=n_subclasses*overcluster_factor,
                                          linkage='ward').fit(all_intermediates)
    all_clf_labels = clf.labels_
    purity = compute_purity(clf.labels_, all_subclass_labels)
    result_dict[overcluster_factor] = purity
  return result_dict, all_clf_labels

# Load OOD dataset

In [None]:
DATASET = 'oxford_iiit_pet'
SPLIT = tfds.Split.TRAIN
dataset_builder = tfds.builder(DATASET, try_gcs=True)
eval_preprocess = preprocess_spec.PreprocessFn([
    RescaleValues(),
    ResizeSmall(256),
    CentralCrop(224),
    GeneralPreprocessOp(),
    ], only_jax_types=True)
dataset_options = tf.data.Options()
dataset_options.experimental_optimization.map_parallelization = True
dataset_options.experimental_threading.private_threadpool_size = 48
dataset_options.experimental_threading.max_intra_op_parallelism = 1
read_config = tfds.ReadConfig(shuffle_seed=None, options=dataset_options)
eval_ds = dataset_builder.as_dataset(
    split=SPLIT,
    shuffle_files=False,
    read_config=read_config,
    decoders=None)
batch_size = 64
eval_ds = eval_ds.cache()
eval_ds = eval_ds.map(eval_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_ds = eval_ds.batch(batch_size, drop_remainder=False)
eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)

# Evaluate purity on OOD data for a single model

In [None]:
model_dir = os.path.join(BASE_DIR, 'breeds/entity13_400_epochs_ema_0.99_bn_0.99/')
ckpt_number = 161
N_CLASSES = 13
model_dir = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/')
ckpt_number = 173
N_CLASSES = 17
# model_dir = os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_shuffle_400_epochs_ema_0.99_bn_0.99/')
# ckpt_number = 129
# N_CLASSES = 13
model_dir = os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/')
ckpt_number = 257
N_CLASSES = 26
model_dir = os.path.join(BASE_DIR, 'breeds/imagenet_ema_0.99_bn_0.99/')
ckpt_number = 8
N_CLASSES = 1000
overcluster_factors = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
n_subclasses = 5

result_dict, all_clf_labels = evaluate_purity(eval_ds, model_dir, ckpt_number, N_CLASSES, overcluster_factors, n_subclasses)
train_subset = model_dir.split('/')[-2].split('_')[0]
print(result_dict)
plt.plot(result_dict.keys(), result_dict.values(), marker='o')
plt.xlabel("Overclustering factor")
plt.ylabel("Purity")
plt.title(f"{train_subset} -> {DATASET}")

# Evaluate purity on OOD data for multiple models

In [None]:
model_metadata = [(os.path.join(BASE_DIR, 'breeds/entity13_400_epochs_ema_0.99_bn_0.99/'), 161, 13),
                  (os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/'), 173, 17),
                  (os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/'), 257, 26),
                  (os.path.join(BASE_DIR, 'breeds/imagenet_ema_0.99_bn_0.99/'), 8, 1000)
                  ]

fig = plt.figure()
for model_dir, ckpt_number, N_CLASSES in model_metadata:
  result_dict, _ = evaluate_purity(eval_ds, model_dir, ckpt_number, N_CLASSES, overcluster_factors, n_subclasses)
  print(result_dict)
  leg = model_dir.split('/')[-2].split('_')[0]
  plt.plot(result_dict.keys(), result_dict.values(), marker='o', label=leg)
plt.xlabel("Overclustering factor")
plt.ylabel("Purity")
plt.legend()
plt.title(f"{DATASET}")

# Visualize sample images from some clusters

In [None]:
for clf_label in range(15):
  clf_idx = np.where(np.array(all_clf_labels) == clf_label)[0]
  visualize_all_clusters(all_images[clf_idx], all_clf_labels[clf_idx], all_subclass_labels[clf_idx], None, None)