In [None]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from colabtools import adhoc_import
from typing import Any, Sequence, Optional
import ml_collections
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import functools
import itertools
from scipy.stats import mode
from collections import defaultdict
import matplotlib.pyplot as plt
import time
import pickle
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_mutual_info_score
import optax



In [None]:
def predict(model, state, batch):
  """Get intermediate representations from a model."""
  variables = {
      "params": state.ema_params,
      "batch_stats": state.batch_stats
  }
  _, state = model.apply(variables, batch['image'], capture_intermediates=True, mutable=["intermediates"], train=False)
  intermediates = state['intermediates']#['stage4']['__call__'][0]
  return intermediates


def compute_purity(clusters, classes):
  """Compute purity of the cluster."""
  n_cluster_points = 0
  for cluster_idx in set(clusters):
    instance_idx = np.where(clusters == cluster_idx)[0]
    subclass_labels = classes[instance_idx]
    mode_stats = mode(subclass_labels)
    n_cluster_points += mode_stats[1][0]
  purity = n_cluster_points / len(clusters)
  return purity

def get_learning_rate(step: int,
                      *,
                      base_learning_rate: float,
                      steps_per_epoch: int,
                      num_epochs: int,
                      warmup_epochs: int = 5):
  """Cosine learning rate schedule."""
  logging.info(
      "get_learning_rate(step=%s, base_learning_rate=%s, steps_per_epoch=%s, num_epochs=%s",
      step, base_learning_rate, steps_per_epoch, num_epochs)
  if steps_per_epoch <= 0:
    raise ValueError(f"steps_per_epoch should be a positive integer but was "
                     f"{steps_per_epoch}.")
  if warmup_epochs >= num_epochs:
    raise ValueError(f"warmup_epochs should be smaller than num_epochs. "
                     f"Currently warmup_epochs is {warmup_epochs}, "
                     f"and num_epochs is {num_epochs}.")
  epoch = step / steps_per_epoch
  lr = cosine_decay(base_learning_rate, epoch - warmup_epochs,
                    num_epochs - warmup_epochs)
  warmup = jnp.minimum(1., epoch / warmup_epochs)
  return lr * warmup

config = get_config()
learning_rate_fn = functools.partial(
      get_learning_rate,
      base_learning_rate=0.1,
      steps_per_epoch=40,
      num_epochs=config.num_epochs,
      warmup_epochs=config.warmup_epochs)

## Load dataset

In [None]:
ret = make_breeds_dataset("living17", BREEDS_INFO_DIR, BREEDS_INFO_DIR, 
                          split=None)
superclasses, subclass_split, label_map = ret
train_subclasses = subclass_split[0]
num_classes = len(train_subclasses)
print(train_subclasses)
print("Num_classes:", num_classes)
print(label_map)

In [None]:
DATASET = 'imagenet'
if DATASET == 'imagenet':
  all_subclasses = list(itertools.chain(*train_subclasses))
  new_label_map = {}
  for subclass_idx, sub in enumerate(all_subclasses):
    new_label_map.update({sub: subclass_idx})
  print(new_label_map)
  lookup_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=tf.constant(list(new_label_map.keys()), dtype=tf.int64),
          values=tf.constant(list(new_label_map.values()), dtype=tf.int64),
      ),
      default_value=tf.constant(-1, dtype=tf.int64))

  dataset_builder = tfds.builder("imagenet2012", try_gcs=True)
  eval_preprocess = preprocess_spec.PreprocessFn([
      RescaleValues(),
      ResizeSmall(256),
      CentralCrop(224),
      #LabelMappingOp(lookup_table=lookup_table)
      ], only_jax_types=True)
else:
  dataset_builder = tfds.builder("celeb_a", try_gcs=True)
  eval_preprocess = preprocess_spec.PreprocessFn([
      RescaleValues(),
      ResizeSmall(256),
      CentralCrop(224),
      LabelMapping()
      ], only_jax_types=True)


dataset_options = tf.data.Options()
dataset_options.experimental_optimization.map_parallelization = True
dataset_options.experimental_threading.private_threadpool_size = 48
dataset_options.experimental_threading.max_intra_op_parallelism = 1

read_config = tfds.ReadConfig(shuffle_seed=None, options=dataset_options)
eval_ds = dataset_builder.as_dataset(
    split=tfds.Split.VALIDATION,
    shuffle_files=False,
    read_config=read_config,
    decoders=None)

batch_size = 64
if DATASET == 'imagenet':
  eval_ds = eval_ds.filter(functools.partial(predicate, all_subclasses=all_subclasses))
eval_ds = eval_ds.cache()
eval_ds = eval_ds.map(eval_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_ds = eval_ds.batch(batch_size, drop_remainder=False)
eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)

## Compute purity of clusters using layer_name 

In [None]:
model_dir = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99_squared_loss_hyperparam_tuned/')
checkpoint_path = os.path.join(model_dir, 'checkpoints-0/ckpt-41.flax')
model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), num_classes=num_classes, learning_rate_fn=learning_rate_fn)
state = checkpoints.restore_checkpoint(checkpoint_path, state)
print("step:", state.step)

In [None]:
overcluster_factors = [1, 2, 3, 4, 5]
metric = 'purity'

for stage_prefix in ['stage4', 'stage2', 'stage3', 'stage1']:
  all_layer_intermediates = {}
  all_subclass_labels = []
  all_filenames = []
  all_images = []
  for step, batch in enumerate(eval_ds):
    if step % 20 == 0:
      print(step)
    intermediates = predict(model, state, batch)
    labels = batch['label'].numpy()
    bs = labels.shape[0]
    all_subclass_labels.append(labels)
    all_images.append(batch['image'].numpy())
    if 'file_name' in batch:
      all_filenames.append(batch['file_name'].numpy())
    
    count = 0
    for stage in sorted(intermediates.keys()):
      if not stage.startswith(stage_prefix):
        continue
      for block in sorted(intermediates[stage].keys()):
        if not block.startswith('block'):
          continue
        key = '_'.join([stage, block])
        if key not in all_layer_intermediates:
          all_layer_intermediates[key] = []
        all_layer_intermediates[key].append(intermediates[stage][block]['__call__'][0].reshape(bs, -1))
  print(all_layer_intermediates.keys())

  all_subclass_labels = np.hstack(all_subclass_labels)
  all_images = np.vstack(all_images)
  if len(all_filenames) > 0:
    all_filenames = np.hstack(all_filenames)
    all_filenames = [f.decode("utf-8") for f in all_filenames]
  print(all_subclass_labels.shape)

  for key, all_intermediates in all_layer_intermediates.items():
      n_subclasses = len(train_subclasses[0])
      all_intermediates = np.vstack(all_intermediates)
      print(all_intermediates.shape)
      result_dict = {}

      for overcluster_factor in overcluster_factors:
        all_clfs = []

        for subclasses in train_subclasses:
          subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
          hier_clustering = AgglomerativeClustering(n_clusters=len(subclasses)*overcluster_factor,
                                                linkage='ward').fit(all_intermediates[subclass_idx])
          all_clfs.append(hier_clustering)


        metric_list = []
        all_clf_labels = []
        for i, clf in enumerate(all_clfs):
          all_clf_labels.append(clf.labels_)
          subclasses = train_subclasses[i]
          subclass_idx = np.array([
              i for i in range(len(all_subclass_labels))
              if all_subclass_labels[i] in subclasses
          ])
          subclass_labels = all_subclass_labels[subclass_idx]
          if metric == 'purity':
            metric = compute_purity(clf.labels_, subclass_labels)
          elif metric == 'ami':
            metric = adjusted_mutual_info_score(subclass_labels, clf.labels_)
          metric_list.append(metric)

        result_dict[overcluster_factor] = metric_list

      print(result_dict)
      if metric == 'purity':
        with gfile.Open(os.path.join(model_dir, f'class_purity_ckpt_{key}.pkl'), 'wb') as f:
          pickle.dump(metric_list, f)
        with gfile.Open(os.path.join(model_dir, f'clf_labels_ckpt_{key}.pkl'), 'wb') as f:
          pickle.dump(all_clf_labels, f)
      elif metric == 'ami':
        with gfile.Open(os.path.join(model_dir, f'adjusted_mutual_info_ckpt_{key}.pkl'), 'wb') as f:
          pickle.dump(result_dict, f)

## Compute purity of clusters over time using second-to-last layer representations

In [None]:
model_dir = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/')
ckpt_list = list(range(11, 172, 10)) 
overcluster_factors = [1, 2, 3, 4, 5]
n_subclasses = len(train_subclasses[0])
for ckpt_number in ckpt_list:
  checkpoint_path = os.path.join(model_dir, f'checkpoints-0/ckpt-{ckpt_number}.flax')
  model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), 
                                    num_classes=num_classes, learning_rate_fn=learning_rate_fn)
  state = checkpoints.restore_checkpoint(checkpoint_path, state)
  print("Ckpt number", ckpt_number, "Ckpt step:", state.step)

  result_dict = {}
  all_intermediates = []
  all_subclass_labels = []
  all_filenames = []
  all_images = []
  for step, batch in enumerate(eval_ds):
    if step % 20 == 0:
      print(step)
    intermediates = predict(model, state, batch)
    labels = batch['label'].numpy()
    bs = labels.shape[0]
    all_subclass_labels.append(labels)
    all_images.append(batch['image'].numpy())
    if 'file_name' in batch:
      all_filenames.append(batch['file_name'].numpy())
    all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))

  all_intermediates = np.vstack(all_intermediates)
  all_subclass_labels = np.hstack(all_subclass_labels)
  all_images = np.vstack(all_images)
  if len(all_filenames) > 0:
    all_filenames = np.hstack(all_filenames)
    all_filenames = [f.decode("utf-8") for f in all_filenames]
  print(all_intermediates.shape)
  print(all_subclass_labels.shape)

  for overcluster_factor in overcluster_factors:
    all_clfs = []

    for subclasses in train_subclasses:
      subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
      hier_clustering = AgglomerativeClustering(n_clusters=len(subclasses)*overcluster_factor,
                                            linkage='ward').fit(all_intermediates[subclass_idx])
      all_clfs.append(hier_clustering)


    metric_list = []
    all_clf_labels = []
    for i, clf in enumerate(all_clfs):
      all_clf_labels.append(clf.labels_)
      subclasses = train_subclasses[i]
      subclass_idx = np.array([
          i for i in range(len(all_subclass_labels))
          if all_subclass_labels[i] in subclasses
      ])
      subclass_labels = all_subclass_labels[subclass_idx]
      if metric == 'purity':
        metric = compute_purity(clf.labels_, subclass_labels)
      elif metric == 'ami':
        metric = adjusted_mutual_info_score(subclass_labels, clf.labels_)
      metric_list.append(metric)

    result_dict[overcluster_factor] = metric_list

  print(result_dict)
  if metric == 'purity':
    with gfile.Open(os.path.join(model_dir, f'class_purity_ckpt_{ckpt_number}.pkl'), 'wb') as f:
      pickle.dump(metric_list, f)
    with gfile.Open(os.path.join(model_dir, f'clf_labels_ckpt_{ckpt_number}}.pkl'), 'wb') as f:
      pickle.dump(all_clf_labels, f)
  elif metric == 'ami':
    with gfile.Open(os.path.join(model_dir, f'adjusted_mutual_info_ckpt_{ckpt_number}.pkl'), 'wb') as f:
      pickle.dump(result_dict, f)

### CelebA

In [None]:
model_dir = os.path.join(BASE_DIR, 'celebA/lr_0.0001_reg_0.0001_ema_0.99/')
ckpt_list = list(range(2, 22, 2))
overcluster_factor = 10
num_classes = 2
for ckpt_number in ckpt_list:
  checkpoint_path = os.path.join(model_dir, f'checkpoints-0/ckpt-{ckpt_number}.flax')
  model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), 
                                    num_classes=2, learning_rate_fn=learning_rate_fn)
  state = checkpoints.restore_checkpoint(checkpoint_path, state)
  print("Ckpt step:", state.step)

  all_intermediates = []
  all_labels = []
  all_filenames = []
  all_images = []
  for step, batch in enumerate(eval_ds):
    if step % 50 == 0:
      print(step)
    intermediates = predict(model, state, batch)
    labels = batch['label'].numpy()
    bs = labels.shape[0]
    all_labels.append(labels)
    all_images.append(batch['image'].numpy())
    if 'file_name' in batch:
      all_filenames.append(batch['file_name'].numpy())
    all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))

  all_intermediates = np.vstack(all_intermediates)
  all_labels = np.hstack(all_labels)
  all_images = np.vstack(all_images)
  if len(all_filenames) > 0:
    all_filenames = np.hstack(all_filenames)
    all_filenames = [f.decode("utf-8") for f in all_filenames]
  print(all_intermediates.shape)
  print(all_labels.shape)

  all_clfs = []

  for subclass in range(num_classes):
    subclass_idx = np.array([i for i in range(len(all_labels)) if all_labels[i] == subclass])
    #print(len(subclass_idx))
    hier_clustering = AgglomerativeClustering(n_clusters=overcluster_factor,
                                          linkage='ward').fit(all_intermediates[subclass_idx])
    all_clfs.append(hier_clustering)


  purity_list = []
  all_clf_labels = []
  for i, clf in enumerate(all_clfs):
    all_clf_labels.append(clf.labels_)
    subclass = list(range(num_classes))[i]
    subclass_idx = np.array([
        i for i in range(len(all_labels))
        if all_labels[i] == subclass
    ])
    subclass_labels = all_labels[subclass_idx]
    purity = compute_purity(clf.labels_, subclass_labels)
    purity_list.append(purity)

  with gfile.Open(os.path.join(model_dir, f'class_purity_ckpt_{ckpt_number}.pkl'), 'wb') as f:
    pickle.dump(purity_list, f)
  with gfile.Open(os.path.join(model_dir, f'clf_labels_ckpt_{ckpt_number}.pkl'), 'wb') as f:
    pickle.dump(all_clf_labels, f)
  print(ckpt_number)

## Concat second-to-last representations from 2 checkpoints

In [None]:
model_dir = os.path.join(BASE_DIR, 'breeds/breeds_training_level_3_nsubclass_20_400_epochs_ema_0.99_seed_2/')
concat_intermediates = []
for ckpt_number in [151, 161]:
  checkpoint_path = f'checkpoints-0/ckpt-{ckpt_number}.flax'
  model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), 
                                    num_classes=num_classes, learning_rate_fn=learning_rate_fn)
  state = checkpoints.restore_checkpoint(checkpoint_path, state)
  print("Ckpt step:", state.step)

  all_intermediates = []
  all_subclass_labels = []
  all_filenames = []
  all_images = []
  for step, batch in enumerate(eval_ds):
    if step % 20 == 0:
      print(step)
    intermediates = predict(model, state, batch)
    labels = batch['label'].numpy()
    bs = labels.shape[0]
    all_subclass_labels.append(labels)
    all_images.append(batch['image'].numpy())
    if 'file_name' in batch:
      all_filenames.append(batch['file_name'].numpy())
    all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))


  overcluster_factor = 5
  n_subclasses = len(train_subclasses[0])
  all_intermediates = np.vstack(all_intermediates)
  concat_intermediates.append(all_intermediates)
  all_subclass_labels = np.hstack(all_subclass_labels)
  all_images = np.vstack(all_images)
  if len(all_filenames) > 0:
    all_filenames = np.hstack(all_filenames)
    all_filenames = [f.decode("utf-8") for f in all_filenames]
  print(all_intermediates.shape)
  print(all_subclass_labels.shape)

all_clfs = []
concat_intermediates = np.hstack(concat_intermediates)
for subclasses in train_subclasses:
  subclass_idx = np.array([i for i in range(len(all_subclass_labels)) if all_subclass_labels[i] in subclasses])
  #print(len(subclass_idx))
  hier_clustering = AgglomerativeClustering(n_clusters=len(subclasses)*overcluster_factor,
                                        linkage='ward').fit(concat_intermediates[subclass_idx])
  all_clfs.append(hier_clustering)


purity_list = []
all_clf_labels = []
for i, clf in enumerate(all_clfs):
  all_clf_labels.append(clf.labels_)
  subclasses = train_subclasses[i]
  subclass_idx = np.array([
      i for i in range(len(all_subclass_labels))
      if all_subclass_labels[i] in subclasses
  ])
  subclass_labels = all_subclass_labels[subclass_idx]
  purity = compute_purity(clf.labels_, subclass_labels)
  purity_list.append(purity)

print(ckpt_number)
print(purity_list)