# Multi-label Pascal VOC 2007 CAM Assisted Training

**References**

- Zhou, Bolei, Aditya Khosla, Agata Lapedriza, Aude Oliva, and Antonio Torralba. “Learning deep features for discriminative localization.” IEEE conference on computer vision and pattern recognition (CVPR), pp. 2921-2929. 2016. [1512.04150](https://arxiv.org/pdf/1512.04150.pdf)

In [None]:
#@title

! pip -qq install tensorflow-addons

In [None]:
import tensorflow as tf

class Config:
  class data:
    size = (224, 224)
    shape = (*size, 3)
    batch_size = 32
    shuffle_buffer_size = 8 * batch_size
    prefetch_buffer_size = tf.data.experimental.AUTOTUNE
    train_shuffle_seed = 120391
    shuffle = True

    preprocess = tf.keras.applications.vgg16.preprocess_input
    deprocess = lambda x: tf.cast(tf.clip_by_value(x[..., ::-1] + [103.939, 116.779, 123.68], 0, 255), tf.uint8)

  class aug:
    brightness_delta =  .2
    saturation_lower =  .2
    saturation_upper = 1.0
    contrast_lower   =  .5
    contrast_upper   = 1.5
    hue_delta        =  .0
    
  class model:
    last_spatial_layer = 'block5_pool'
    first_dense_layer = 'avg_pool'
    backbone = tf.keras.applications.VGG16
  
  class training:
    epochs = 50
    learning_rate = .002
    lr_first_decay_steps = 50
    
    fine_tune_lr = .00001
    fine_tune_epochs = 20
    fine_tune_layers = .6  # 60%
    freeze_batch_norm = False

    early_stopping_patience = epochs // 4
  
  class explaining:
    noise = tf.constant(.2)
    repetitions = tf.constant(8)

    score_activations = tf.constant(256)
  
  class segmentation:
    class data:
      batch_size = 16
      shuffle_buffer_size = 8 * batch_size

    class training:
      epochs = 100
      early_stopping_patience = epochs // 3
      reduce_lr_on_plateau_patience = max(5, epochs // 10)

      # Loss params
      cl_ce_w = 0.05
      lr_initial = .005
      lr_first_decay_steps = 10  # None for constant learning-rate


  class experiment:
    seed = 218402
    override = True
    logs              = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007/'
    fine_tune_logs    = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007-fine-tune/'
    segmentation_logs = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007-segmentation/'

    training_weights     = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007/weights.h5'
    fine_tune_weights    = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007-fine-tune/weights.h5'
    segmentation_weights = '/content/drive/MyDrive/logs/pascal/vgg16-ce-cam-pascal-voc-2007-segmentation/weights.h5'

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import shutil
from math import ceil

import numpy as np
import pandas as pd
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import seaborn as sns

from tensorflow.keras import callbacks

In [None]:
for d in tf.config.list_physical_devices('GPU'):
  print(d)
  print(f'Setting device {d} to memory-growth mode.')
  try:
    tf.config.experimental.set_memory_growth(d, True)
  except Exception as e:
    print(e)

In [None]:
R = tf.random.Generator.from_seed(Config.experiment.seed, alg='philox')
C = np.asarray(sns.color_palette("Set1", 21))
CMAP = sns.color_palette("Set1", 21, as_cmap=True)

sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
def normalize(x, reduce_min=True, reduce_max=True):
  if reduce_min: x -= tf.reduce_min(x, axis=(-3, -2), keepdims=True)
  if reduce_max: x = tf.math.divide_no_nan(x, tf.reduce_max(x, axis=(-3, -2), keepdims=True))

  return x


def visualize(
    image,
    title=None,
    rows=2,
    cols=None,
    figsize=(16, 7.2),
    cmap=None
):
  if image is not None:
    if isinstance(image, (list, tuple)) or len(image.shape) > 3:  # many images
      plt.figure(figsize=figsize)
      cols = cols or ceil(len(image) / rows)
      for ix in range(len(image)):
        plt.subplot(rows, cols, ix+1)
        visualize(image[ix],
                 cmap=cmap,
                 title=title[ix] if title is not None and len(title) > ix else None)
      plt.tight_layout()
      return

    if isinstance(image, tf.Tensor): image = image.numpy()
    if image.shape[-1] == 1: image = image[..., 0]
    plt.imshow(image, cmap=cmap)
  
  if title is not None: plt.title(title)
  plt.axis('off')

## Dataset

### Augmentation Policy

In [None]:
def default_policy_fn(image):
  image = tf.image.resize_with_crop_or_pad(image, *Config.data.size)
  # mask = tf.image.resize_with_crop_or_pad(mask, *Config.data.size)

  return image


def augment_policy_fn(image):
  seeds = R.make_seeds(6)

  image = tf.image.resize_with_crop_or_pad(image, *Config.data.size)
  # image = tf.image.stateless_random_crop(image, [*Config.data.size, 3], seed=seeds[:, 0])
  # mask = tf.image.stateless_random_crop(mask, [*Config.data.size, 1], seed=seeds[:, 0])

  image = tf.image.stateless_random_flip_left_right(image, seed=seeds[:, 0])
  # mask = tf.image.stateless_random_flip_left_right(mask, seed=seeds[:, 0])
  
  image = tf.image.stateless_random_flip_up_down(image, seed=seeds[:, 1])
  # mask = tf.image.stateless_random_flip_up_down(mask, seed=seeds[:, 1])

  image = tf.image.stateless_random_hue(image, Config.aug.hue_delta, seed=seeds[:, 2])
  image = tf.image.stateless_random_brightness(image, Config.aug.brightness_delta, seed=seeds[:, 3])
  image = tf.image.stateless_random_contrast(image, Config.aug.contrast_lower, Config.aug.contrast_upper, seed=seeds[:, 4])
  image = tf.image.stateless_random_saturation(image, Config.aug.saturation_lower, Config.aug.saturation_upper, seed=seeds[:, 5])

  return image

### Preparing and Performance Settings

In [None]:
(train_dataset, val_dataset, test_dataset), info = tfds.load(
  'voc/2007',
  split=('train', 'validation', 'test'),
  with_info=True,
  shuffle_files=False
)

In [None]:
CLASSES = np.asarray(info.features['objects']['label']._int2str)
int2str = info.features['objects']['label'].int2str

In [None]:
from functools import partial


@tf.function
def load_fn(d, augment=False):
  image = d['image']
  labels = d['objects']['label']

  image = tf.cast(image, tf.float32)
  
  image, _ = adjust_resolution(image)
  image = (augment_policy_fn(image)
           if augment
           else default_policy_fn(image))
  
  image = Config.data.preprocess(image)

  return image, labels_to_one_hot(labels)


def adjust_resolution(image):
  es = tf.constant(Config.data.size, tf.float32)
  xs = tf.cast(tf.shape(image)[:2], tf.float32)

  ratio = tf.reduce_min(es / xs)
  xsn = tf.cast(tf.math.ceil(ratio * xs), tf.int32)

  image = tf.image.resize(image, xsn, preserve_aspect_ratio=True, method='nearest')

  return image, ratio


def labels_to_one_hot(labels):
  return tf.reduce_max(
      tf.one_hot(labels, depth=CLASSES.shape[0]),
    axis=0)


def prepare(ds, batch_size, cache=False, shuffle=False, augment=False):
  if cache: ds = ds.cache()
  if shuffle: ds = ds.shuffle(Config.data.shuffle_buffer_size, reshuffle_each_iteration=True, seed=Config.data.train_shuffle_seed)

  return (ds.map(partial(load_fn, augment=augment), num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size, drop_remainder=True)
            .prefetch(Config.data.prefetch_buffer_size))

In [None]:
train = prepare(train_dataset, Config.data.batch_size, shuffle=True, augment=True)
valid = prepare(val_dataset, Config.data.batch_size)
test = prepare(test_dataset, Config.data.batch_size)

### Examples in The Dataset

In [None]:
def talk_about(dataset, batches, tag):
  print(tag)
  print(f'  {batches}')
  print(f'  samples: {len(dataset)}')
  print(f'  steps  : {len(batches)}')
  print()


talk_about(train_dataset, train, 'Training')
talk_about(val_dataset, valid, 'Validation')

In [None]:
#@title

for images, labels in train.take(1):
  gt = ['\n'.join(CLASSES[l].astype(str))
        for l in labels.numpy().astype(bool)]

  visualize(
    Config.data.deprocess(images[:16]),
    gt,
    rows=2,
    figsize=(16, 6)
  )

## Network

In [None]:
print(f'Loading {Config.model.backbone.__name__}')

backbone = Config.model.backbone(
  classifier_activation=None,
  include_top=False,
  input_shape=Config.data.shape
)

In [None]:
from tensorflow.keras.layers import Dropout, Dense, GlobalAveragePooling2D


def build_specific_classifier(
    backbone,
    classes,
    dropout_rate=0.5,
    name=None,
    gpl='avg_pool',
):
  x = backbone.input
  y = backbone.output
  y = GlobalAveragePooling2D(name='avg_pool')(y)
  y = Dropout(rate=dropout_rate, name='top_dropout')(y)
  y = Dense(classes, name='predictions')(y)

  return tf.keras.Model(
    x,
    y,
    name=name
  )

backbone.trainable = False

nn = build_specific_classifier(backbone, len(CLASSES), name='enb7_voc_20')

In [None]:
nn.summary()

## Training

### Loss, Metrics and Model Compilation

In [None]:
class FromLogitsMixin:
  def __init__(self, from_logits=False, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.from_logits = from_logits

  def update_state(self, y_true, y_pred, sample_weight=None):
    if self.from_logits:
      y_pred = tf.nn.sigmoid(y_pred)
    return super().update_state(y_true, y_pred, sample_weight)


class AUC(FromLogitsMixin, tf.metrics.AUC):
  ...

class BinaryAccuracy(FromLogitsMixin, tf.metrics.BinaryAccuracy):
  ...

class TruePositives(FromLogitsMixin, tf.metrics.TruePositives):
  ...

class FalsePositives(FromLogitsMixin, tf.metrics.FalsePositives):
  ...

class TrueNegatives(FromLogitsMixin, tf.metrics.TrueNegatives):
  ...

class FalseNegatives(FromLogitsMixin, tf.metrics.FalseNegatives):
  ...

class Precision(FromLogitsMixin, tf.metrics.Precision):
  ...

class Recall(FromLogitsMixin, tf.metrics.Recall):
  ...

class F1Score(FromLogitsMixin, tfa.metrics.F1Score):
  ...

In [None]:
nn.compile(
    optimizer=tf.optimizers.SGD(learning_rate=Config.training.learning_rate, momentum=0.9, nesterov=True),
    loss=tf.losses.BinaryCrossentropy(from_logits=True),
    metrics=[
      AUC(from_logits=True),
      BinaryAccuracy(from_logits=True),
      F1Score(num_classes=len(CLASSES), from_logits=True),
      Precision(from_logits=True),
      Recall(from_logits=True),

      # TruePositives(from_logits=True),
      # FalsePositives(from_logits=True),
      # TrueNegatives(from_logits=True),
      # FalseNegatives(from_logits=True),
    ])

In [None]:
nn.evaluate(tf.random.normal((2, *Config.data.shape)));

### Top Classifier Training

In [None]:
cs = [
    callbacks.TerminateOnNaN(),
    callbacks.ModelCheckpoint(Config.experiment.training_weights,
                              save_best_only=True,
                              save_weights_only=True,
                              verbose=1),
    # callbacks.EarlyStopping(patience=Config.training.early_stopping_patience, verbose=1),
    callbacks.TensorBoard(
      Config.experiment.logs,
      write_graph=False,
      profile_batch=0)
]

In [None]:
try:
  if os.path.exists(Config.experiment.logs):
    if not Config.experiment.override:
      raise ValueError(f'A training was found in {Config.experiment.logs}. '
                       f'Either move it or set experiment.override to True.')

    print(f'Overriding previous training at {Config.experiment.logs}.')
    shutil.rmtree(Config.experiment.logs)

  nn.fit(
    train,
    validation_data=valid,
    epochs=Config.training.epochs,
    callbacks=cs
  );

except KeyboardInterrupt: print('\ninterrupted')
else: print('\ndone')

### Fine-Tuning

In [None]:
trained_epochs = len(nn.history.history['loss'])

nn.load_weights(Config.experiment.logs + '/weights.h5')

In [None]:
cs = [
    callbacks.TerminateOnNaN(),
    callbacks.ModelCheckpoint(Config.experiment.fine_tune_weights,
                              save_best_only=True,
                              save_weights_only=True,
                              verbose=1),
    callbacks.EarlyStopping(patience=Config.training.early_stopping_patience, verbose=1),
    callbacks.TensorBoard(
      Config.experiment.fine_tune_logs,
      write_graph=False)
]

In [None]:
if Config.training.fine_tune_epochs:
  backbone.trainable = True

  frozen_layer_ix = int((1-Config.training.fine_tune_layers) * len(backbone.layers))
  for ix, l in enumerate(backbone.layers):
    l.trainable = (ix > frozen_layer_ix and
                   (not isinstance(l, tf.keras.layers.BatchNormalization) or
                    not Config.training.freeze_batch_norm))

  nn.compile(
    optimizer=tf.optimizers.SGD(learning_rate=Config.training.fine_tune_lr, momentum=0.9, nesterov=True),
    loss=tf.losses.BinaryCrossentropy(from_logits=True),
    metrics=[
      AUC(from_logits=True),
      BinaryAccuracy(from_logits=True),
      F1Score(num_classes=len(CLASSES), from_logits=True),
      Precision(from_logits=True),
      Recall(from_logits=True),

      # TruePositives(from_logits=True),
      # FalsePositives(from_logits=True),
      # TrueNegatives(from_logits=True),
      # FalseNegatives(from_logits=True),
    ])

In [None]:
if Config.training.fine_tune_epochs:
  print(f'Fine tuning params:')
  print(f'  epochs:          {Config.training.fine_tune_epochs}')
  print(f'  learning rate:   {Config.training.fine_tune_lr}')
  print(f'  layers unfrozen: {frozen_layer_ix} to {len(backbone.layers)}')

  try:
    history = nn.fit(
      train,
      validation_data=valid,
      epochs=trained_epochs + Config.training.fine_tune_epochs,
      callbacks=cs,
      initial_epoch=trained_epochs,
    );

  except KeyboardInterrupt: print('\ninterrupted')
  else: print('\ndone')

## Evaluation

In [None]:
if Config.training.fine_tune_epochs:
  backbone.trainable = True

  frozen_layer_ix = int((1-Config.training.fine_tune_layers) * len(backbone.layers))

  for ix, l in enumerate(backbone.layers):
    l.trainable = (ix > frozen_layer_ix and
                   (not isinstance(l, tf.keras.layers.BatchNormalization) or
                    not Config.training.freeze_batch_norm))

In [None]:
nn.load_weights(Config.experiment.fine_tune_weights)

### Model Metrics

In [None]:
results = pd.DataFrame([
 [*nn.evaluate(train), 'train'],
 [*nn.evaluate(valid), 'valid'],
 [*nn.evaluate(test), 'test']
], columns=[*nn.metrics_names, 'subset'])

In [None]:
results

### Label-Specific Metrics

In [None]:
from sklearn import metrics
  
def metrics_per_label(gt, y_pred, threshold=0.5):
    threshold = tf.cast(threshold, y_pred.dtype)
    p_pred = tf.cast(y_pred > threshold, y_pred.dtype)

    tru_ = tf.reduce_sum(gt, axis=0)
    neg_ = tf.reduce_sum(1- gt, axis=0)

    acc = tf.reduce_mean(tf.cast(gt == p_pred, tf.float32), axis=0)
    tpr = tf.reduce_sum(p_pred*gt, axis=0) / tru_
    fpr = tf.reduce_sum(p_pred*(1-gt), axis=0) / neg_
    tnr = tf.reduce_sum((1-p_pred)*(1-gt), axis=0) / neg_
    fnr = tf.reduce_sum((1-p_pred)*gt, axis=0) / tru_

    auc = metrics.roc_auc_score(gt, y_pred, average=None)
    mcm = metrics.multilabel_confusion_matrix(gt, p_pred)

    return acc, tpr, fpr, tnr, fnr, auc, mcm

In [None]:
#@title


def labels_and_probs(nn, dataset):
  labels_ = []
  probs_ = []

  for images, labels in dataset:
    y = nn(images, training=False)
    y = tf.nn.sigmoid(y)

    labels_.append(labels)
    probs_.append(y)
  
  return (tf.concat(labels_, axis=0),
          tf.concat(probs_, axis=0))


def evaluate(l, p):
  acc, tpr, fpr, tnr, fnr, auc, mcm = metrics_per_label(l, p)

  return pd.DataFrame({
    'accuracy': acc,
    'true positive r': tpr,
    'true negative r': tnr,
    'false positive r': fpr,
    'false negative r': fnr,
    'roc auc score': auc,
    'support': tf.cast(tf.reduce_sum(l, axis=0), tf.int32),
    'label': CLASSES
  })

In [None]:
l, p = labels_and_probs(nn, test)
test_report = evaluate(l, p)

In [None]:
test_report.round(4)

In [None]:
pd.DataFrame(test_report.mean(axis=0)).round(4).T

In [None]:
co_occurrence = tf.transpose(l) @ l
occurrence = tf.reshape(np.diag(co_occurrence), (-1, 1))

co_occurrence_rate = tf.math.divide_no_nan(co_occurrence, occurrence)

In [None]:
#@title Labels Occurrence Matrix

plt.figure(figsize=(16, 6))
plt.subplot(121)
sns.heatmap(
  co_occurrence.numpy().astype(int),
  annot=True,
  fmt='d',
  xticklabels=CLASSES,
  yticklabels=CLASSES,
  cmap="RdPu",
  cbar=False
)

plt.subplot(122)
sns.heatmap(
  co_occurrence_rate.numpy(),
  annot=True,
  fmt='.0%',
  xticklabels=CLASSES,
  yticklabels=CLASSES,
  cmap="RdPu",
  cbar=False
)
plt.tight_layout();