Copyright 2020 The Google Research Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Activation Clustering Model: Similar Examples, Concepts

This notebook shows how to use an activation clustering model to discover training examples similar to a test example, and "concepts".

Here we use a trained activation clustering model (in the `work_dir` directory) whose baseline model is a ResNet classification model trained on the CIFAR-10 dataset.


In [None]:
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds

from activation_clustering import ac_model, utils

# Restore an activation clustering model

In [None]:
acm = ac_model.ACModel.restore('work_dir')

In [None]:
from activation_clustering import utils
utils.get_activation_shapes(acm.baseline_model, acm.activation_names)

In [None]:
acm.activation_names

In [None]:
# The same dataset preprocessing as used in the baseline model training.
def input_fn(batch_size, ds, label_key='label'):
    dataset = ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

    def interface(batch):
        features = tf.cast(batch['image'], tf.float32) / 255     
        labels = batch[label_key]

        return features, labels

    return dataset.map(interface)

In [None]:
test_ds = tfds.load(
    'cifar10:3.*.*',
    shuffle_files=False,
    split='test'
)

test_ds = input_fn(batch_size=10000, ds=test_ds)

test_features, test_labels =list(test_ds.take(1))[0]
del test_ds

test_features = test_features.numpy()
test_labels = test_labels.numpy()

# An activation clustering model can be used as a surrogate model for its baseline model.

In [None]:
# Accuracy of the surrogate model
print('surrogate model accuracy: ', acm.evaluate(features=test_features, y=test_labels))

In [None]:
# fidelity: how much does the surrogate model agree with the baseline model
baseline_labels = np.argmax(acm.baseline_model.predict(test_features), axis=-1)
print('fidelity: ', acm.evaluate(features=test_features, y=baseline_labels))

In [None]:
acm.predict_proba(features=test_features[:3])

In [None]:
acm.clustering_predict_labels(features=test_features[:3])

In [None]:
# accuracy of the baseline model
np.sum(test_labels == baseline_labels)

# Get similar training examples of random test examples

In [None]:
# load the training features (images) for visualization
train_ds = tfds.load(
    'cifar10:3.*.*',
    shuffle_files=False,
    split='train'
)

train_ds = input_fn(batch_size=50000, ds=train_ds)

train_features, train_labels =list(train_ds.take(1))[0]
del train_ds

train_features = train_features.numpy()
train_labels = train_labels.numpy()

In [None]:
test_indices = np.random.choice(10000, size=10, replace=False)

test_feat = test_features[test_indices]
print(test_feat.shape)

print('test indices:    {}'.format(test_indices))
print(test_labels[test_indices])

In [None]:
equal = [1.0, 1.0, 1.0, 1.0]
low = [2.0, 1.0, 0.0, 0.0]
high = [0.0, 0.0, 1.0, 2.0]

The distances between embeddings from different layers are averaged to output visually similar images.

In [None]:
ind = acm.query(features=test_feat, weights=equal)
ind

train_image_arrays_list = train_features[ind]

utils.visualize_similar(
    test_image_arrays=test_feat,
    train_image_arrays_list=train_image_arrays_list,
    test_labels=test_labels[test_indices].tolist(),
    train_labels=train_labels[ind].tolist()
)

In [None]:
ind = acm.query(features=test_feat, weights=low)
ind

train_image_arrays_list = train_features[ind]

utils.visualize_similar(
    test_image_arrays=test_feat,
    train_image_arrays_list=train_image_arrays_list,
    test_labels=test_labels[test_indices].tolist(),
    train_labels=train_labels[ind].tolist()
)

In [None]:
ind = acm.query(features=test_feat, weights=high)
ind

train_image_arrays_list = train_features[ind]

utils.visualize_similar(
    test_image_arrays=test_feat,
    train_image_arrays_list=train_image_arrays_list,
    test_labels=test_labels[test_indices].tolist(),
    train_labels=train_labels[ind].tolist()
)

# Compare with similar images based on the last activation

In [None]:
# get last activations

activation_names = [acm.activation_names[-1]]

train_activations = acm.get_activations_from_features(train_features, activation_names)
test_activations = acm.get_activations_from_features(test_features, activation_names)

print(train_activations.keys(), test_activations.keys())

In [None]:
# query with activations
from scipy.spatial.distance import cdist


def query(test_acts, train_acts, weights=None):
    if weights is None:
        weights = [1.0] * len(acm.activation_names)
    distances = 0.0
    for i, activation_name in enumerate(acm.activation_names):
        if activation_name not in test_acts:
            continue

        test_act = test_acts[activation_name]
        train_act = train_acts[activation_name]
        
        # flatten
        test_act = test_act.reshape((len(test_act), -1))
        train_act = train_act.reshape((len(train_act), -1))

        dis = cdist(test_act, train_act, 'euclidean')
        distances += dis * weights[i]

    ind = np.argsort(distances, axis=-1)
    
    return ind

In [None]:
test_acts = {k:v[test_indices] for k, v in test_activations.items()}
train_acts = train_activations

In [None]:
# using only the last activation
%time last_act_ind = query(test_acts, train_acts)

Using the distance between the last activations to determine which images are similar does not capture low-level visual features.

In [None]:
train_image_arrays_list = train_features[last_act_ind]

utils.visualize_similar(
    test_image_arrays=test_feat,
    train_image_arrays_list=train_image_arrays_list,
    test_labels=test_labels[test_indices].tolist(),
    train_labels=train_labels[last_act_ind].tolist()
)

# Concepts

Here we think of each cluster as a "concept" and visualize images closest to each cluster centroid.

In [None]:
concept_indices = acm.concept_indices()

Each activation has its own list of clusters.  Earlier activations capture low-level visual features.

In [None]:
activation_index = 0
print('Concepts based on {}'.format(acm.activation_names[activation_index]))
concept_ind = concept_indices[activation_index]
train_image_arrays_list = train_features[concept_ind]

utils.visualize_concepts(train_image_arrays_list)

In [None]:
activation_index = 2
print('Concepts based on {}'.format(acm.activation_names[activation_index]))
concept_ind = concept_indices[activation_index]
train_image_arrays_list = train_features[concept_ind]

utils.visualize_concepts(train_image_arrays_list)

In [None]:
activation_index = 3
print('Concepts based on {}'.format(acm.activation_names[activation_index]))
concept_ind = concept_indices[activation_index]
train_image_arrays_list = train_features[concept_ind]

utils.visualize_concepts(train_image_arrays_list)