Copyright 2020 The Google Research Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Activation Clustering Model: Training

This notebook shows how to train an activation clustering model from a trained baseline Keras model.  Here we use a ResNet classification model trained on the CIFAR-10 dataset as an example.  The model is included as `model.h5`.


In [None]:
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds

from activation_clustering import ac_model, utils

In [None]:
# The same dataset preprocessing as used in the baseline cifar10 model training.
def input_fn(batch_size, ds, label_key='label'):
    dataset = ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

    def interface(batch):
        features = tf.cast(batch['image'], tf.float32) / 255     
        labels = batch[label_key]

        return features, labels

    return dataset.map(interface)

# Train an activation clustering model from a baseline model

Load the baseline model for the actication clustering model to calculate activations.

In [None]:
model = tf.keras.models.load_model('model.h5')

Activation clustering model's configurations.  The first entry in each pair is a layer name of the baseline model, whose output activations will be clustered.  The second entry is a dict with key `n_clusters` specifying the number of clusters.

We use deep embedding clustering (DEC) as the clustering algorithm in this implementation, which has several other parameters that you can expose by modifying the `activation_clustering` library and configure here.

In [None]:
clustering_config = [
    ('activation', {'n_clusters': 15}),
    ('activation_18', {'n_clusters': 15}),
    ('activation_36', {'n_clusters': 15}),
    ('activation_54', {'n_clusters': 15})
]

# Uncomment this for shorter training time for debugging/test runs.
# clustering_config = [
#     ('activation', {'n_clusters': 10}),
#     ('activation_54', {'n_clusters': 10, 'filters': [16, 16, 16, 8]})
# ]

work_dir = 'new_work_dir'

In [None]:
new_acm = ac_model.ACModel(model, clustering_config, work_dir=work_dir)

Calling `build_clustering_models` creates clustering models, one for each specified activation.

In [None]:
new_acm.build_clustering_models()
new_acm.clustering_models

In [None]:
train_ds = tfds.load(
    'cifar10:3.*.*',
    shuffle_files=False,
    split='train'
)

test_ds = tfds.load(
    'cifar10:3.*.*',
    shuffle_files=False,
    split='test'
)

# # Uncommend this to use just a portion of data in this example for shorter training time.
# train_ds = tfds.load(
#     'cifar10:3.*.*',
#     shuffle_files=False,
#     split='train[:10%]'
# )

# test_ds = tfds.load(
#     'cifar10:3.*.*',
#     shuffle_files=False,
#     split='test[:10%]'
# )

In [None]:
# Cache the activations to make it easier to iterate.

batch_size = 500

ds = input_fn(batch_size, train_ds)
new_acm.cache_activations(ds, tag='train')
del ds

ds = input_fn(batch_size, test_ds)
new_acm.cache_activations(ds, tag='test')
del ds

In [None]:
activations_dict = new_acm.load_activations_dict(
    activations_filename=work_dir+'/activations/activations_train.npz')

test_activations_dict = new_acm.load_activations_dict(
    activations_filename=work_dir+'/activations/activations_test.npz')

In [None]:
for k, v in activations_dict.items():
    print(k, v.shape)

In [None]:
# Here we use a small number of epochs/iterations for shorter training time.
# The activation clustering training loop handles model saving in its `work_dir`.

epochs = 15
maxiter = 980

# # Uncomment this for shorter training time
# epochs = 2
# maxiter = 280

new_acm.fit(activations_dict=activations_dict, epochs=epochs, maxiter=maxiter)